mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-12 07:18:27 -05:00
Compare commits
543 Commits
github-pag
...
20230223.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67ba634958 | ||
|
|
9b544491e0 | ||
|
|
9c5415b598 | ||
|
|
040dbc317f | ||
|
|
65775046d8 | ||
|
|
b18bc36127 | ||
|
|
f01c526efd | ||
|
|
16168ab6b3 | ||
|
|
4233218629 | ||
|
|
b63fb36dc0 | ||
|
|
4e92304b89 | ||
|
|
2ae047f1a8 | ||
|
|
6d2a485264 | ||
|
|
4f045db024 | ||
|
|
5b33597b6d | ||
|
|
962470f610 | ||
|
|
ba8c116380 | ||
|
|
ad7330eae4 | ||
|
|
cf126e4839 | ||
|
|
c96d25c3e2 | ||
|
|
006aa0dae2 | ||
|
|
5b204bee86 | ||
|
|
d98b2afbe9 | ||
|
|
681332ef32 | ||
|
|
c3a4fdcbfc | ||
|
|
aac5de5b02 | ||
|
|
13a255afad | ||
|
|
3bffda52f9 | ||
|
|
d4e62ce557 | ||
|
|
9738483b18 | ||
|
|
143492fe94 | ||
|
|
ecc5c662c4 | ||
|
|
d973ba191d | ||
|
|
0198b183a2 | ||
|
|
0d44a3527b | ||
|
|
2147b6a397 | ||
|
|
6b5b4ba27b | ||
|
|
67005bf57c | ||
|
|
0430c741c6 | ||
|
|
1ce02e365d | ||
|
|
eae862adc2 | ||
|
|
dffa89524a | ||
|
|
2af1102441 | ||
|
|
c4b472842a | ||
|
|
750a7d806f | ||
|
|
bc7333f1e5 | ||
|
|
55ae50f991 | ||
|
|
a590c331ef | ||
|
|
8c241b06cb | ||
|
|
9c072c8068 | ||
|
|
ebd8b5122a | ||
|
|
055e484a40 | ||
|
|
912c4a1d12 | ||
|
|
c203b65bf1 | ||
|
|
307f0334ee | ||
|
|
5167df08b9 | ||
|
|
dd2e482214 | ||
|
|
87fd13d8eb | ||
|
|
dd423bc6de | ||
|
|
899cb9cc1f | ||
|
|
0464c7e558 | ||
|
|
f64e1fb926 | ||
|
|
ef7d31293d | ||
|
|
6d54eb68dc | ||
|
|
30eb10c990 | ||
|
|
591bbcd058 | ||
|
|
99aa77d036 | ||
|
|
9c13f1e635 | ||
|
|
24af983cfb | ||
|
|
67842a7525 | ||
|
|
3159a6f3e1 | ||
|
|
b2f3c96835 | ||
|
|
6582475955 | ||
|
|
41ee65b377 | ||
|
|
83fe477066 | ||
|
|
4ca84ee4ee | ||
|
|
c28cc4c919 | ||
|
|
e9864cb3f7 | ||
|
|
83c69ecd49 | ||
|
|
3595b4aaff | ||
|
|
3a9cfe113a | ||
|
|
c9966127da | ||
|
|
51300d33a7 | ||
|
|
5af124c5a5 | ||
|
|
eeb20b531a | ||
|
|
9dca842c22 | ||
|
|
1eb9436836 | ||
|
|
9604d9ce81 | ||
|
|
481d0553d8 | ||
|
|
60035cd63a | ||
|
|
d35f992ace | ||
|
|
157ae64f9d | ||
|
|
ffa17f6057 | ||
|
|
d695a43e32 | ||
|
|
01f6b4e6f0 | ||
|
|
7cf31a6ae4 | ||
|
|
fbd6224b04 | ||
|
|
8115b26079 | ||
|
|
820586ac68 | ||
|
|
4a7441ed07 | ||
|
|
383741f284 | ||
|
|
2bbc4e0e9f | ||
|
|
a7237244b0 | ||
|
|
1d38d49162 | ||
|
|
a783c089a9 | ||
|
|
e7907dc532 | ||
|
|
394413679d | ||
|
|
37189f14cb | ||
|
|
0b1ee81901 | ||
|
|
00cf73f9b8 | ||
|
|
5a5f285493 | ||
|
|
7f2ea454b6 | ||
|
|
7c14002118 | ||
|
|
3e9554f0a1 | ||
|
|
e11ffec544 | ||
|
|
8a47ddbe99 | ||
|
|
821108c7bd | ||
|
|
339738f8a3 | ||
|
|
9b90672f63 | ||
|
|
ba07e94a5e | ||
|
|
b3fc0f29cc | ||
|
|
5c7deb3611 | ||
|
|
15604e374f | ||
|
|
7cfc0fa55b | ||
|
|
a90812133b | ||
|
|
e26a70aa4f | ||
|
|
6a32a4e26c | ||
|
|
e853abf98b | ||
|
|
51e81e6ef8 | ||
|
|
e355000ceb | ||
|
|
e374074013 | ||
|
|
81e3d1c2c6 | ||
|
|
ab0cbb4475 | ||
|
|
1c64e40722 | ||
|
|
8cafe56eb4 | ||
|
|
3eceeb7b23 | ||
|
|
1a37675435 | ||
|
|
198ebede8d | ||
|
|
a504903dd5 | ||
|
|
842adef29c | ||
|
|
7edcaf5a06 | ||
|
|
c124b76328 | ||
|
|
e9c744ee5d | ||
|
|
83302930d8 | ||
|
|
a4634632ba | ||
|
|
d17e8dc5ad | ||
|
|
9fe63de4d4 | ||
|
|
8111f8bf35 | ||
|
|
fcd62513cf | ||
|
|
c3c701e654 | ||
|
|
6bf991edf6 | ||
|
|
9644e78545 | ||
|
|
c911189ef0 | ||
|
|
1118b4b651 | ||
|
|
4be75d4418 | ||
|
|
fb6beae27c | ||
|
|
fee73b0b63 | ||
|
|
9bbffa519e | ||
|
|
c3a641f0ab | ||
|
|
aafe7c4701 | ||
|
|
9a0b082cf8 | ||
|
|
8265e34a29 | ||
|
|
8ef8ae097f | ||
|
|
c3d14293c0 | ||
|
|
d55d8be504 | ||
|
|
03543030d3 | ||
|
|
fc6b474b92 | ||
|
|
a5db785dd7 | ||
|
|
1c1c5cd611 | ||
|
|
6ed02f70ec | ||
|
|
cb78cd8ac0 | ||
|
|
0c4590b45a | ||
|
|
d2e2ee6efa | ||
|
|
6a380a0b48 | ||
|
|
e5d5acbf1f | ||
|
|
00e38abbf0 | ||
|
|
e3e4ea5443 | ||
|
|
a3e4ea3228 | ||
|
|
56f16d6baf | ||
|
|
7a55ab900e | ||
|
|
137643fe72 | ||
|
|
d6e59c6241 | ||
|
|
458eb5d34c | ||
|
|
8259f08864 | ||
|
|
b3ab0a1843 | ||
|
|
f09f217478 | ||
|
|
e842c8c19b | ||
|
|
f6c3112d44 | ||
|
|
7059610632 | ||
|
|
2d272930d9 | ||
|
|
6c470d8131 | ||
|
|
30b29ce8cd | ||
|
|
1a9933002f | ||
|
|
c4a9365aa1 | ||
|
|
9d3af37104 | ||
|
|
7b3d57cff7 | ||
|
|
a802270da9 | ||
|
|
dd194a8758 | ||
|
|
6de02de221 | ||
|
|
85259750bf | ||
|
|
1249f0007d | ||
|
|
db0514d3fa | ||
|
|
dce42a7fad | ||
|
|
ec0b380194 | ||
|
|
7f27b61c98 | ||
|
|
f0b3557b02 | ||
|
|
2a1d1c1001 | ||
|
|
df7eb80e5b | ||
|
|
b9d947ce6f | ||
|
|
e6589d2454 | ||
|
|
0f5ac6afcf | ||
|
|
bc1bb1d188 | ||
|
|
3af2dd10ce | ||
|
|
dd22c65855 | ||
|
|
48137ced19 | ||
|
|
6eb47c12d1 | ||
|
|
5a1fc6675a | ||
|
|
6f80825814 | ||
|
|
f0dd48ed2a | ||
|
|
15e2df0db0 | ||
|
|
4ad0109769 | ||
|
|
ee0009d4b8 | ||
|
|
9d851c3346 | ||
|
|
5d117af8ae | ||
|
|
bb41c2d15e | ||
|
|
eba138ee4a | ||
|
|
3b2bbb74f8 | ||
|
|
dbc0f81211 | ||
|
|
d0b613d22e | ||
|
|
72f29b67d5 | ||
|
|
9570045cc3 | ||
|
|
e4efdb5cbb | ||
|
|
187f0fa70c | ||
|
|
472185c3e4 | ||
|
|
f94a571773 | ||
|
|
183e447d35 | ||
|
|
12f844d93a | ||
|
|
47a119a37f | ||
|
|
ee56559b9a | ||
|
|
00e594deea | ||
|
|
6ad9b213b9 | ||
|
|
e4375e8195 | ||
|
|
487bf8e29b | ||
|
|
fea1694e74 | ||
|
|
4102c124a9 | ||
|
|
135bad3280 | ||
|
|
b604f36881 | ||
|
|
782b449c71 | ||
|
|
017dcab685 | ||
|
|
e60b4568c6 | ||
|
|
4ee3d95a5a | ||
|
|
f18725bacc | ||
|
|
f6064a2b84 | ||
|
|
2e90cb7b95 | ||
|
|
2c09d63cd9 | ||
|
|
cc6fbdb0c3 | ||
|
|
ecfdec12f3 | ||
|
|
45af40fd14 | ||
|
|
d11cf42501 | ||
|
|
c3c1e3b055 | ||
|
|
7c5e3b1d99 | ||
|
|
ed6cec71e7 | ||
|
|
d6bcdd069c | ||
|
|
a26347826d | ||
|
|
5d1c099b31 | ||
|
|
220bee1365 | ||
|
|
1261074d95 | ||
|
|
136021424c | ||
|
|
fee4ba3746 | ||
|
|
a5b70335d4 | ||
|
|
5cf4976054 | ||
|
|
1aa3255061 | ||
|
|
b01f29f10d | ||
|
|
2673abca88 | ||
|
|
7eeb7f0715 | ||
|
|
37262a2479 | ||
|
|
de6e304959 | ||
|
|
234475bbc7 | ||
|
|
abbd9f7cfc | ||
|
|
dfd6ba67b3 | ||
|
|
1595254eab | ||
|
|
6964c5eeba | ||
|
|
2befe771b3 | ||
|
|
b133a035a4 | ||
|
|
726c062327 | ||
|
|
9083672de3 | ||
|
|
cdbaf880af | ||
|
|
9434981cdc | ||
|
|
8b3706f557 | ||
|
|
0d5173833d | ||
|
|
bf1178eb79 | ||
|
|
abcd3fa94a | ||
|
|
62aa1614b6 | ||
|
|
7027356126 | ||
|
|
5ebe13a13d | ||
|
|
c3bed9a2b7 | ||
|
|
f865222882 | ||
|
|
e2fe2e4095 | ||
|
|
0532a95f08 | ||
|
|
ff536f6015 | ||
|
|
097d0f27bb | ||
|
|
2257f87edf | ||
|
|
a17800da00 | ||
|
|
059c1b3a19 | ||
|
|
9a36816d27 | ||
|
|
7986b9b20b | ||
|
|
b2b3a0a62b | ||
|
|
3173b7d1d9 | ||
|
|
9d716d70d6 | ||
|
|
e1901a8608 | ||
|
|
7d0cbd8d90 | ||
|
|
59358361f9 | ||
|
|
7fea2d3b68 | ||
|
|
b6d3ff26bd | ||
|
|
523e63f5c1 | ||
|
|
10630ab597 | ||
|
|
2bc6de650d | ||
|
|
ffef1681e3 | ||
|
|
d935006a4a | ||
|
|
660cb5946e | ||
|
|
10160a066a | ||
|
|
72976a2ece | ||
|
|
831f206cd0 | ||
|
|
72648aa9f2 | ||
|
|
35e623deaf | ||
|
|
6263636738 | ||
|
|
535d012ded | ||
|
|
c73eed2e51 | ||
|
|
30fdc99f37 | ||
|
|
acb905f0cc | ||
|
|
bba06d0142 | ||
|
|
a14a47af12 | ||
|
|
73457336bc | ||
|
|
a14c53ad31 | ||
|
|
e7e763551a | ||
|
|
2928179331 | ||
|
|
24a16a4cfe | ||
|
|
6aed4423b2 | ||
|
|
6508e3fcc9 | ||
|
|
a15cb140ae | ||
|
|
898bc9e009 | ||
|
|
e67ea31ee2 | ||
|
|
986c126a5c | ||
|
|
0eee7616b9 | ||
|
|
5ddce749b8 | ||
|
|
d946cffabc | ||
|
|
fe618811ee | ||
|
|
09c45bfb80 | ||
|
|
e9e9ccd379 | ||
|
|
a9b27c78a3 | ||
|
|
bc17c29b2e | ||
|
|
aaf60bdee6 | ||
|
|
d913453e57 | ||
|
|
08e373aef4 | ||
|
|
4cb50a3d06 | ||
|
|
b03038222d | ||
|
|
5f5e0766dd | ||
|
|
48ec11c514 | ||
|
|
8ae76d18b5 | ||
|
|
e5be1790e5 | ||
|
|
e64aa40b17 | ||
|
|
eb8114ece8 | ||
|
|
616ee9b824 | ||
|
|
57c94f8f80 | ||
|
|
2a59c4f670 | ||
|
|
192ff487c4 | ||
|
|
b62ee3fcb9 | ||
|
|
0225292a44 | ||
|
|
589a7ed02f | ||
|
|
b3a42cd0b1 | ||
|
|
e3e1ca7cc6 | ||
|
|
57e417d174 | ||
|
|
1699db79b5 | ||
|
|
dab9403b8f | ||
|
|
9a14298146 | ||
|
|
40eea21863 | ||
|
|
d2475ec169 | ||
|
|
b3bcf4bf44 | ||
|
|
6049f86bc4 | ||
|
|
ff649b52ef | ||
|
|
e9e138c757 | ||
|
|
1096936a15 | ||
|
|
29cc478525 | ||
|
|
05e9eb40b5 | ||
|
|
c4444ff695 | ||
|
|
27b34f3929 | ||
|
|
2b8d784660 | ||
|
|
18f447d8d8 | ||
|
|
d7e1078d68 | ||
|
|
6be592653f | ||
|
|
8859853b41 | ||
|
|
3c46021102 | ||
|
|
bba8646669 | ||
|
|
b0dc19a910 | ||
|
|
df79ebd0f2 | ||
|
|
e19a97f316 | ||
|
|
482ffd6275 | ||
|
|
5117e50602 | ||
|
|
83b138208d | ||
|
|
1870cb4557 | ||
|
|
42ad5b9c5c | ||
|
|
333975eb8f | ||
|
|
aa0195e4ef | ||
|
|
56109fe09b | ||
|
|
e74046478b | ||
|
|
aa5a60812f | ||
|
|
ebb60019aa | ||
|
|
6393dc5d14 | ||
|
|
8c158f2452 | ||
|
|
8c3eabdcee | ||
|
|
8aa0ce6a24 | ||
|
|
a27ee141b3 | ||
|
|
1106456651 | ||
|
|
8856878cbd | ||
|
|
a9bac0287d | ||
|
|
efbd3dc778 | ||
|
|
a0d0eaa408 | ||
|
|
e2bf734b67 | ||
|
|
a333a90441 | ||
|
|
6dc0057d3d | ||
|
|
0f9e69d48c | ||
|
|
e6a7c019ab | ||
|
|
1d32eabd14 | ||
|
|
53d03f06a6 | ||
|
|
a2d8c40455 | ||
|
|
4f7d950c8d | ||
|
|
cac54b8c26 | ||
|
|
cd0e881d7d | ||
|
|
fee406e220 | ||
|
|
128342f47f | ||
|
|
024487c5fe | ||
|
|
879ba27ccb | ||
|
|
6d6d9627e7 | ||
|
|
af4bc82543 | ||
|
|
439a18bcc3 | ||
|
|
e12a1e0444 | ||
|
|
4400b0d3c3 | ||
|
|
5dff28ff99 | ||
|
|
d5ac841a1a | ||
|
|
232ce12e9b | ||
|
|
9a8638a6d0 | ||
|
|
a5445866b8 | ||
|
|
e8ded71a7b | ||
|
|
a14c615def | ||
|
|
3903b6ff0c | ||
|
|
41bf262482 | ||
|
|
645b658da0 | ||
|
|
6ee8f61fbe | ||
|
|
3c4c4231ce | ||
|
|
d0eef19eba | ||
|
|
6ca2eb3ad7 | ||
|
|
74aeb55733 | ||
|
|
3eb7965ca0 | ||
|
|
04f20070d1 | ||
|
|
88937fcb2f | ||
|
|
f80b85f10c | ||
|
|
32a2ec432d | ||
|
|
f4821d0d39 | ||
|
|
fdf2aa54ef | ||
|
|
275c032264 | ||
|
|
d88979fe19 | ||
|
|
e67bcffea7 | ||
|
|
005ded3c6f | ||
|
|
d624940e12 | ||
|
|
7763403b0e | ||
|
|
88c58244b9 | ||
|
|
0754c6ea20 | ||
|
|
7b1f04d121 | ||
|
|
d8a9bee244 | ||
|
|
ac0ea6bd3c | ||
|
|
45677c1e23 | ||
|
|
d9f4a9954a | ||
|
|
ec461a4456 | ||
|
|
559928e93b | ||
|
|
a526f7d5b8 | ||
|
|
749a2c2dec | ||
|
|
29a317dbb6 | ||
|
|
2f36de319a | ||
|
|
2005bce419 | ||
|
|
8a02d7729d | ||
|
|
1cdf301c14 | ||
|
|
9a86e5c476 | ||
|
|
32d3f4bd5f | ||
|
|
18689afc1a | ||
|
|
64d6da75c7 | ||
|
|
1e95e4b502 | ||
|
|
c63009a6db | ||
|
|
88f8718635 | ||
|
|
a081733a42 | ||
|
|
06ccfb0533 | ||
|
|
b18d75e3f7 | ||
|
|
3e7efaa048 | ||
|
|
a3fdfc81db | ||
|
|
f4c91df1df | ||
|
|
32e1ba8c0d | ||
|
|
1939376d72 | ||
|
|
25931d48a3 | ||
|
|
024c5e153a | ||
|
|
83f34b645d | ||
|
|
3f9f450e0d | ||
|
|
fd89b06641 | ||
|
|
f8dc996004 | ||
|
|
e6a964088b | ||
|
|
e3e767c7eb | ||
|
|
239c19eb12 | ||
|
|
7f37599a60 | ||
|
|
77c9a2c5ea | ||
|
|
fd7baae548 | ||
|
|
01fdf5ee16 | ||
|
|
e52f533c16 | ||
|
|
fbd77dc936 | ||
|
|
cdc6dd19e3 | ||
|
|
fd578a48a9 | ||
|
|
9956099516 | ||
|
|
f97b8fffed | ||
|
|
7b9e309724 | ||
|
|
1d33913d48 | ||
|
|
a48eaaed20 | ||
|
|
2741b8be53 | ||
|
|
4f906a265c | ||
|
|
0dff8d7af0 | ||
|
|
4f0d0d8167 | ||
|
|
d513060b21 | ||
|
|
d1a25ce4f3 | ||
|
|
51c98695b2 | ||
|
|
b448770ec2 | ||
|
|
5fe22a7980 | ||
|
|
38ae6b5af4 | ||
|
|
0bfe30d75d | ||
|
|
7be1d7d0be | ||
|
|
0d74c873f0 | ||
|
|
139aff2938 | ||
|
|
a3f733490c | ||
|
|
8a11f138d1 | ||
|
|
3405607917 | ||
|
|
7c99a6bd33 | ||
|
|
3fba8ce0e6 | ||
|
|
f3bde3c7fc | ||
|
|
21fee8ef33 | ||
|
|
0e217d6180 | ||
|
|
00a8ce75d1 | ||
|
|
8f3f00cd99 | ||
|
|
13bae2538a |
2
.github/workflows/gh-pages-releases.yml
vendored
2
.github/workflows/gh-pages-releases.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
- run: git fetch --all
|
||||
- run: git switch github-pages
|
||||
- run: git config --global user.email "none@none.com"
|
||||
- run: git config --global user.name "nod-team"
|
||||
- run: git config --global user.name "nod-ai"
|
||||
- run: mv /tmp/index.html package-index/index.html
|
||||
- run: git add package-index/index.html
|
||||
|
||||
|
||||
148
.github/workflows/nightly.yml
vendored
148
.github/workflows/nightly.yml
vendored
@@ -9,13 +9,92 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
windows-build:
|
||||
runs-on: 7950X
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Compute version
|
||||
shell: powershell
|
||||
run: |
|
||||
$package_version = $(Get-Date -UFormat "%Y%m%d")+"."+${{ github.run_number }}
|
||||
$package_version_ = $(Get-Date -UFormat "%Y%m%d")+"_"+${{ github.run_number }}
|
||||
$tag_name=$package_version
|
||||
echo "package_version=$package_version" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
echo "package_version_=$package_version_" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
echo "tag_name=$tag_name" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ env.tag_name }}
|
||||
release_name: nod.ai SHARK ${{ env.tag_name }}
|
||||
body: |
|
||||
Automatic snapshot release of nod.ai SHARK.
|
||||
draft: true
|
||||
prerelease: false
|
||||
|
||||
- name: Build Package
|
||||
shell: powershell
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
|
||||
python process_skipfiles.py
|
||||
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
|
||||
|
||||
# GHA windows VM OOMs so disable for now
|
||||
#- name: Build and validate the SHARK Runtime package
|
||||
# shell: powershell
|
||||
# run: |
|
||||
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
- uses: actions/upload-artifact@v2
|
||||
with:
|
||||
path: dist/*
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ./dist/*
|
||||
|
||||
- name: Publish Release
|
||||
id: publish_release
|
||||
uses: eregon/publish-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
|
||||
linux-build:
|
||||
|
||||
runs-on: a100
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.11"]
|
||||
backend: [IREE, SHARK]
|
||||
|
||||
steps:
|
||||
@@ -32,40 +111,13 @@ jobs:
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Compute version
|
||||
run: |
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
tag_name="${package_version}"
|
||||
echo "package_version=${package_version}" >> $GITHUB_ENV
|
||||
echo "tag_name=${tag_name}" >> $GITHUB_ENV
|
||||
- name: Set Environment Variables
|
||||
run: |
|
||||
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ env.tag_name }}
|
||||
release_name: nod.ai SHARK ${{ env.tag_name }}
|
||||
body: |
|
||||
Automatic snapshot release of nod.ai SHARK.
|
||||
draft: true
|
||||
prerelease: false
|
||||
- name: Find Torch-MLIR Release
|
||||
run: |
|
||||
TM_HTML_URL="$(python3 -c "import urllib.request, json, sys; u=json.loads(urllib.request.urlopen('https://api.github.com/repos/llvm/torch-mlir/releases/latest').read().decode()).get('html_url', False); print(u) if u else sys.exit(1);")"
|
||||
TM_RELEASE_DIR=${TM_HTML_URL/"tag"/"expanded_assets"}
|
||||
echo "TM_RELEASE_DIR=${TM_RELEASE_DIR}" >> $GITHUB_ENV
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
echo "Torch-MLIR Release DIR is ${{ env.TM_RELEASE_DIR }}"
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f ${{ env.TM_RELEASE_DIR }} -f https://github.com/nod-ai/SHARK-Runtime/releases; fi
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
@@ -74,25 +126,26 @@ jobs:
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
|
||||
- name: Build and validate the IREE package
|
||||
if: ${{ matrix.backend == 'IREE' }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
|
||||
source iree.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f ${{ env.TM_RELEASE_DIR }} -f https://github.com/iree-org/iree/releases
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://iree-org.github.io/iree/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" tank/test_models.py |
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
if !(grep -Fxq " failed" pytest_results.txt)
|
||||
then
|
||||
export SHA=$(git log -1 --format='%h')
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/$SHA
|
||||
gsutil -m cp -r gs://shark_tank/$SHA/* gs://shark_tank/latest/
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
|
||||
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
|
||||
fi
|
||||
rm -rf ./wheelhouse/nodai*
|
||||
|
||||
@@ -104,29 +157,10 @@ jobs:
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f ${{ env.TM_RELEASE_DIR }} -f https://github.com/nod-ai/SHARK-Runtime/releases
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" tank/test_models.py |
|
||||
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
|
||||
- name: Upload Release Assets
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ${GITHUB_WORKSPACE}/wheelhouse/nodai_*.whl
|
||||
|
||||
- name: Publish Release
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
id: publish_release
|
||||
uses: eregon/publish-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
|
||||
70
.github/workflows/test-models.yml
vendored
70
.github/workflows/test-models.yml
vendored
@@ -6,18 +6,32 @@ name: Validate Models on Shark Runtime
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'shark/examples/**'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'shark/examples/**'
|
||||
workflow_dispatch:
|
||||
|
||||
# Ensure that only a single job or workflow using the same
|
||||
# concurrency group will run at a time. This would cancel
|
||||
# any in-progress jobs in the same github workflow and github
|
||||
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-validate:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [icelake, a100, MacStudio, ubuntu-latest]
|
||||
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
|
||||
suite: [cpu,cuda,vulkan]
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.11"]
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
suite: lint
|
||||
@@ -32,21 +46,25 @@ jobs:
|
||||
suite: cuda
|
||||
- os: MacStudio
|
||||
suite: cpu
|
||||
- os: MacStudio
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
suite: cuda
|
||||
- os: a100
|
||||
suite: cpu
|
||||
- os: 7950x
|
||||
suite: cpu
|
||||
- os: 7950x
|
||||
suite: cuda
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
if: matrix.os != '7950x'
|
||||
|
||||
- name: Set Environment Variables
|
||||
if: matrix.os != '7950x'
|
||||
run: |
|
||||
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
@@ -66,6 +84,9 @@ jobs:
|
||||
#cache-dependency-path: |
|
||||
# **/requirements-importer.txt
|
||||
# **/requirements.txt
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
if: matrix.os == '7950x'
|
||||
|
||||
- name: Install dependencies
|
||||
if: matrix.suite == 'lint'
|
||||
@@ -88,9 +109,9 @@ jobs:
|
||||
if: matrix.suite == 'cpu'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k cpu
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k cpu
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
|
||||
@@ -100,14 +121,41 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k cuda
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k cuda
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
# Disabled due to black image bug
|
||||
# python build_tools/stable_diffusion_testing.py --device=cuda
|
||||
|
||||
- name: Validate Vulkan Models
|
||||
if: matrix.suite == 'vulkan'
|
||||
- name: Validate Vulkan Models (MacOS)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k vulkan
|
||||
export DYLD_LIBRARY_PATH=/usr/local/lib/
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" -k vulkan --update_tank
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
|
||||
- name: Validate Vulkan Models (Windows)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
pytest --benchmark -k vulkan -s
|
||||
type bench_results.csv
|
||||
|
||||
- name: Validate Stable Diffusion Models (Windows)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
|
||||
17
.gitignore
vendored
17
.gitignore
vendored
@@ -31,7 +31,6 @@ MANIFEST
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
@@ -160,10 +159,26 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# vscode related
|
||||
.vscode
|
||||
|
||||
# Shark related artefacts
|
||||
*venv/
|
||||
shark_tmp/
|
||||
*.vmfb
|
||||
.use-iree
|
||||
tank/dict_configs.py
|
||||
|
||||
# ORT related artefacts
|
||||
cache_models/
|
||||
onnx_models/
|
||||
|
||||
# Generated images
|
||||
generated_imgs/
|
||||
|
||||
# Custom model related artefacts
|
||||
variants.json
|
||||
models/
|
||||
|
||||
# models folder
|
||||
apps/stable_diffusion/web/models/
|
||||
|
||||
466
README.md
466
README.md
@@ -1,29 +1,161 @@
|
||||
# SHARK
|
||||
|
||||
High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerators and Heterogeneous Clusters
|
||||
High Performance Machine Learning Distribution
|
||||
|
||||
[](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
|
||||
[](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
|
||||
|
||||
## Communication Channels
|
||||
|
||||
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
|
||||
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
<summary>Installation (Linux and macOS)</summary>
|
||||
<summary>Prerequisites - Drivers </summary>
|
||||
|
||||
#### Install your Windows hardware drivers
|
||||
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
|
||||
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
#### Linux Drivers
|
||||
* MESA / RADV drivers wont work with FP16. Please use the latest AMGPU-PRO drivers (non-pro OSS drivers also wont work) or the latest NVidia Linux Drivers.
|
||||
|
||||
Other users please ensure you have your latest vendor drivers and Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home) and if you are using vulkan check `vulkaninfo` works in a terminal window
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
### Quick Start for SHARK Stable Diffusion for Windows 10/11 Users
|
||||
|
||||
Install the Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
|
||||
|
||||
Download the stable release [539](https://github.com/nod-ai/SHARK/releases/download/20230216.539/shark_sd_20230216_539.exe) or if you are adventurous the latest .exe from [releases page](https://github.com/nod-ai/SHARK/releases).
|
||||
|
||||
Double click the .exe and you should have the [UI](http://localhost:8080/) in the browser.
|
||||
|
||||
If you have custom models put them in a `models/` directory where the .exe is.
|
||||
|
||||
Enjoy.
|
||||
|
||||
<details>
|
||||
<summary>More installation notes</summary>
|
||||
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
|
||||
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
|
||||
|
||||
## Running
|
||||
|
||||
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE)
|
||||
* The first run may take few minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
|
||||
* You will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
|
||||
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/.
|
||||
|
||||
## Stopping
|
||||
|
||||
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Advanced Installation (Only for developers)</summary>
|
||||
|
||||
## Advanced Installation (Windows, Linux and macOS) for developers
|
||||
|
||||
## Check out the code
|
||||
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
cd SHARK
|
||||
```
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
|
||||
### Windows 10/11 Users
|
||||
|
||||
* Install the latest Python 3.11.x version from [here](https://www.python.org/downloads/windows/)
|
||||
|
||||
* Install Git for Windows from [here](https://git-scm.com/download/win)
|
||||
|
||||
#### Allow the install script to run in Powershell
|
||||
```powershell
|
||||
set-executionpolicy remotesigned
|
||||
```
|
||||
|
||||
#### Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...)
|
||||
```powershell
|
||||
./setup_venv.ps1 #You can re-run this script to get the latest version
|
||||
```
|
||||
|
||||
### Linux / macOS Users
|
||||
|
||||
```shell
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
```
|
||||
|
||||
|
||||
### Run Stable Diffusion on your device - WebUI
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
|
||||
(shark.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
|
||||
```
|
||||
#### Linux / macOS Users
|
||||
```shell
|
||||
(shark.venv) > cd apps/stable_diffusion/web
|
||||
(shark.venv) > python index.py
|
||||
```
|
||||
|
||||
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
|
||||
|
||||
|
||||
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
|
||||
|
||||
|
||||
|
||||
### Run Stable Diffusion on your device - Commandline
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
```
|
||||
|
||||
#### Linux / macOS Users
|
||||
```shell
|
||||
python3.11 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
```
|
||||
|
||||
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
|
||||
</details>
|
||||
|
||||
The output on a AMD 7900XTX would look something like:
|
||||
|
||||
```shell
|
||||
Average step time: 47.19188690185547ms/it
|
||||
Clip Inference time (ms) = 109.531
|
||||
VAE Inference time (ms): 78.590
|
||||
|
||||
Total image generation time: 2.5788655281066895sec
|
||||
```
|
||||
|
||||
Here are some samples generated:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Binary Installation</summary>
|
||||
|
||||
### Setup a new pip Virtual Environment
|
||||
|
||||
This step sets up a new VirtualEnv for Python
|
||||
|
||||
```shell
|
||||
python --version #Check you have 3.7->3.10 on Linux or 3.10 on macOS
|
||||
python --version #Check you have 3.11 on Linux, macOS or Windows Powershell
|
||||
python -m venv shark_venv
|
||||
source shark_venv/bin/activate
|
||||
source shark_venv/bin/activate # Use shark_venv/Scripts/activate on Windows
|
||||
|
||||
# If you are using conda create and activate a new conda env
|
||||
|
||||
@@ -35,12 +167,17 @@ python -m pip install --upgrade pip
|
||||
|
||||
### Install SHARK
|
||||
|
||||
This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10
|
||||
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
|
||||
|
||||
```shell
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://github.com/nod-ai/shark-runtime/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
If you are on an Intel macOS machine you need this [workaround](https://github.com/nod-ai/SHARK/issues/102) for an upstream issue.
|
||||
|
||||
### Run shark tank model tests.
|
||||
```shell
|
||||
pytest tank/test_models.py
|
||||
```
|
||||
See tank/README.md for a more detailed walkthrough of our pytest suite and CLI.
|
||||
|
||||
### Download and run Resnet50 sample
|
||||
|
||||
@@ -61,33 +198,31 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Source Installation</summary>
|
||||
<summary>Development, Testing and Benchmarks</summary>
|
||||
|
||||
## Check out the code
|
||||
If you want to use Python3.11 and with TF Import tools you can use the environment variables like:
|
||||
Set `USE_IREE=1` to use upstream IREE
|
||||
```
|
||||
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
|
||||
```
|
||||
|
||||
### Run any of the hundreds of SHARK tank models via the test framework
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
|
||||
# Or a pytest
|
||||
pytest tank/test_models.py -k "MiniLM"
|
||||
```
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
```shell
|
||||
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
```
|
||||
For example if you want to use Python3.10 and upstream IREE with TF Import tools you can use the environment variables like:
|
||||
```
|
||||
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 USE_IREE=1 ./setup_venv.sh
|
||||
```
|
||||
|
||||
|
||||
### How to use your locally built IREE / Torch-MLIR with SHARK
|
||||
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
|
||||
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
|
||||
with Python bindings and set your PYTHONPATH as mentioned [here](https://google.github.io/iree/bindings/python/)
|
||||
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
|
||||
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
|
||||
for Torch-MLIR.
|
||||
|
||||
### How to use your locally built Torch-MLIR with SHARK
|
||||
How to use your locally built Torch-MLIR with SHARK:
|
||||
```shell
|
||||
1.) Run `./setup_venv.sh in SHARK` and activate `shark.venv` virtual env.
|
||||
2.) Run `pip uninstall torch-mlir`.
|
||||
@@ -102,82 +237,45 @@ for Torch-MLIR.
|
||||
```
|
||||
Now the SHARK will use your locally build Torch-MLIR repo.
|
||||
|
||||
### Run a demo script
|
||||
```shell
|
||||
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
|
||||
# Or a pytest
|
||||
pytest tank/test_models.py -k "MiniLM"
|
||||
|
||||
## Benchmarking Dispatches
|
||||
|
||||
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
|
||||
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
|
||||
|
||||
For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
|
||||
```
|
||||
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
|
||||
```
|
||||
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.
|
||||
|
||||
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
|
||||
|
||||
```
|
||||
shark_module = SharkInference(
|
||||
mlir_model,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
dispatch_benchmarks="all",
|
||||
dispatch_benchmarks_dir="results"
|
||||
)
|
||||
```
|
||||
|
||||
Output will include:
|
||||
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
|
||||
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
|
||||
- An .mlir file containing the dispatch benchmark
|
||||
- A compiled .vmfb file containing the dispatch benchmark
|
||||
- An .mlir file containing just the hal executable
|
||||
- A compiled .vmfb file of the hal executable
|
||||
- A .txt file containing benchmark output
|
||||
|
||||
|
||||
See tank/README.md for further instructions on how to run model tests and benchmarks from the SHARK tank.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Testing and Benchmarks</summary>
|
||||
|
||||
### Run all model tests on CPU/GPU/VULKAN/Metal
|
||||
```shell
|
||||
pytest tank/test_models.py
|
||||
|
||||
# If on Linux for multithreading on CPU (faster results):
|
||||
pytest tank/test_models.py -n auto
|
||||
```
|
||||
|
||||
### Running specific tests
|
||||
```shell
|
||||
|
||||
# Search for test cases by including a keyword that matches all or part of the test case's name;
|
||||
pytest tank/test_models.py -k "keyword"
|
||||
|
||||
# Test cases are named uniformly by format test_module_<model_name_underscores_only>_<torch/tf>_<static/dynamic>_<device>.
|
||||
|
||||
# Example: Test all models on nvidia gpu:
|
||||
pytest tank/test_models.py -k "cuda"
|
||||
|
||||
# Example: Test all tensorflow resnet models on Vulkan backend:
|
||||
pytest tank/test_models.py -k "resnet and tf and vulkan"
|
||||
|
||||
# Exclude a test case:
|
||||
pytest tank/test_models.py -k "not ..."
|
||||
|
||||
### Run benchmarks on SHARK tank pytests and generate bench_results.csv with results.
|
||||
|
||||
(the following requires source installation with `IMPORTER=1 ./setup_venv.sh`)
|
||||
|
||||
```shell
|
||||
pytest --benchmark tank/test_models.py
|
||||
|
||||
# Just do static GPU benchmarks for PyTorch tests:
|
||||
pytest --benchmark tank/test_models.py -k "pytorch and static and cuda"
|
||||
|
||||
```
|
||||
|
||||
### Benchmark Resnet50, MiniLM on CPU
|
||||
|
||||
(requires source installation with `IMPORTER=1 ./setup_venv.sh`)
|
||||
|
||||
```shell
|
||||
# We suggest running the following commands as root before running benchmarks on CPU:
|
||||
|
||||
cat /sys/devices/system/cpu/cpu*/topology/thread_siblings_list | awk -F, '{print $2}' | sort -n | uniq | ( while read X ; do echo $X ; echo 0 > /sys/devices/system/cpu/cpu$X/online ; done )
|
||||
echo 1 > /sys/devices/system/cpu/intel_pstate/no_turbo
|
||||
|
||||
# Benchmark canonical Resnet50 on CPU via pytest
|
||||
pytest --benchmark tank/test_models -k "resnet50 and tf_static_cpu"
|
||||
|
||||
# Benchmark canonical MiniLM on CPU via pytest
|
||||
pytest --benchmark tank/test_models -k "MiniLM and cpu"
|
||||
|
||||
# Benchmark MiniLM on CPU via transformer-benchmarks:
|
||||
git clone --recursive https://github.com/nod-ai/transformer-benchmarks.git
|
||||
cd transformer-benchmarks
|
||||
./perf-ci.sh -n
|
||||
# Check detail.csv for MLIR/IREE results.
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>API Reference</summary>
|
||||
|
||||
@@ -228,160 +326,26 @@ result = shark_module.forward((arg0, arg1))
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
## Supported and Validated Models
|
||||
|
||||
<details>
|
||||
<summary>PyTorch Models</summary>
|
||||
SHARK is maintained to support the latest innovations in ML Models:
|
||||
|
||||
### Huggingface PyTorch Models
|
||||
| TF HuggingFace Models | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DistilBERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| GPT2 | :green_heart: | :green_heart: | :green_heart: |
|
||||
| BLOOM | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Stable Diffusion | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Vision Transformer | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ResNet50 | :green_heart: | :green_heart: | :green_heart: |
|
||||
|
||||
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Albert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| BigBird | :green_heart: (AOT) | | | |
|
||||
| DistilBERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| GPT2 | :broken_heart: (AOT) | | | |
|
||||
| MobileBert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
For a complete list of the models supported in SHARK, please refer to [tank/README.md](https://github.com/nod-ai/SHARK/blob/main/tank/README.md).
|
||||
|
||||
### Torchvision Models
|
||||
## Communication Channels
|
||||
|
||||
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|--------------------|----------------------|----------|----------|-------------|
|
||||
| AlexNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DenseNet121 | :green_heart: (Script) | | | |
|
||||
| MNasNet1_0 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| MobileNetV2 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| MobileNetV3 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Unet | :broken_heart: (Script) | | | |
|
||||
| Resnet18 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnet50 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnet101 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnext50_32x4d | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ShuffleNet_v2 | :broken_heart: (Script) | | | |
|
||||
| SqueezeNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| EfficientNet | :green_heart: (Script) | | | |
|
||||
| Regnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnest | :broken_heart: (Script) | | | |
|
||||
| Vision Transformer | :green_heart: (Script) | | | |
|
||||
| VGG 16 | :green_heart: (Script) | :green_heart: | :green_heart: | |
|
||||
| Wide Resnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| RAFT | :broken_heart: (JIT) | | | |
|
||||
|
||||
For more information refer to [MODEL TRACKING SHEET](https://docs.google.com/spreadsheets/d/15PcjKeHZIrB5LfDyuw7DGEEE8XnQEX2aX8lm8qbxV8A/edit#gid=0)
|
||||
|
||||
### PyTorch Training Models
|
||||
|
||||
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :broken_heart: | :broken_heart: | | |
|
||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>JAX Models</summary>
|
||||
|
||||
|
||||
### JAX Models
|
||||
|
||||
| Models | JAX-MHLO lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| DALL-E | :broken_heart: | :broken_heart: | | |
|
||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>TFLite Models</summary>
|
||||
|
||||
### TFLite Models
|
||||
|
||||
| Models | TOSA/LinAlg | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :broken_heart: | :broken_heart: | | |
|
||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||
| albert | :green_heart: | :green_heart: | | |
|
||||
| asr_conformer | :green_heart: | :green_heart: | | |
|
||||
| bird_classifier | :green_heart: | :green_heart: | | |
|
||||
| cartoon_gan | :green_heart: | :green_heart: | | |
|
||||
| craft_text | :green_heart: | :green_heart: | | |
|
||||
| deeplab_v3 | :green_heart: | :green_heart: | | |
|
||||
| densenet | :green_heart: | :green_heart: | | |
|
||||
| east_text_detector | :green_heart: | :green_heart: | | |
|
||||
| efficientnet_lite0_int8 | :green_heart: | :green_heart: | | |
|
||||
| efficientnet | :green_heart: | :green_heart: | | |
|
||||
| gpt2 | :green_heart: | :green_heart: | | |
|
||||
| image_stylization | :green_heart: | :green_heart: | | |
|
||||
| inception_v4 | :green_heart: | :green_heart: | | |
|
||||
| inception_v4_uint8 | :green_heart: | :green_heart: | | |
|
||||
| lightning_fp16 | :green_heart: | :green_heart: | | |
|
||||
| lightning_i8 | :green_heart: | :green_heart: | | |
|
||||
| lightning | :green_heart: | :green_heart: | | |
|
||||
| magenta | :green_heart: | :green_heart: | | |
|
||||
| midas | :green_heart: | :green_heart: | | |
|
||||
| mirnet | :green_heart: | :green_heart: | | |
|
||||
| mnasnet | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_edgetpu_s_float | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_edgetpu_s_quant | :green_heart: | :green_heart: | | |
|
||||
| mobilebert | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_tf2_float | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_tf2_quant | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_ssd_quant | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v1 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v2 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v2_uint8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v3-large | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v3-large_uint8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v35-int8 | :green_heart: | :green_heart: | | |
|
||||
| nasnet | :green_heart: | :green_heart: | | |
|
||||
| person_detect | :green_heart: | :green_heart: | | |
|
||||
| posenet | :green_heart: | :green_heart: | | |
|
||||
| resnet_50_int8 | :green_heart: | :green_heart: | | |
|
||||
| rosetta | :green_heart: | :green_heart: | | |
|
||||
| spice | :green_heart: | :green_heart: | | |
|
||||
| squeezenet | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v1 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2_fpnlite | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2_fpnlite_uint8 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2 | :green_heart: | :green_heart: | | |
|
||||
| ssd_spaghettinet_large | :green_heart: | :green_heart: | | |
|
||||
| ssd_spaghettinet_large_uint8 | :green_heart: | :green_heart: | | |
|
||||
| visual_wake_words_i8 | :green_heart: | :green_heart: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>TF Models</summary>
|
||||
|
||||
### Tensorflow Models (Inference)
|
||||
|
||||
| Hugging Face Models | tf-mhlo lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| albert-base-v2 | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DistilBERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| CamemBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ConvBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Deberta | | | | |
|
||||
| electra | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| funnel | | | | |
|
||||
| layoutlm | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| longformer | | | | |
|
||||
| mobile-bert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| remembert | | | | |
|
||||
| tapas | | | | |
|
||||
| flaubert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| xlm-roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| mpnet | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
|
||||
</details>
|
||||
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
|
||||
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
|
||||
|
||||
## Related Projects
|
||||
|
||||
|
||||
87
apps/stable_diffusion/profiling_with_iree.md
Normal file
87
apps/stable_diffusion/profiling_with_iree.md
Normal file
@@ -0,0 +1,87 @@
|
||||
Compile / Run Instructions:
|
||||
|
||||
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
|
||||
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
|
||||
|
||||
Compile Commands FP32/FP16:
|
||||
|
||||
```shell
|
||||
Vulkan AMD:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
|
||||
# use –iree-input-type=mhlo for tf models
|
||||
|
||||
CUDA NVIDIA:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
CPU:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
```
|
||||
|
||||
|
||||
|
||||
Run / Benchmark Command (FP32 - NCHW):
|
||||
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
|
||||
|
||||
```shell
|
||||
## Vulkan AMD:
|
||||
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
|
||||
|
||||
## CUDA:
|
||||
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=cuda --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
|
||||
|
||||
## CPU:
|
||||
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=local-task --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
|
||||
|
||||
```
|
||||
|
||||
Run via vulkan_gui for RGP Profiling:
|
||||
|
||||
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
|
||||
```shell
|
||||
./build/vulkan_gui/iree-vulkan-gui --module=/path/to/unet.vmfb --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
|
||||
```
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>Debug Commands</summary>
|
||||
|
||||
## Debug commands and other advanced usage follows.
|
||||
|
||||
```shell
|
||||
python txt2img.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
|
||||
```
|
||||
|
||||
## dump all dispatch .spv and isa using amdllpc
|
||||
|
||||
```shell
|
||||
python txt2img.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
|
||||
```
|
||||
|
||||
## Compile and save the .vmfb (using vulkan fp16 as an example):
|
||||
|
||||
```shell
|
||||
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
|
||||
```
|
||||
|
||||
## Capture an RGP trace
|
||||
|
||||
```shell
|
||||
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
|
||||
```
|
||||
|
||||
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
|
||||
|
||||
```shell
|
||||
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf16
|
||||
```
|
||||
|
||||
## Run the unet module with iree-benchmark-module (same config as above):
|
||||
```shell
|
||||
##if you want to use .npz inputs:
|
||||
unzip ~/.local/shark_tank/<your unet>/inputs.npz
|
||||
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --input=@arr_0.npy --input=1xf16 --input=@arr_2.npy --input=@arr_3.npy --input=@arr_4.npy
|
||||
```
|
||||
|
||||
</details>
|
||||
2
apps/stable_diffusion/scripts/__init__.py
Normal file
2
apps/stable_diffusion/scripts/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from apps.stable_diffusion.scripts.txt2img import txt2img_inf
|
||||
from apps.stable_diffusion.scripts.img2img import img2img_inf
|
||||
264
apps/stable_diffusion/scripts/img2img.py
Normal file
264
apps/stable_diffusion/scripts/img2img.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
img2img_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def img2img_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image: str,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
):
|
||||
global img2img_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.seed = seed
|
||||
args.steps = steps
|
||||
args.strength = strength
|
||||
args.scheduler = scheduler
|
||||
args.img_path = init_image
|
||||
image = Image.open(args.img_path).convert("RGB")
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
if image is None:
|
||||
return None, "An Initial Image is required"
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if not img2img_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = ""
|
||||
args.use_tuned = True
|
||||
args.import_mlir = True
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "runwayml/stable-diffusion-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
img2img_obj = Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
img2img_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
img2img_obj.log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = img2img_obj.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
img2img_obj.log += "\n"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={steps}, strength={args.strength}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
text_output += img2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
# When the models get uploaded, it should be default to False.
|
||||
args.import_mlir = True
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
if args.scheduler != "PNDM":
|
||||
if "Shark" in args.scheduler:
|
||||
print(
|
||||
f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler"
|
||||
)
|
||||
args.scheduler = "PNDM"
|
||||
else:
|
||||
sys.exit(
|
||||
"Img2Img works best with PNDM scheduler. Other schedulers are not supported yet."
|
||||
)
|
||||
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
image = Image.open(args.img_path).convert("RGB")
|
||||
seed = utils.sanitize_seed(args.seed)
|
||||
|
||||
# Adjust for height and width based on model
|
||||
|
||||
img2img_obj = Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = img2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.strength,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, strength={args.strength}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += img2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
253
apps/stable_diffusion/scripts/inpaint.py
Normal file
253
apps/stable_diffusion/scripts/inpaint.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
inpaint_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def inpaint_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image: Image,
|
||||
mask_image: Image,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
):
|
||||
global inpaint_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if not inpaint_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = ""
|
||||
args.use_tuned = True
|
||||
args.import_mlir = False
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
inpaint_obj = InpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
inpaint_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
inpaint_obj.log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = inpaint_obj.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
mask_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
inpaint_obj.log += "\n"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += inpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
if args.mask_path is None:
|
||||
print("Flag --mask_path is required.")
|
||||
exit()
|
||||
if "inpaint" not in args.hf_model_id:
|
||||
print("Please use inpainting model with --hf_model_id.")
|
||||
exit()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
image = Image.open(args.img_path)
|
||||
mask_image = Image.open(args.mask_path)
|
||||
|
||||
inpaint_obj = InpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = inpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
mask_image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += inpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
240
apps/stable_diffusion/scripts/telegram_bot.py
Normal file
240
apps/stable_diffusion/scripts/telegram_bot.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import logging
|
||||
import os
|
||||
from models.stable_diffusion.main import stable_diff_inf
|
||||
from models.stable_diffusion.utils import get_available_devices
|
||||
from dotenv import load_dotenv
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram import BotCommand
|
||||
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
|
||||
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
|
||||
from io import BytesIO
|
||||
import random
|
||||
|
||||
log = logging.getLogger("TG.Bot")
|
||||
logging.basicConfig()
|
||||
log.warning("Start")
|
||||
load_dotenv()
|
||||
os.environ["AMD_ENABLE_LLPC"] = "0"
|
||||
TG_TOKEN = os.getenv("TG_TOKEN")
|
||||
SELECTED_MODEL = "stablediffusion"
|
||||
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
|
||||
STEPS = 30
|
||||
NEGATIVE_PROMPT = (
|
||||
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
|
||||
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
|
||||
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
|
||||
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
|
||||
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
|
||||
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
|
||||
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
|
||||
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
|
||||
" legs, fused fingers, too many fingers"
|
||||
)
|
||||
GUIDANCE_SCALE = 6
|
||||
available_devices = get_available_devices()
|
||||
models_list = [
|
||||
"stablediffusion",
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]
|
||||
sheds_list = [
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
]
|
||||
|
||||
|
||||
def image_to_bytes(image):
|
||||
bio = BytesIO()
|
||||
bio.name = "image.jpeg"
|
||||
image.save(bio, "JPEG")
|
||||
bio.seek(0)
|
||||
return bio
|
||||
|
||||
|
||||
def get_try_again_markup():
|
||||
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
|
||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
||||
return reply_markup
|
||||
|
||||
|
||||
def generate_image(prompt):
|
||||
seed = random.randint(1, 10000)
|
||||
log.warning(SELECTED_MODEL)
|
||||
log.warning(STEPS)
|
||||
image, text = stable_diff_inf(
|
||||
prompt=prompt,
|
||||
negative_prompt=NEGATIVE_PROMPT,
|
||||
steps=STEPS,
|
||||
guidance_scale=GUIDANCE_SCALE,
|
||||
seed=seed,
|
||||
scheduler_key=SELECTED_SCHEDULER,
|
||||
variant=SELECTED_MODEL,
|
||||
device_key=available_devices[0],
|
||||
)
|
||||
|
||||
return image, seed
|
||||
|
||||
|
||||
async def generate_and_send_photo(
|
||||
update: Update, context: ContextTypes.DEFAULT_TYPE
|
||||
) -> None:
|
||||
progress_msg = await update.message.reply_text(
|
||||
"Generating image...", reply_to_message_id=update.message.message_id
|
||||
)
|
||||
im, seed = generate_image(prompt=update.message.text)
|
||||
await context.bot.delete_message(
|
||||
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
|
||||
)
|
||||
await context.bot.send_photo(
|
||||
update.effective_user.id,
|
||||
image_to_bytes(im),
|
||||
caption=f'"{update.message.text}" (Seed: {seed})',
|
||||
reply_markup=get_try_again_markup(),
|
||||
reply_to_message_id=update.message.message_id,
|
||||
)
|
||||
|
||||
|
||||
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
query = update.callback_query
|
||||
if query.data in models_list:
|
||||
global SELECTED_MODEL
|
||||
SELECTED_MODEL = query.data
|
||||
await query.answer()
|
||||
await query.edit_message_text(text=f"Selected model: {query.data}")
|
||||
return
|
||||
if query.data in sheds_list:
|
||||
global SELECTED_SCHEDULER
|
||||
SELECTED_SCHEDULER = query.data
|
||||
await query.answer()
|
||||
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
|
||||
return
|
||||
replied_message = query.message.reply_to_message
|
||||
await query.answer()
|
||||
progress_msg = await query.message.reply_text(
|
||||
"Generating image...", reply_to_message_id=replied_message.message_id
|
||||
)
|
||||
|
||||
if query.data == "TRYAGAIN":
|
||||
prompt = replied_message.text
|
||||
im, seed = generate_image(prompt)
|
||||
|
||||
await context.bot.delete_message(
|
||||
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
|
||||
)
|
||||
await context.bot.send_photo(
|
||||
update.effective_user.id,
|
||||
image_to_bytes(im),
|
||||
caption=f'"{prompt}" (Seed: {seed})',
|
||||
reply_markup=get_try_again_markup(),
|
||||
reply_to_message_id=replied_message.message_id,
|
||||
)
|
||||
|
||||
|
||||
async def select_model_handler(update, context):
|
||||
text = "Select model"
|
||||
keyboard = []
|
||||
for model in models_list:
|
||||
keyboard.append(
|
||||
[
|
||||
InlineKeyboardButton(text=model, callback_data=model),
|
||||
]
|
||||
)
|
||||
markup = InlineKeyboardMarkup(keyboard)
|
||||
await update.message.reply_text(text=text, reply_markup=markup)
|
||||
|
||||
|
||||
async def select_scheduler_handler(update, context):
|
||||
text = "Select schedule"
|
||||
keyboard = []
|
||||
for shed in sheds_list:
|
||||
keyboard.append(
|
||||
[
|
||||
InlineKeyboardButton(text=shed, callback_data=shed),
|
||||
]
|
||||
)
|
||||
markup = InlineKeyboardMarkup(keyboard)
|
||||
await update.message.reply_text(text=text, reply_markup=markup)
|
||||
|
||||
|
||||
async def set_steps_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_steps ")[1]
|
||||
global STEPS
|
||||
STEPS = int(input_args)
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_steps 30"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def set_negative_prompt_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_negative_prompt ")[1]
|
||||
global NEGATIVE_PROMPT
|
||||
NEGATIVE_PROMPT = input_args
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_negative_prompt ugly, bad art, mutated"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def set_guidance_scale_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_guidance_scale ")[1]
|
||||
global GUIDANCE_SCALE
|
||||
GUIDANCE_SCALE = int(input_args)
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_guidance_scale 7"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def setup_bot_commands(application: Application) -> None:
|
||||
await application.bot.set_my_commands(
|
||||
[
|
||||
BotCommand("select_model", "to select model"),
|
||||
BotCommand("select_scheduler", "to select scheduler"),
|
||||
BotCommand("set_steps", "to set steps"),
|
||||
BotCommand("set_guidance_scale", "to set guidance scale"),
|
||||
BotCommand("set_negative_prompt", "to set negative prompt"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
app = (
|
||||
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
|
||||
)
|
||||
app.add_handler(CommandHandler("select_model", select_model_handler))
|
||||
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
|
||||
app.add_handler(CommandHandler("set_steps", set_steps_handler))
|
||||
app.add_handler(
|
||||
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
|
||||
)
|
||||
app.add_handler(
|
||||
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
|
||||
)
|
||||
app.add_handler(
|
||||
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
|
||||
)
|
||||
app.add_handler(CallbackQueryHandler(button))
|
||||
log.warning("Start bot")
|
||||
app.run_polling()
|
||||
240
apps/stable_diffusion/scripts/txt2img.py
Normal file
240
apps/stable_diffusion/scripts/txt2img.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImagePipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
txt2img_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def txt2img_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
):
|
||||
global txt2img_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if not txt2img_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = ""
|
||||
args.use_tuned = True
|
||||
args.import_mlir = False
|
||||
args.img_path = None
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
txt2img_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
txt2img_obj.log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = txt2img_obj.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
txt2img_obj.log += "\n"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += (
|
||||
f"\nsteps={steps}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
)
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
|
||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
79
apps/stable_diffusion/shark_sd.spec
Normal file
79
apps/stable_diffusion/shark_sd.spec
Normal file
@@ -0,0 +1,79 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torchvision')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('diffusers')
|
||||
datas += copy_metadata('transformers')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('shark')
|
||||
datas += [
|
||||
( 'src/utils/resources/prompts.json', 'resources' ),
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
( 'src/utils/resources/opt_flags.json', 'resources' ),
|
||||
( 'src/utils/resources/base_model.json', 'resources' ),
|
||||
( 'web/ui/css/*', 'ui/css' ),
|
||||
( 'web/ui/logos/*', 'logos' )
|
||||
]
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
|
||||
a = Analysis(
|
||||
['web/index.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_sd',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
77
apps/stable_diffusion/shark_sd_cli.spec
Normal file
77
apps/stable_diffusion/shark_sd_cli.spec
Normal file
@@ -0,0 +1,77 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torchvision')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('diffusers')
|
||||
datas += copy_metadata('transformers')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('shark')
|
||||
datas += [
|
||||
( 'src/utils/resources/prompts.json', 'resources' ),
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
( 'src/utils/resources/opt_flags.json', 'resources' ),
|
||||
( 'src/utils/resources/base_model.json', 'resources' ),
|
||||
]
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
|
||||
a = Analysis(
|
||||
['scripts/txt2img.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_sd_cli',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
14
apps/stable_diffusion/src/__init__.py
Normal file
14
apps/stable_diffusion/src/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
args,
|
||||
set_init_device_flags,
|
||||
prompt_examples,
|
||||
get_available_devices,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines import (
|
||||
Text2ImagePipeline,
|
||||
InpaintPipeline,
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import get_schedulers
|
||||
12
apps/stable_diffusion/src/models/__init__.py
Normal file
12
apps/stable_diffusion/src/models/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from apps.stable_diffusion.src.models.model_wrappers import (
|
||||
SharkifyStableDiffusionModel,
|
||||
)
|
||||
from apps.stable_diffusion.src.models.opt_params import (
|
||||
get_vae_encode,
|
||||
get_vae,
|
||||
get_unet,
|
||||
get_clip,
|
||||
get_tokenizer,
|
||||
get_params,
|
||||
get_variant_version,
|
||||
)
|
||||
395
apps/stable_diffusion/src/models/model_wrappers.py
Normal file
395
apps/stable_diffusion/src/models/model_wrappers.py
Normal file
@@ -0,0 +1,395 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
import safetensors.torch
|
||||
import traceback
|
||||
import sys
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_opt_flags,
|
||||
base_models,
|
||||
args,
|
||||
fetch_or_delete_vmfbs,
|
||||
preprocessCKPT,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
)
|
||||
|
||||
|
||||
# These shapes are parameter dependent.
|
||||
def replace_shape_str(shape, max_len, width, height, batch_size):
|
||||
new_shape = []
|
||||
for i in range(len(shape)):
|
||||
if shape[i] == "max_len":
|
||||
new_shape.append(max_len)
|
||||
elif shape[i] == "height":
|
||||
new_shape.append(height)
|
||||
elif shape[i] == "width":
|
||||
new_shape.append(width)
|
||||
elif isinstance(shape[i], str):
|
||||
mul_val = int(shape[i].split("*")[0])
|
||||
if "batch_size" in shape[i]:
|
||||
new_shape.append(batch_size * mul_val)
|
||||
elif "height" in shape[i]:
|
||||
new_shape.append(height * mul_val)
|
||||
elif "width" in shape[i]:
|
||||
new_shape.append(width * mul_val)
|
||||
else:
|
||||
new_shape.append(shape[i])
|
||||
return new_shape
|
||||
|
||||
|
||||
# Get the input info for various models i.e. "unet", "clip", "vae", "vae_encode".
|
||||
def get_input_info(model_info, max_len, width, height, batch_size):
|
||||
dtype_config = {"f32": torch.float32, "i64": torch.int64}
|
||||
input_map = defaultdict(list)
|
||||
for k in model_info:
|
||||
for inp in model_info[k]:
|
||||
shape = model_info[k][inp]["shape"]
|
||||
dtype = dtype_config[model_info[k][inp]["dtype"]]
|
||||
tensor = None
|
||||
if isinstance(shape, list):
|
||||
clean_shape = replace_shape_str(
|
||||
shape, max_len, width, height, batch_size
|
||||
)
|
||||
if dtype == torch.int64:
|
||||
tensor = torch.randint(1, 3, tuple(clean_shape))
|
||||
else:
|
||||
tensor = torch.randn(*clean_shape).to(dtype)
|
||||
elif isinstance(shape, int):
|
||||
tensor = torch.tensor(shape).to(dtype)
|
||||
else:
|
||||
sys.exit("shape isn't specified correctly.")
|
||||
input_map[k].append(tensor)
|
||||
return input_map
|
||||
|
||||
|
||||
class SharkifyStableDiffusionModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
custom_weights: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
max_len: int = 64,
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
batch_size: int = 1,
|
||||
use_base_vae: bool = False,
|
||||
use_tuned: bool = False,
|
||||
low_cpu_mem_usage: bool = False
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
self.height = height // 8
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
self.custom_weights = custom_weights
|
||||
if custom_weights != "":
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
# TODO: remove the following line when stable-diffusion-2-1 works
|
||||
if self.model_id == "stabilityai/stable-diffusion-2-1":
|
||||
self.model_id = "stabilityai/stable-diffusion-2-1-base"
|
||||
self.custom_vae = custom_vae
|
||||
self.precision = precision
|
||||
self.base_vae = use_base_vae
|
||||
self.model_name = (
|
||||
str(batch_size)
|
||||
+ "_"
|
||||
+ str(max_len)
|
||||
+ "_"
|
||||
+ str(height)
|
||||
+ "_"
|
||||
+ str(width)
|
||||
+ "_"
|
||||
+ precision
|
||||
)
|
||||
self.use_tuned = use_tuned
|
||||
if use_tuned:
|
||||
self.model_name = self.model_name + "_tuned"
|
||||
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
|
||||
def get_extended_name_for_all_model(self):
|
||||
model_name = {}
|
||||
sub_model_list = ["clip", "unet", "vae", "vae_encode"]
|
||||
for model in sub_model_list:
|
||||
sub_model = model
|
||||
model_config = self.model_name
|
||||
if "vae" == model:
|
||||
if self.custom_vae != "":
|
||||
model_config = model_config + get_path_stem(self.custom_vae)
|
||||
if self.base_vae:
|
||||
sub_model = "base_vae"
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
return model_name
|
||||
|
||||
def check_params(self, max_len, width, height):
|
||||
if not (max_len >= 32 and max_len <= 77):
|
||||
sys.exit("please specify max_len in the range [32, 77].")
|
||||
if not (width % 8 == 0 and width >= 384):
|
||||
sys.exit("width should be greater than 384 and multiple of 8")
|
||||
if not (height % 8 == 0 and height >= 384):
|
||||
sys.exit("height should be greater than 384 and multiple of 8")
|
||||
|
||||
def get_vae_encode(self):
|
||||
class VaeEncodeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
latents = self.vae.encode(input).latent_dist.sample()
|
||||
return 0.18215 * latents
|
||||
|
||||
vae_encode = VaeEncodeModel()
|
||||
inputs = tuple(self.inputs["vae_encode"])
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
shark_vae_encode = compile_through_fx(
|
||||
vae_encode,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae_encode"],
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
)
|
||||
return shark_vae_encode
|
||||
|
||||
def get_vae(self):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.vae = None
|
||||
if custom_vae == "":
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif not isinstance(custom_vae, dict):
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.vae.load_state_dict(custom_vae)
|
||||
self.base_vae = base_vae
|
||||
|
||||
def forward(self, input):
|
||||
if not self.base_vae:
|
||||
input = 1 / 0.18215 * input
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
x = (x / 2 + 0.5).clamp(0, 1)
|
||||
if self.base_vae:
|
||||
return x
|
||||
x = x * 255.0
|
||||
return x.round()
|
||||
|
||||
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae"],
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
def get_unet(self):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(
|
||||
self, latent, timestep, text_embedding, guidance_scale
|
||||
):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latents = torch.cat([latent] * 2)
|
||||
unet_out = self.unet.forward(
|
||||
latents, timestep, text_embedding, return_dict=False
|
||||
)[0]
|
||||
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
input_mask = [True, True, True, False]
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
model_name=self.model_name["unet"],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
tuple(self.inputs["clip"]),
|
||||
model_name=self.model_name["clip"],
|
||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
def process_custom_vae(self):
|
||||
custom_vae = self.custom_vae.lower()
|
||||
if not custom_vae.endswith((".ckpt", ".safetensors")):
|
||||
return self.custom_vae
|
||||
try:
|
||||
preprocessCKPT(self.custom_vae)
|
||||
return get_path_to_diffusers_checkpoint(self.custom_vae)
|
||||
except:
|
||||
print("Processing standalone Vae checkpoint")
|
||||
vae_checkpoint = None
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
if custom_vae.endswith(".ckpt"):
|
||||
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
|
||||
else:
|
||||
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
|
||||
if "state_dict" in vae_checkpoint:
|
||||
vae_checkpoint = vae_checkpoint["state_dict"]
|
||||
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict
|
||||
|
||||
|
||||
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
|
||||
# configiration.
|
||||
def compile_all(self, base_model_id, need_vae_encode):
|
||||
self.inputs = get_input_info(
|
||||
base_models[base_model_id],
|
||||
self.max_len,
|
||||
self.width,
|
||||
self.height,
|
||||
self.batch_size,
|
||||
)
|
||||
compiled_unet = self.get_unet()
|
||||
if self.custom_vae != "":
|
||||
print("Plugging in custom Vae")
|
||||
compiled_vae = self.get_vae()
|
||||
compiled_clip = self.get_clip()
|
||||
if need_vae_encode:
|
||||
compiled_vae_encode = self.get_vae_encode()
|
||||
return compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode
|
||||
|
||||
return compiled_clip, compiled_unet, compiled_vae
|
||||
|
||||
def __call__(self):
|
||||
# Step 1:
|
||||
# -- Fetch all vmfbs for the model, if present, else delete the lot.
|
||||
need_vae_encode = args.img_path is not None
|
||||
self.model_name = self.get_extended_name_for_all_model()
|
||||
vmfbs = fetch_or_delete_vmfbs(self.model_name, need_vae_encode, self.precision)
|
||||
if vmfbs[0]:
|
||||
# -- If all vmfbs are indeed present, we also try and fetch the base
|
||||
# model configuration for running SD with custom checkpoints.
|
||||
if self.custom_weights != "":
|
||||
args.hf_model_id = fetch_and_update_base_model_id(self.custom_weights)
|
||||
if args.hf_model_id == "":
|
||||
sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.")
|
||||
print("Loaded vmfbs from cache and successfully fetched base model configuration.")
|
||||
if not need_vae_encode:
|
||||
return vmfbs[:3]
|
||||
return vmfbs
|
||||
|
||||
# Step 2:
|
||||
# -- If vmfbs weren't found, we try to see if the base model configuration
|
||||
# for the required SD run is known to us and bypass the retry mechanism.
|
||||
model_to_run = ""
|
||||
if self.custom_weights != "":
|
||||
model_to_run = self.custom_weights
|
||||
assert self.custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
preprocessCKPT(self.custom_weights)
|
||||
else:
|
||||
model_to_run = args.hf_model_id
|
||||
# For custom Vae user can provide either the repo-id or a checkpoint file,
|
||||
# and for a checkpoint file we'd need to process it via Diffusers' script.
|
||||
self.custom_vae = self.process_custom_vae()
|
||||
base_model_fetched = fetch_and_update_base_model_id(model_to_run)
|
||||
if base_model_fetched != "":
|
||||
print("Compiling all the models with the fetched base model configuration.")
|
||||
if args.ckpt_loc != "":
|
||||
args.hf_model_id = base_model_fetched
|
||||
return self.compile_all(base_model_fetched, need_vae_encode)
|
||||
|
||||
# Step 3:
|
||||
# -- This is the retry mechanism where the base model's configuration is not
|
||||
# known to us and figure that out by trial and error.
|
||||
print("Inferring base model configuration.")
|
||||
for model_id in base_models:
|
||||
try:
|
||||
if need_vae_encode:
|
||||
compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode = self.compile_all(model_id, need_vae_encode)
|
||||
else:
|
||||
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id, need_vae_encode)
|
||||
except Exception as e:
|
||||
print("Retrying with a different base model configuration")
|
||||
continue
|
||||
# -- Once a successful compilation has taken place we'd want to store
|
||||
# the base model's configuration inferred.
|
||||
fetch_and_update_base_model_id(model_to_run, model_id)
|
||||
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
|
||||
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
|
||||
# model and rely on retrying method to find the input configuration, we should also update
|
||||
# the knowledge of base model id accordingly into `args.hf_model_id`.
|
||||
if args.ckpt_loc != "":
|
||||
args.hf_model_id = model_id
|
||||
if need_vae_encode:
|
||||
return (
|
||||
compiled_clip,
|
||||
compiled_unet,
|
||||
compiled_vae,
|
||||
compiled_vae_encode,
|
||||
)
|
||||
return compiled_clip, compiled_unet, compiled_vae
|
||||
sys.exit(
|
||||
"Cannot compile the model. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
|
||||
)
|
||||
108
apps/stable_diffusion/src/models/opt_params.py
Normal file
108
apps/stable_diffusion/src/models/opt_params.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import sys
|
||||
from transformers import CLIPTokenizer
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
models_db,
|
||||
args,
|
||||
get_shark_model,
|
||||
get_opt_flags,
|
||||
)
|
||||
|
||||
|
||||
hf_model_variant_map = {
|
||||
"Linaqruf/anything-v3.0": ["anythingv3", "v1_4"],
|
||||
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v1_4"],
|
||||
"prompthero/openjourney": ["openjourney", "v1_4"],
|
||||
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v1_4"],
|
||||
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
|
||||
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
|
||||
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
|
||||
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
|
||||
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
|
||||
}
|
||||
|
||||
|
||||
def get_variant_version(hf_model_id):
|
||||
return hf_model_variant_map[hf_model_id]
|
||||
|
||||
|
||||
def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
try:
|
||||
bucket = models_db[0][bucket_key]
|
||||
model_name = models_db[1][model_key]
|
||||
except KeyError:
|
||||
raise Exception(
|
||||
f"{bucket_key}/{model_key} is not present in the models database"
|
||||
)
|
||||
iree_flags = get_opt_flags(model, precision="fp16")
|
||||
return bucket, model_name, iree_flags
|
||||
|
||||
|
||||
def get_unet():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "unet", is_tuned, args.precision
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae_encode():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "vae", is_tuned, args.precision
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "vae", is_tuned, args.precision
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_clip():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
bucket_key = f"{variant}/untuned"
|
||||
model_key = (
|
||||
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
|
||||
)
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "clip", "untuned", "fp32"
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_tokenizer():
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.hf_model_id, subfolder="tokenizer"
|
||||
)
|
||||
return tokenizer
|
||||
9
apps/stable_diffusion/src/pipelines/__init__.py
Normal file
9
apps/stable_diffusion/src/pipelines/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
|
||||
Text2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
|
||||
InpaintPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
@@ -0,0 +1,169 @@
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from random import randint
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
|
||||
class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae_encode: SharkInference,
|
||||
vae: SharkInference,
|
||||
text_encoder: SharkInference,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: SharkInference,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
||||
self.vae_encode = vae_encode
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
dtype,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre process image
|
||||
image = image.resize((width, height))
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
|
||||
# set scheduler steps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
init_timestep = min(
|
||||
int(num_inference_steps * strength), num_inference_steps
|
||||
)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
# timesteps reduced as per strength
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
# new number of steps to be used as per strength will be
|
||||
# num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# image encode
|
||||
latents = self.encode_image((image_arr,))
|
||||
latents = torch.from_numpy(latents).to(dtype)
|
||||
# add noise to data
|
||||
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents, noise, timesteps[0].repeat(1)
|
||||
)
|
||||
|
||||
return latents, timesteps
|
||||
|
||||
def encode_image(self, input_image):
|
||||
vae_encode_start = time.time()
|
||||
latents = self.vae_encode("forward", input_image)
|
||||
vae_inf_time = (time.time() - vae_encode_start) * 1000
|
||||
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
return latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get text embeddings from prompts
|
||||
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Prepare input image latent
|
||||
image_latents, final_timesteps = self.prepare_image_latents(
|
||||
image=image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=image_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=final_timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
|
||||
return all_imgs
|
||||
@@ -0,0 +1,229 @@
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
|
||||
class InpaintPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae_encode: SharkInference,
|
||||
vae: SharkInference,
|
||||
text_encoder: SharkInference,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: SharkInference,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
||||
self.vae_encode = vae_encode
|
||||
|
||||
def prepare_mask_and_masked_image(self, image, mask):
|
||||
# preprocess image
|
||||
if isinstance(image, (Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], Image.Image):
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
|
||||
mask = np.concatenate(
|
||||
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
||||
)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
):
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // 8, width // 8)
|
||||
)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
masked_image = masked_image.to(dtype)
|
||||
masked_image_latents = self.vae_encode("forward", (masked_image,))
|
||||
masked_image_latents = torch.from_numpy(masked_image_latents)
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
return mask, masked_image_latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
mask_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings from prompts
|
||||
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Preprocess mask and image
|
||||
mask, masked_image = self.prepare_mask_and_masked_image(
|
||||
image, mask_image
|
||||
)
|
||||
|
||||
# Prepare mask latent variables
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask=mask,
|
||||
masked_image=masked_image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
mask=mask,
|
||||
masked_image_latents=masked_image_latents,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
|
||||
return all_imgs
|
||||
@@ -0,0 +1,137 @@
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
|
||||
class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: SharkInference,
|
||||
text_encoder: SharkInference,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: SharkInference,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings from prompts
|
||||
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
|
||||
return all_imgs
|
||||
@@ -0,0 +1,261 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
import time
|
||||
from typing import Union
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
get_vae,
|
||||
get_clip,
|
||||
get_unet,
|
||||
get_tokenizer,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
start_profiling,
|
||||
end_profiling,
|
||||
)
|
||||
|
||||
|
||||
class StableDiffusionPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
vae: SharkInference,
|
||||
text_encoder: SharkInference,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: SharkInference,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
],
|
||||
):
|
||||
self.vae = vae
|
||||
self.text_encoder = text_encoder
|
||||
self.tokenizer = tokenizer
|
||||
self.unet = unet
|
||||
self.scheduler = scheduler
|
||||
# TODO: Implement using logging python utility.
|
||||
self.log = ""
|
||||
|
||||
def encode_prompts(self, prompts, neg_prompts, max_length):
|
||||
# Tokenize text and get embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Get unconditional embeddings as well
|
||||
uncond_input = self.tokenizer(
|
||||
neg_prompts,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
|
||||
|
||||
clip_inf_start = time.time()
|
||||
text_embeddings = self.text_encoder("forward", (text_input,))
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
|
||||
if use_base_vae:
|
||||
latents = 1 / 0.18215 * latents
|
||||
|
||||
latents_numpy = latents
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
|
||||
profile_device = start_profiling(file_path="vae.rdc")
|
||||
vae_start = time.time()
|
||||
images = self.vae("forward", (latents_numpy,))
|
||||
vae_inf_time = (time.time() - vae_start) * 1000
|
||||
end_profiling(profile_device)
|
||||
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
if use_base_vae:
|
||||
images = torch.from_numpy(images)
|
||||
images = (images.detach().cpu() * 255.0).numpy()
|
||||
images = images.round()
|
||||
|
||||
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
pil_images = [Image.fromarray(image) for image in images.numpy()]
|
||||
return pil_images
|
||||
|
||||
def produce_img_latents(
|
||||
self,
|
||||
latents,
|
||||
text_embeddings,
|
||||
guidance_scale,
|
||||
total_timesteps,
|
||||
dtype,
|
||||
cpu_scheduling,
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
step_time_sum = 0
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, t)
|
||||
if mask is not None and masked_image_latents is not None:
|
||||
latent_model_input = torch.cat(
|
||||
[
|
||||
torch.from_numpy(np.asarray(latent_model_input)),
|
||||
mask,
|
||||
masked_image_latents,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype)
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
noise_pred = torch.from_numpy(noise_pred.to_host())
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents
|
||||
).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents)
|
||||
|
||||
latent_history.append(latents)
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
# self.log += (
|
||||
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
if not return_all_latents:
|
||||
return latents
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
],
|
||||
import_mlir: bool,
|
||||
model_id: str,
|
||||
ckpt_loc: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
max_length: int,
|
||||
batch_size: int,
|
||||
height: int,
|
||||
width: int,
|
||||
use_base_vae: bool,
|
||||
use_tuned: bool,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
):
|
||||
if import_mlir:
|
||||
mlir_import = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
ckpt_loc,
|
||||
custom_vae,
|
||||
precision,
|
||||
max_len=max_length,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
use_base_vae=use_base_vae,
|
||||
use_tuned=use_tuned,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
|
||||
clip, unet, vae, vae_encode = mlir_import()
|
||||
return cls(
|
||||
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
|
||||
)
|
||||
clip, unet, vae = mlir_import()
|
||||
return cls(vae, clip, get_tokenizer(), unet, scheduler)
|
||||
try:
|
||||
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
|
||||
return cls(
|
||||
get_vae_encode(),
|
||||
get_vae(),
|
||||
get_clip(),
|
||||
get_tokenizer(),
|
||||
get_unet(),
|
||||
scheduler,
|
||||
)
|
||||
return cls(
|
||||
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
|
||||
)
|
||||
except:
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
mlir_import = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
ckpt_loc,
|
||||
custom_vae,
|
||||
precision,
|
||||
max_len=max_length,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
use_base_vae=use_base_vae,
|
||||
use_tuned=use_tuned,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
|
||||
clip, unet, vae, vae_encode = mlir_import()
|
||||
return cls(
|
||||
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
|
||||
)
|
||||
clip, unet, vae = mlir_import()
|
||||
return cls(vae, clip, get_tokenizer(), unet, scheduler)
|
||||
4
apps/stable_diffusion/src/schedulers/__init__.py
Normal file
4
apps/stable_diffusion/src/schedulers/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
56
apps/stable_diffusion/src/schedulers/sd_schedulers.py
Normal file
56
apps/stable_diffusion/src/schedulers/sd_schedulers.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
|
||||
|
||||
def get_schedulers(model_id):
|
||||
schedulers = dict()
|
||||
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"EulerAncestralDiscrete"
|
||||
] = EulerAncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"SharkEulerDiscrete"
|
||||
] = SharkEulerDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["SharkEulerDiscrete"].compile()
|
||||
return schedulers
|
||||
156
apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py
Normal file
156
apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_shark_model,
|
||||
args,
|
||||
)
|
||||
import torch
|
||||
|
||||
|
||||
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
beta_start,
|
||||
beta_end,
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
)
|
||||
|
||||
def compile(self):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
BATCH_SIZE = args.batch_size
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
"latent": torch.randn(
|
||||
BATCH_SIZE, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"output": torch.randn(
|
||||
BATCH_SIZE, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"dt": torch.tensor(1).to(torch.float32),
|
||||
},
|
||||
}
|
||||
|
||||
example_latent = model_input["euler"]["latent"]
|
||||
example_output = model_input["euler"]["output"]
|
||||
if args.precision == "fp16":
|
||||
example_latent = example_latent.half()
|
||||
example_output = example_output.half()
|
||||
example_sigma = model_input["euler"]["sigma"]
|
||||
example_dt = model_input["euler"]["dt"]
|
||||
|
||||
class ScalingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma, latent, dt):
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
return latent + derivative * dt
|
||||
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model = compile_through_fx(
|
||||
scaling_model,
|
||||
(example_latent, example_sigma),
|
||||
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
step_model = SchedulerStepModel()
|
||||
self.step_model = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
if args.import_mlir:
|
||||
_import(self)
|
||||
|
||||
else:
|
||||
try:
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_step_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"failed to download model, falling back and using import_mlir"
|
||||
)
|
||||
args.import_mlir = True
|
||||
_import(self)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
return self.scaling_model(
|
||||
"forward",
|
||||
(
|
||||
sample,
|
||||
sigma,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(self, noise_pred, timestep, latent):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
dt = self.sigmas[step_index + 1] - sigma
|
||||
return self.step_model(
|
||||
"forward",
|
||||
(
|
||||
noise_pred,
|
||||
sigma,
|
||||
latent,
|
||||
dt,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
31
apps/stable_diffusion/src/utils/__init__.py
Normal file
31
apps/stable_diffusion/src/utils/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from apps.stable_diffusion.src.utils.profiler import (
|
||||
start_profiling,
|
||||
end_profiling,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.resources import (
|
||||
prompt_examples,
|
||||
models_db,
|
||||
base_models,
|
||||
opt_flags,
|
||||
resource_path,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
get_shark_model,
|
||||
compile_through_fx,
|
||||
set_iree_runtime_flags,
|
||||
map_device_to_name_path,
|
||||
set_init_device_flags,
|
||||
get_available_devices,
|
||||
get_opt_flags,
|
||||
preprocessCKPT,
|
||||
fetch_or_delete_vmfbs,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
sanitize_seed,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
18
apps/stable_diffusion/src/utils/profiler.py
Normal file
18
apps/stable_diffusion/src/utils/profiler.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
|
||||
# Helper function to profile the vulkan device.
|
||||
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
|
||||
if args.vulkan_debug_utils and "vulkan" in args.device:
|
||||
import iree
|
||||
|
||||
print(f"Profiling and saving to {file_path}.")
|
||||
vulkan_device = iree.runtime.get_device(args.device)
|
||||
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
|
||||
return vulkan_device
|
||||
return None
|
||||
|
||||
|
||||
def end_profiling(device):
|
||||
if device:
|
||||
return device.end_profiling()
|
||||
37
apps/stable_diffusion/src/utils/resources.py
Normal file
37
apps/stable_diffusion/src/utils/resources.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def get_json_file(path):
|
||||
json_var = []
|
||||
loc_json = resource_path(path)
|
||||
if os.path.exists(loc_json):
|
||||
with open(loc_json, encoding="utf-8") as fopen:
|
||||
json_var = json.load(fopen)
|
||||
|
||||
if not json_var:
|
||||
print(f"Unable to fetch {path}")
|
||||
|
||||
return json_var
|
||||
|
||||
|
||||
# TODO: This shouldn't be called from here, every time the file imports
|
||||
# it will run all the global vars.
|
||||
prompt_examples = get_json_file("resources/prompts.json")
|
||||
models_db = get_json_file("resources/model_db.json")
|
||||
|
||||
# The base_model contains the input configuration for the different
|
||||
# models and also helps in providing information for the variants.
|
||||
base_models = get_json_file("resources/base_model.json")
|
||||
|
||||
# Contains optimization flags for different models.
|
||||
opt_flags = get_json_file("resources/opt_flags.json")
|
||||
226
apps/stable_diffusion/src/utils/resources/base_model.json
Normal file
226
apps/stable_diffusion/src/utils/resources/base_model.json
Normal file
@@ -0,0 +1,226 @@
|
||||
{
|
||||
"stabilityai/stable-diffusion-2-1": {
|
||||
"unet": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
1024
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"height","width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"2*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
}
|
||||
},
|
||||
"CompVis/stable-diffusion-v1-4": {
|
||||
"unet": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"height","width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"2*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
}
|
||||
},
|
||||
"runwayml/stable-diffusion-inpainting": {
|
||||
"unet": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
9,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"height","width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"2*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-2-inpainting": {
|
||||
"unet": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
9,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
1024
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"height","width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"2*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
23
apps/stable_diffusion/src/utils/resources/model_config.json
Normal file
23
apps/stable_diffusion/src/utils/resources/model_config.json
Normal file
@@ -0,0 +1,23 @@
|
||||
[
|
||||
{
|
||||
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
|
||||
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
|
||||
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
|
||||
"stablediffusion/inpaint_v1":"runwayml/stable-diffusion-inpainting",
|
||||
"stablediffusion/inpaint_v2":"stabilityai/stable-diffusion-2-inpainting",
|
||||
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
|
||||
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
|
||||
"openjourney/v1_4":"prompthero/openjourney",
|
||||
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
|
||||
},
|
||||
{
|
||||
"stablediffusion/fp16":"fp16",
|
||||
"stablediffusion/fp32":"main",
|
||||
"anythingv3/fp16":"diffusers",
|
||||
"anythingv3/fp32":"diffusers",
|
||||
"analogdiffusion/fp16":"main",
|
||||
"analogdiffusion/fp32":"main",
|
||||
"openjourney/fp16":"main",
|
||||
"openjourney/fp32":"main"
|
||||
}
|
||||
]
|
||||
91
apps/stable_diffusion/src/utils/resources/model_db.json
Normal file
91
apps/stable_diffusion/src/utils/resources/model_db.json
Normal file
@@ -0,0 +1,91 @@
|
||||
[
|
||||
{
|
||||
"stablediffusion/untuned":"gs://shark_tank/sd_untuned",
|
||||
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
|
||||
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
|
||||
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
|
||||
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
|
||||
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
|
||||
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
|
||||
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
|
||||
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
|
||||
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
|
||||
"openjourney/tuned":"gs://shark_tank/sd_tuned",
|
||||
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
|
||||
},
|
||||
{
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
|
||||
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/inpaint_v1/unet/fp16/length_77/untuned":"unet_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v1/unet/fp32/length_77/untuned":"unet_inpaint_fp32",
|
||||
"stablediffusion/inpaint_v1/vae_encode/fp16/length_77/untuned":"vae_encode_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v1/vae_encode/fp32/length_77/untuned":"vae_encode_inpaint_fp32",
|
||||
"stablediffusion/inpaint_v1/vae/fp16/length_77/untuned":"vae_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v1/vae/fp32/length_77/untuned":"vae_inpaint_fp32",
|
||||
"stablediffusion/inpaint_v1/clip/fp32/length_77/untuned":"clip_inpaint_fp32",
|
||||
"stablediffusion/inpaint_v2/unet/fp16/length_77/untuned":"unet_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v2/vae_encode/fp16/length_77/untuned":"vae_encode_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v2/vae/fp16/length_77/untuned":"vae_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v2/clip/fp32/length_77/untuned":"clip_inpaint_fp32",
|
||||
"anythingv3/v1_4/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
|
||||
"anythingv3/v1_4/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
|
||||
"anythingv3/v1_4/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
|
||||
"anythingv3/v1_4/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
|
||||
"anythingv3/v1_4/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
|
||||
"anythingv3/v1_4/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
|
||||
"anythingv3/v1_4/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
|
||||
"anythingv3/v1_4/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
|
||||
"anythingv3/v1_4/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
|
||||
"anythingv3/v1_4/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
|
||||
"anythingv3/v1_4/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
|
||||
"analogdiffusion/v1_4/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
|
||||
"analogdiffusion/v1_4/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
|
||||
"analogdiffusion/v1_4/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
|
||||
"analogdiffusion/v1_4/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
|
||||
"analogdiffusion/v1_4/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
|
||||
"analogdiffusion/v1_4/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
|
||||
"analogdiffusion/v1_4/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
|
||||
"analogdiffusion/v1_4/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
|
||||
"analogdiffusion/v1_4/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
|
||||
"analogdiffusion/v1_4/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
|
||||
"analogdiffusion/v1_4/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
|
||||
"openjourney/v1_4/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
|
||||
"openjourney/v1_4/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
|
||||
"openjourney/v1_4/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
|
||||
"openjourney/v1_4/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
|
||||
"openjourney/v1_4/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
|
||||
"openjourney/v1_4/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
|
||||
"openjourney/v1_4/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
|
||||
"dreamlike/v1_4/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
|
||||
"dreamlike/v1_4/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
|
||||
"dreamlike/v1_4/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
|
||||
"dreamlike/v1_4/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
|
||||
"dreamlike/v1_4/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
|
||||
"dreamlike/v1_4/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
|
||||
"dreamlike/v1_4/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
|
||||
}
|
||||
]
|
||||
84
apps/stable_diffusion/src/utils/resources/opt_flags.json
Normal file
84
apps/stable_diffusion/src/utils/resources/opt_flags.json
Normal file
@@ -0,0 +1,84 @@
|
||||
{
|
||||
"unet": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": []
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": []
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [],
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [],
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
8
apps/stable_diffusion/src/utils/resources/prompts.json
Normal file
8
apps/stable_diffusion/src/utils/resources/prompts.json
Normal file
@@ -0,0 +1,8 @@
|
||||
[["A high tech solarpunk utopia in the Amazon rainforest"],
|
||||
["A pikachu fine dining with a view to the Eiffel Tower"],
|
||||
["A mecha robot in a favela in expressionist style"],
|
||||
["an insect robot preparing a delicious meal"],
|
||||
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
|
||||
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
|
||||
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
|
||||
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]
|
||||
236
apps/stable_diffusion/src/utils/sd_annotation.py
Normal file
236
apps/stable_diffusion/src/utils/sd_annotation.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import os
|
||||
import io
|
||||
from shark.model_annotation import model_annotation, create_context
|
||||
from shark.iree_utils._common import iree_target_map, run_cmd
|
||||
from shark.shark_downloader import (
|
||||
download_model,
|
||||
download_public_file,
|
||||
WORKDIR,
|
||||
)
|
||||
from shark.parser import shark_args
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
|
||||
def get_device():
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else args.device.split("://")[0]
|
||||
)
|
||||
return device
|
||||
|
||||
|
||||
def get_device_args():
|
||||
device = get_device()
|
||||
device_spec_args = []
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
gpu_flags = get_iree_gpu_args()
|
||||
for flag in gpu_flags:
|
||||
device_spec_args.append(flag)
|
||||
elif device == "vulkan":
|
||||
device_spec_args.append(
|
||||
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
|
||||
)
|
||||
return device, device_spec_args
|
||||
|
||||
|
||||
# Download the model (Unet or VAE fp16) from shark_tank
|
||||
def load_model_from_tank():
|
||||
from apps.stable_diffusion.src.models import (
|
||||
get_params,
|
||||
get_variant_version,
|
||||
)
|
||||
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
bucket_key = f"{variant}/untuned"
|
||||
if args.annotation_model == "unet":
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
|
||||
elif args.annotation_model == "vae":
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, args.annotation_model, "untuned", args.precision
|
||||
)
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=bucket,
|
||||
frontend="torch",
|
||||
)
|
||||
return mlir_model, model_name
|
||||
|
||||
|
||||
# Download the tuned config files from shark_tank
|
||||
def load_winograd_configs():
|
||||
device = get_device()
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
config_name = f"{args.annotation_model}_winograd_{device}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
winograd_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading Winograd config file from ", winograd_config_dir)
|
||||
download_public_file(full_gs_url, winograd_config_dir, True)
|
||||
return winograd_config_dir
|
||||
|
||||
|
||||
def load_lower_configs():
|
||||
from apps.stable_diffusion.src.models import get_variant_version
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
)
|
||||
|
||||
base_model_id = args.hf_model_id
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
if base_model_id == "runwayml/stable-diffusion-v1-5":
|
||||
base_model_id = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
variant, version = get_variant_version(base_model_id)
|
||||
|
||||
config_bucket = "gs://shark_tank/sd_tuned_configs/"
|
||||
|
||||
device, device_spec_args = get_device_args()
|
||||
spec = ""
|
||||
if device_spec_args:
|
||||
spec = device_spec_args[-1].split("=")[-1].strip()
|
||||
if device == "vulkan":
|
||||
spec = spec.split("-")[0]
|
||||
|
||||
if args.annotation_model == "vae":
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
config_name = (
|
||||
f"{args.annotation_model}_{args.precision}_{device}.json"
|
||||
)
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
|
||||
else:
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
|
||||
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading lowering config file from ", lowering_config_dir)
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
return lowering_config_dir
|
||||
|
||||
|
||||
# Annotate the model with Winograd attribute on selected conv ops
|
||||
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
|
||||
with create_context() as ctx:
|
||||
winograd_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=input_mlir,
|
||||
config_path=winograd_config_dir,
|
||||
search_op="conv",
|
||||
winograd=True,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
winograd_model.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if args.save_annotation:
|
||||
if model_name.split("_")[-1] != "tuned":
|
||||
out_file_path = (
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
)
|
||||
else:
|
||||
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(winograd_model))
|
||||
f.close()
|
||||
|
||||
return bytecode
|
||||
|
||||
|
||||
def dump_after_mlir(input_mlir, use_winograd):
|
||||
import iree.compiler as ireec
|
||||
|
||||
device, device_spec_args = get_device_args()
|
||||
if use_winograd:
|
||||
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
else:
|
||||
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
|
||||
dump_module = ireec.compile_str(
|
||||
input_mlir,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=device_spec_args
|
||||
+ [
|
||||
preprocess_flag,
|
||||
"--compile-to=preprocessing",
|
||||
],
|
||||
)
|
||||
return dump_module
|
||||
|
||||
|
||||
# For Unet annotate the model with tuned lowering configs
|
||||
def annotate_with_lower_configs(
|
||||
input_mlir, lowering_config_dir, model_name, use_winograd
|
||||
):
|
||||
# Dump IR after padding/img2col/winograd passes
|
||||
dump_module = dump_after_mlir(input_mlir, use_winograd)
|
||||
print("Applying tuned configs on", model_name)
|
||||
|
||||
# Annotate the model with lowering configs in the config file
|
||||
with create_context() as ctx:
|
||||
tuned_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=dump_module,
|
||||
config_path=lowering_config_dir,
|
||||
search_op="all",
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
tuned_model.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if args.save_annotation:
|
||||
if model_name.split("_")[-1] != "tuned":
|
||||
out_file_path = (
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
)
|
||||
else:
|
||||
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(tuned_model))
|
||||
f.close()
|
||||
|
||||
return bytecode
|
||||
|
||||
|
||||
def sd_model_annotation(mlir_model, model_name):
|
||||
device = get_device()
|
||||
if args.annotation_model == "unet" and device == "vulkan":
|
||||
use_winograd = True
|
||||
winograd_config_dir = load_winograd_configs()
|
||||
winograd_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
lowering_config_dir = load_lower_configs()
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
winograd_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
elif args.annotation_model == "vae" and device == "vulkan":
|
||||
use_winograd = True
|
||||
winograd_config_dir = load_winograd_configs()
|
||||
tuned_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
else:
|
||||
use_winograd = False
|
||||
lowering_config_dir = load_lower_configs()
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
mlir_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
return tuned_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlir_model, model_name = load_model_from_tank()
|
||||
sd_model_annotation(mlir_model, model_name)
|
||||
370
apps/stable_diffusion/src/utils/stable_args.py
Normal file
370
apps/stable_diffusion/src/utils/stable_args.py
Normal file
@@ -0,0 +1,370 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"-p",
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=["cyberpunk forest by Salvador Dali"],
|
||||
help="text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative_prompts",
|
||||
nargs="+",
|
||||
default=["trees, green"],
|
||||
help="text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--img_path",
|
||||
type=str,
|
||||
help="Path to the image input for img2img/inpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--mask_path",
|
||||
type=str,
|
||||
help="Path to the mask image input for inpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the no. of steps to do the sampling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="the seed to use. -1 for a random one.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=range(1, 4),
|
||||
help="the number of inferences to be made in a single `batch_count`.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=512,
|
||||
help="the height of the output image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=512,
|
||||
help="the width of the output image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="the value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=64,
|
||||
help="max length of the tokenizer output, options are 64 and 77.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="the strength of change applied on the given input image for img2img",
|
||||
)
|
||||
##############################################################################
|
||||
### Model Config and Usage Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="vulkan", help="device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp16", help="precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="saves the compiled flatbuffer to the local directory",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_base_vae",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Do conversion from the VAE output to pixel space on cpu.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="SharkEulerDiscrete",
|
||||
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_img_format",
|
||||
type=str,
|
||||
default="png",
|
||||
help="specify the format in which output image is save. Supported options: jpg / png",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory path to save the output images and json",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--batch_count",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of batch to be generated with random seeds in single execution",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ckpt_loc",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to SD's .ckpt file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--custom_vae",
|
||||
type=str,
|
||||
default="",
|
||||
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_model_id",
|
||||
type=str,
|
||||
default="stabilityai/stable-diffusion-2-1-base",
|
||||
help="The repo-id of hugging face.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--low_cpu_mem_usage",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use the accelerate package to reduce cpu memory consumption",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree_vulkan_target_triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Profiles vulkan device and collects the .rdc info",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="4147483648",
|
||||
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for disabling vulkan validation layers when benchmarking",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Misc. Debug and Optimization flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--use_compiled_scheduler",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="use the default scheduler precompiled into the model if available",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dump_isa",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks",
|
||||
default=None,
|
||||
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="temp_dispatch_benchmarks",
|
||||
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--enable_rgp",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for inserting debug frames between iterations for use with rgp.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hide_steps",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for hiding the details of iteration/sec for each step.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--warmup_count",
|
||||
type=int,
|
||||
default=0,
|
||||
help="flag setting warmup count for clip and vae [>= 0].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--clear_all",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_metadata_to_json",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for whether or not to save a generation information json file with the image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--write_metadata_to_png",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Web UI flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--progress_bar",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for removing the progress bar animation during image generation",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
|
||||
)
|
||||
|
||||
|
||||
p.add_argument(
|
||||
"--share",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for generating a public URL",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=8080,
|
||||
help="flag for setting server port",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### SD model auto-annotation flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_output",
|
||||
type=path_expand,
|
||||
default="./",
|
||||
help="Directory to save the annotated mlir file",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_model",
|
||||
type=str,
|
||||
default="unet",
|
||||
help="Options are unet and vae.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_annotation",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Save annotated mlir file",
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
616
apps/stable_diffusion/src/utils/utils.py
Normal file
616
apps/stable_diffusion/src/utils/utils.py
Normal file
@@ -0,0 +1,616 @@
|
||||
import os
|
||||
import gc
|
||||
import json
|
||||
import re
|
||||
from PIL import PngImagePlugin
|
||||
from datetime import datetime as dt
|
||||
from csv import DictWriter
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
import sys
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
|
||||
def get_extended_name(model_name):
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
return extended_name
|
||||
|
||||
|
||||
def get_vmfb_path_name(model_name):
|
||||
vmfb_path = os.path.join(os.getcwd(), model_name + ".vmfb")
|
||||
return vmfb_path
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
vmfb_path = get_vmfb_path_name(model_name)
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
print("Saving to {}".format(vmfb_path))
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(
|
||||
vmfb_path
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), model_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
from shark.parser import shark_args
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
if "cuda" in args.device:
|
||||
shark_args.enable_tf32 = True
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=tank_url,
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
# Converts the torch-module into a shark_module.
|
||||
def compile_through_fx(
|
||||
model,
|
||||
inputs,
|
||||
model_name,
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
use_tuned=False,
|
||||
extra_args=[],
|
||||
):
|
||||
from shark.parser import shark_args
|
||||
|
||||
if "cuda" in args.device:
|
||||
shark_args.enable_tf32 = True
|
||||
|
||||
mlir_module, func_name = import_with_fx(
|
||||
model, inputs, is_f16, f16_input_mask
|
||||
)
|
||||
|
||||
if use_tuned:
|
||||
if "vae" in model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
mlir_module = sd_model_annotation(mlir_module, model_name)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
del mlir_module
|
||||
gc.collect()
|
||||
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
|
||||
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
|
||||
]
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
"""
|
||||
Inputs: driver_name
|
||||
Returns a list of all the available devices for a given driver sorted by
|
||||
the iree path names of the device as in --list_devices option in iree.
|
||||
"""
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return device_list_src
|
||||
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired combination of name/path.
|
||||
"""
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
|
||||
def get_output_value(dev_dict):
|
||||
if key_combination == 1:
|
||||
return f"{driver}://{dev_dict['path']}"
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
|
||||
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
return device_map
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
try:
|
||||
device_mapping = device_map[device]
|
||||
except KeyError:
|
||||
raise ValueError(f"Device '{device}' is not a valid device.")
|
||||
return device_mapping
|
||||
|
||||
|
||||
def set_init_device_flags():
|
||||
if "vulkan" in args.device:
|
||||
# set runtime flags for vulkan.
|
||||
set_iree_runtime_flags()
|
||||
|
||||
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
||||
device_name, args.device = map_device_to_name_path(args.device)
|
||||
if not args.iree_vulkan_target_triple:
|
||||
triple = get_vulkan_target_triple(device_name)
|
||||
if triple is not None:
|
||||
args.iree_vulkan_target_triple = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
|
||||
)
|
||||
elif "cuda" in args.device:
|
||||
args.device = "cuda"
|
||||
elif "cpu" in args.device:
|
||||
args.device = "cpu"
|
||||
|
||||
# set max_length based on availability.
|
||||
if args.hf_model_id in [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
]:
|
||||
args.max_length = 77
|
||||
elif args.hf_model_id == "prompthero/openjourney":
|
||||
args.max_length = 64
|
||||
|
||||
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
|
||||
base_model_id = args.hf_model_id
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
|
||||
if (
|
||||
args.hf_model_id
|
||||
in [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
]
|
||||
or args.precision != "fp16"
|
||||
or args.height != 512
|
||||
or args.width != 512
|
||||
or args.batch_size != 1
|
||||
or ("vulkan" not in args.device and "cuda" not in args.device)
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif args.ckpt_loc != "" and base_model_id not in [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
]:
|
||||
args.use_tuned = False
|
||||
|
||||
elif "vulkan" in args.device and not any(
|
||||
x in args.iree_vulkan_target_triple for x in ["rdna2", "rdna3"]
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]:
|
||||
args.use_tuned = False
|
||||
|
||||
elif args.use_base_vae and args.hf_model_id not in [
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]:
|
||||
args.use_tuned = False
|
||||
|
||||
if args.use_tuned:
|
||||
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
|
||||
else:
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
# set import_mlir to True for unuploaded models.
|
||||
if args.ckpt_loc != "":
|
||||
args.import_mlir = True
|
||||
|
||||
elif args.hf_model_id not in [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
]:
|
||||
args.import_mlir = True
|
||||
|
||||
elif args.height != 512 or args.width != 512 or args.batch_size != 1:
|
||||
args.import_mlir = True
|
||||
|
||||
elif args.use_tuned and args.hf_model_id in [
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
"prompthero/openjourney",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
]:
|
||||
args.import_mlir = True
|
||||
|
||||
elif (
|
||||
args.use_tuned
|
||||
and "vulkan" in args.device
|
||||
and "rdna2" in args.iree_vulkan_target_triple
|
||||
):
|
||||
args.import_mlir = True
|
||||
|
||||
elif (
|
||||
args.use_tuned
|
||||
and "cuda" in args.device
|
||||
and get_cuda_sm_cc() == "sm_89"
|
||||
):
|
||||
args.import_mlir = True
|
||||
|
||||
|
||||
# Utility to get list of devices available.
|
||||
def get_available_devices():
|
||||
def get_devices_by_name(driver_name):
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
driver_name = iree_device_map(driver_name)
|
||||
device_list_dict = get_all_devices(driver_name)
|
||||
print(f"{driver_name} devices are available.")
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_list.append(f"{device['name']} => {driver_name}://{i}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
vulkan_devices = get_devices_by_name("vulkan")
|
||||
available_devices.extend(vulkan_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
available_devices.append("cpu")
|
||||
return available_devices
|
||||
|
||||
|
||||
def disk_space_check(path, lim=20):
|
||||
from shutil import disk_usage
|
||||
|
||||
du = disk_usage(path)
|
||||
free = du.free / (1024 * 1024 * 1024)
|
||||
if free <= lim:
|
||||
print(f"[WARNING] Only {free:.2f}GB space available in {path}.")
|
||||
|
||||
|
||||
def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags = []
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"default_compilation_flags"
|
||||
]
|
||||
|
||||
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else args.device.split("://")[0]
|
||||
)
|
||||
if (
|
||||
device
|
||||
not in opt_flags[model][is_tuned][precision][
|
||||
"specified_compilation_flags"
|
||||
]
|
||||
):
|
||||
device = "default_device"
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"specified_compilation_flags"
|
||||
][device]
|
||||
return iree_flags
|
||||
|
||||
|
||||
def get_path_stem(path):
|
||||
path = Path(path)
|
||||
return path.stem
|
||||
|
||||
|
||||
def get_path_to_diffusers_checkpoint(custom_weights):
|
||||
path = Path(custom_weights)
|
||||
diffusers_path = path.parent.absolute()
|
||||
diffusers_directory_name = path.stem
|
||||
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
|
||||
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
|
||||
path_to_diffusers = complete_path_to_diffusers.as_posix()
|
||||
return path_to_diffusers
|
||||
|
||||
|
||||
def preprocessCKPT(custom_weights):
|
||||
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
if next(Path(path_to_diffusers).iterdir(), None):
|
||||
print("Checkpoint already loaded at : ", path_to_diffusers)
|
||||
return
|
||||
else:
|
||||
print(
|
||||
"Diffusers' checkpoint will be identified here : ",
|
||||
path_to_diffusers,
|
||||
)
|
||||
from_safetensors = (
|
||||
True if custom_weights.lower().endswith(".safetensors") else False
|
||||
)
|
||||
# EMA weights usually yield higher quality images for inference but non-EMA weights have
|
||||
# been yielding better results in our case.
|
||||
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
|
||||
# weight extraction or not.
|
||||
extract_ema = False
|
||||
print(
|
||||
"Loading diffusers' pipeline from original stable diffusion checkpoint"
|
||||
)
|
||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
)
|
||||
pipe.save_pretrained(path_to_diffusers)
|
||||
print("Loading complete")
|
||||
|
||||
|
||||
def load_vmfb(vmfb_path, model, precision):
|
||||
model = "vae" if "base_vae" in model or "vae_encode" in model else model
|
||||
precision = "fp32" if "clip" in model else precision
|
||||
extra_args = get_opt_flags(model, precision)
|
||||
shark_module = SharkInference(mlir_module=None, device=args.device)
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
return shark_module
|
||||
|
||||
|
||||
# This utility returns vmfbs of Clip, Unet, Vae and Vae_encode, in case all of them
|
||||
# are present; deletes them otherwise.
|
||||
def fetch_or_delete_vmfbs(
|
||||
extended_model_name, need_vae_encode, precision="fp32"
|
||||
):
|
||||
vmfb_path = [
|
||||
get_vmfb_path_name(extended_model_name[model])
|
||||
for model in extended_model_name
|
||||
]
|
||||
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
|
||||
all_vmfb_present = True
|
||||
compiled_models = []
|
||||
for i in range(3):
|
||||
all_vmfb_present = all_vmfb_present and vmfb_present[i]
|
||||
compiled_models.append(None)
|
||||
if need_vae_encode:
|
||||
all_vmfb_present = all_vmfb_present and vmfb_present[3]
|
||||
compiled_models.append(None)
|
||||
|
||||
# We need to delete vmfbs only if some of the models were compiled.
|
||||
if not all_vmfb_present:
|
||||
for i in range(len(compiled_models)):
|
||||
if vmfb_present[i]:
|
||||
os.remove(vmfb_path[i])
|
||||
print("Deleted: ", vmfb_path[i])
|
||||
else:
|
||||
model_name = [model for model in extended_model_name.keys()]
|
||||
for i in range(len(compiled_models)):
|
||||
compiled_models[i] = load_vmfb(
|
||||
vmfb_path[i], model_name[i], precision
|
||||
)
|
||||
return compiled_models
|
||||
|
||||
|
||||
# `fetch_and_update_base_model_id` is a resource utility function which
|
||||
# helps maintaining mapping of the model to run with its base model.
|
||||
# If `base_model` is "", then this function tries to fetch the base model
|
||||
# info for the `model_to_run`.
|
||||
def fetch_and_update_base_model_id(model_to_run, base_model=""):
|
||||
variants_path = os.path.join(os.getcwd(), "variants.json")
|
||||
data = {model_to_run: base_model}
|
||||
json_data = {}
|
||||
if os.path.exists(variants_path):
|
||||
with open(variants_path, "r", encoding="utf-8") as jsonFile:
|
||||
json_data = json.load(jsonFile)
|
||||
# Return with base_model's info if base_model is "".
|
||||
if base_model == "":
|
||||
if model_to_run in json_data:
|
||||
base_model = json_data[model_to_run]
|
||||
return base_model
|
||||
elif base_model == "":
|
||||
return base_model
|
||||
# Update JSON data to contain an entry mapping model_to_run with base_model.
|
||||
json_data.update(data)
|
||||
with open(variants_path, "w", encoding="utf-8") as jsonFile:
|
||||
json.dump(json_data, jsonFile)
|
||||
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the supported range (including -1)
|
||||
def sanitize_seed(seed):
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
return seed
|
||||
|
||||
|
||||
# clear all the cached objects to recompile cleanly.
|
||||
def clear_all():
|
||||
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
|
||||
from glob import glob
|
||||
import shutil
|
||||
|
||||
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
|
||||
for vmfb in vmfbs:
|
||||
if os.path.exists(vmfb):
|
||||
os.remove(vmfb)
|
||||
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
|
||||
# TODO: Remove this once we have better weight updation logic.
|
||||
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
|
||||
for yaml in inference_yaml:
|
||||
if os.path.exists(yaml):
|
||||
os.remove(yaml)
|
||||
home = os.path.expanduser("~")
|
||||
if os.name == "nt": # Windows
|
||||
appdata = os.getenv("LOCALAPPDATA")
|
||||
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
|
||||
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
|
||||
elif os.name == "unix":
|
||||
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed):
|
||||
output_path = args.output_dir if args.output_dir else Path.cwd()
|
||||
generated_imgs_path = Path(
|
||||
output_path, "generated_imgs", dt.now().strftime("%Y%m%d")
|
||||
)
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
|
||||
out_img_name = (
|
||||
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
img_model = args.hf_model_id
|
||||
if args.ckpt_loc:
|
||||
img_model = os.path.basename(args.ckpt_loc)
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if args.write_metadata_to_png:
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
if args.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {args.output_img_format} is not supported yet."
|
||||
"Image saved as png instead. Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
new_entry = {
|
||||
"VARIANT": img_model,
|
||||
"SCHEDULER": args.scheduler,
|
||||
"PROMPT": args.prompts[0],
|
||||
"NEG_PROMPT": args.negative_prompts[0],
|
||||
"SEED": img_seed,
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"HEIGHT": args.height,
|
||||
"WIDTH": args.width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
}
|
||||
|
||||
with open(csv_path, "a") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
if args.save_metadata_to_json:
|
||||
del new_entry["OUTPUT"]
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
15
apps/stable_diffusion/stable_diffusion_telegram_bot.md
Normal file
15
apps/stable_diffusion/stable_diffusion_telegram_bot.md
Normal file
@@ -0,0 +1,15 @@
|
||||
You need to pre-create your bot (https://core.telegram.org/bots#how-do-i-create-a-bot)
|
||||
Then create in the directory web file .env
|
||||
In it the record:
|
||||
TG_TOKEN="your_token"
|
||||
specifying your bot's token from previous step.
|
||||
Then run telegram_bot.py with the same parameters that you use when running index.py, for example:
|
||||
python telegram_bot.py --max_length=77 --vulkan_large_heap_block_size=0 --use_base_vae --local_tank_cache h:\shark\TEMP
|
||||
|
||||
Bot commands:
|
||||
/select_model
|
||||
/select_scheduler
|
||||
/set_steps "integer number of steps"
|
||||
/set_guidance_scale "integer number"
|
||||
/set_negative_prompt "negative text"
|
||||
Any other text triggers the creation of an image based on it.
|
||||
44
apps/stable_diffusion/web/index.py
Normal file
44
apps/stable_diffusion/web/index.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
|
||||
import gradio as gr
|
||||
from apps.stable_diffusion.src import args, clear_all
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
clear_gradio_tmp_imgs_folder,
|
||||
)
|
||||
|
||||
# clear all gradio tmp images from the last session
|
||||
clear_gradio_tmp_imgs_folder()
|
||||
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
from apps.stable_diffusion.web.ui import txt2img_web, img2img_web
|
||||
|
||||
sd_web = gr.TabbedInterface(
|
||||
[txt2img_web, img2img_web],
|
||||
["Text-to-Image", "Image-to-Image"],
|
||||
css=dark_theme,
|
||||
)
|
||||
|
||||
sd_web.queue()
|
||||
sd_web.launch(
|
||||
share=args.share,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.server_port,
|
||||
)
|
||||
2
apps/stable_diffusion/web/ui/__init__.py
Normal file
2
apps/stable_diffusion/web/ui/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import txt2img_web
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import img2img_web
|
||||
215
apps/stable_diffusion/web/ui/css/sd_dark_theme.css
Normal file
215
apps/stable_diffusion/web/ui/css/sd_dark_theme.css
Normal file
@@ -0,0 +1,215 @@
|
||||
|
||||
/* Overwrite the Gradio default theme with their .dark theme declarations */
|
||||
|
||||
:root {
|
||||
--color-focus-primary: var(--color-grey-700);
|
||||
--color-focus-secondary: var(--color-grey-600);
|
||||
--color-focus-ring: rgb(55 65 81);
|
||||
--color-background-primary: var(--color-grey-950);
|
||||
--color-background-secondary: var(--color-grey-900);
|
||||
--color-background-tertiary: var(--color-grey-800);
|
||||
--color-text-body: var(--color-grey-100);
|
||||
--color-text-label: var(--color-grey-200);
|
||||
--color-text-placeholder: var(--color-grey);
|
||||
--color-text-subdued: var(--color-grey-400);
|
||||
--color-text-link-base: var(--color-blue-500);
|
||||
--color-text-link-hover: var(--color-blue-400);
|
||||
--color-text-link-visited: var(--color-blue-600);
|
||||
--color-text-link-active: var(--color-blue-500);
|
||||
--color-text-code-background: var(--color-grey-800);
|
||||
--color-text-code-border: color.border-primary;
|
||||
--color-border-primary: var(--color-grey-700);
|
||||
--color-border-secondary: var(--color-grey-600);
|
||||
--color-border-highlight: var(--color-accent-base);
|
||||
--color-accent-base: var(--color-orange-500);
|
||||
--color-accent-light: var(--color-orange-300);
|
||||
--color-accent-dark: var(--color-orange-700);
|
||||
--color-functional-error-base: var(--color-red-400);
|
||||
--color-functional-error-subdued: var(--color-red-300);
|
||||
--color-functional-error-background: var(--color-background-primary);
|
||||
--color-functional-info-base: var(--color-yellow);
|
||||
--color-functional-info-subdued: var(--color-yellow-300);
|
||||
--color-functional-success-base: var(--color-green);
|
||||
--color-functional-success-subdued: var(--color-green-300);
|
||||
--shadow-spread: 2px;
|
||||
--api-background: linear-gradient(to bottom, rgba(255, 216, 180, .05), transparent);
|
||||
--api-pill-background: var(--color-orange-400);
|
||||
--api-pill-border: var(--color-orange-600);
|
||||
--api-pill-text: var(--color-orange-900);
|
||||
--block-border-color: var(--color-border-primary);
|
||||
--block-background: var(--color-background-tertiary);
|
||||
--uploadable-border-color-hover: var(--color-border-primary);
|
||||
--uploadable-border-color-loaded: var(--color-functional-success);
|
||||
--uploadable-text-color: var(--color-text-subdued);
|
||||
--block_label-border-color: var(--color-border-primary);
|
||||
--block_label-icon-color: var(--color-text-label);
|
||||
--block_label-shadow: var(--shadow-drop);
|
||||
--block_label-background: var(--color-background-secondary);
|
||||
--icon_button-icon-color-base: var(--color-text-label);
|
||||
--icon_button-icon-color-hover: var(--color-text-label);
|
||||
--icon_button-background-base: var(--color-background-primary);
|
||||
--icon_button-background-hover: var(--color-background-primary);
|
||||
--icon_button-border-color-base: var(--color-background-primary);
|
||||
--icon_button-border-color-hover: var(--color-border-secondary);
|
||||
--input-text-color: var(--color-text-body);
|
||||
--input-border-color-base: var(--color-border-primary);
|
||||
--input-border-color-hover: var(--color-border-primary);
|
||||
--input-border-color-focus: var(--color-border-primary);
|
||||
--input-background-base: var(--color-background-tertiary);
|
||||
--input-background-hover: var(--color-background-tertiary);
|
||||
--input-background-focus: var(--color-background-tertiary);
|
||||
--input-shadow: var(--shadow-inset);
|
||||
--checkbox-border-color-base: var(--color-border-primary);
|
||||
--checkbox-border-color-hover: var(--color-focus-primary);
|
||||
--checkbox-border-color-focus: var(--color-blue-500);
|
||||
--checkbox-background-base: var(--color-background-primary);
|
||||
--checkbox-background-hover: var(--color-background-primary);
|
||||
--checkbox-background-focus: var(--color-background-primary);
|
||||
--checkbox-background-selected: var(--color-blue-600);
|
||||
--checkbox-label-border-color-base: var(--color-border-primary);
|
||||
--checkbox-label-border-color-hover: var(--color-border-primary);
|
||||
--checkbox-label-border-color-focus: var(--color-border-secondary);
|
||||
--checkbox-label-background-base: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
|
||||
--checkbox-label-background-hover: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
|
||||
--checkbox-label-background-focus: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
|
||||
--form-seperator-color: var(--color-border-primary);
|
||||
--button-primary-border-color-base: var(--color-orange-600);
|
||||
--button-primary-border-color-hover: var(--color-orange-600);
|
||||
--button-primary-border-color-focus: var(--color-orange-600);
|
||||
--button-primary-text-color-base: white;
|
||||
--button-primary-text-color-hover: white;
|
||||
--button-primary-text-color-focus: white;
|
||||
--button-primary-background-base: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-700));
|
||||
--button-primary-background-hover: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-500));
|
||||
--button-primary-background-focus: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-500));
|
||||
--button-secondary-border-color-base: var(--color-grey-600);
|
||||
--button-secondary-border-color-hover: var(--color-grey-600);
|
||||
--button-secondary-border-color-focus: var(--color-grey-600);
|
||||
--button-secondary-text-color-base: white;
|
||||
--button-secondary-text-color-hover: white;
|
||||
--button-secondary-text-color-focus: white;
|
||||
--button-secondary-background-base: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-700));
|
||||
--button-secondary-background-hover: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-600));
|
||||
--button-secondary-background-focus: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-600));
|
||||
--button-cancel-border-color-base: var(--color-red-600);
|
||||
--button-cancel-border-color-hover: var(--color-red-600);
|
||||
--button-cancel-border-color-focus: var(--color-red-600);
|
||||
--button-cancel-text-color-base: white;
|
||||
--button-cancel-text-color-hover: white;
|
||||
--button-cancel-text-color-focus: white;
|
||||
--button-cancel-background-base: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-700));
|
||||
--button-cancel-background-focus: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-500));
|
||||
--button-cancel-background-hover: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-500));
|
||||
--button-plain-border-color-base: var(--color-grey-600);
|
||||
--button-plain-border-color-hover: var(--color-grey-500);
|
||||
--button-plain-border-color-focus: var(--color-grey-500);
|
||||
--button-plain-text-color-base: var(--color-text-body);
|
||||
--button-plain-text-color-hover: var(--color-text-body);
|
||||
--button-plain-text-color-focus: var(--color-text-body);
|
||||
--button-plain-background-base: var(--color-grey-700);
|
||||
--button-plain-background-hover: var(--color-grey-700);
|
||||
--button-plain-background-focus: var(--color-grey-700);
|
||||
--gallery-label-background-base: var(--color-grey-50);
|
||||
--gallery-label-background-hover: var(--color-grey-50);
|
||||
--gallery-label-border-color-base: var(--color-border-primary);
|
||||
--gallery-label-border-color-hover: var(--color-border-primary);
|
||||
--gallery-thumb-background-base: var(--color-grey-900);
|
||||
--gallery-thumb-background-hover: var(--color-grey-900);
|
||||
--gallery-thumb-border-color-base: var(--color-border-primary);
|
||||
--gallery-thumb-border-color-hover: var(--color-accent-base);
|
||||
--gallery-thumb-border-color-focus: var(--color-blue-500);
|
||||
--gallery-thumb-border-color-selected: var(--color-accent-base);
|
||||
--chatbot-border-border-color-base: transparent;
|
||||
--chatbot-border-border-color-latest: transparent;
|
||||
--chatbot-user-background-base: ;
|
||||
--chatbot-user-background-latest: ;
|
||||
--chatbot-user-text-color-base: white;
|
||||
--chatbot-user-text-color-latest: white;
|
||||
--chatbot-bot-background-base: ;
|
||||
--chatbot-bot-background-latest: ;
|
||||
--chatbot-bot-text-color-base: white;
|
||||
--chatbot-bot-text-color-latest: white;
|
||||
--label-gradient-from: var(--color-orange-400);
|
||||
--label-gradient-to: var(--color-orange-600);
|
||||
--table-odd-background: var(--color-grey-900);
|
||||
--table-even-background: var(--color-grey-950);
|
||||
--table-background-edit: transparent;
|
||||
--dataset-gallery-background-base: var(--color-background-primary);
|
||||
--dataset-gallery-background-hover: var(--color-grey-800);
|
||||
--dataset-dataframe-border-base: var(--color-border-primary);
|
||||
--dataset-dataframe-border-hover: var(--color-border-secondary);
|
||||
--dataset-table-background-base: transparent;
|
||||
--dataset-table-background-hover: var(--color-grey-700);
|
||||
--dataset-table-border-base: var(--color-grey-800);
|
||||
--dataset-table-border-hover: var(--color-grey-800);
|
||||
}
|
||||
|
||||
/* SHARK theme */
|
||||
body {
|
||||
background-color: var(--color-background-primary);
|
||||
}
|
||||
|
||||
/* display in full width for desktop devices */
|
||||
@media (min-width: 1536px)
|
||||
{
|
||||
.gradio-container {
|
||||
max-width: var(--size-full) !important;
|
||||
}
|
||||
}
|
||||
|
||||
.gradio-container .contain {
|
||||
padding: 0 var(--size-4) !important;
|
||||
}
|
||||
|
||||
.container {
|
||||
background-color: black !important;
|
||||
padding-top: var(--size-5) !important;
|
||||
}
|
||||
|
||||
#ui_title {
|
||||
padding: var(--size-2) 0 0 var(--size-1);
|
||||
}
|
||||
|
||||
#top_logo {
|
||||
background-color: transparent;
|
||||
border-radius: 0 !important;
|
||||
border: 0;
|
||||
}
|
||||
|
||||
#demo_title_outer {
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
#prompt_box_outer div:first-child {
|
||||
border-radius: 0 !important
|
||||
}
|
||||
|
||||
#prompt_box textarea, #negative_prompt_box textarea {
|
||||
background-color: var(--color-background-primary) !important;
|
||||
}
|
||||
|
||||
#prompt_examples {
|
||||
margin: 0 !important;
|
||||
}
|
||||
|
||||
#prompt_examples svg {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
#ui_body {
|
||||
background-color: var(--color-background-secondary) !important;
|
||||
padding: var(--size-2) !important;
|
||||
border-radius: 0.5em !important;
|
||||
}
|
||||
|
||||
#img_result+div {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
footer {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
#gallery + div {
|
||||
border-radius: 0 !important;
|
||||
}
|
||||
231
apps/stable_diffusion/web/ui/img2img_ui.py
Normal file
231
apps/stable_diffusion/web/ui/img2img_ui.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import img2img_inf
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
init_image = gr.Image(label="Input Image", type="filepath")
|
||||
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value="PNDM",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"DPMSolverMultistep",
|
||||
"EulerAncestralDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 786, value=args.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 786, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.strength,
|
||||
step=0.1,
|
||||
label="Strength",
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2])
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=1,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
kwargs = dict(
|
||||
fn=img2img_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
BIN
apps/stable_diffusion/web/ui/logos/nod-logo.png
Normal file
BIN
apps/stable_diffusion/web/ui/logos/nod-logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 10 KiB |
229
apps/stable_diffusion/web/ui/txt2img_ui.py
Normal file
229
apps/stable_diffusion/web/ui/txt2img_ui.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import txt2img_inf
|
||||
from apps.stable_diffusion.src import prompt_examples, args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"KDPM2Discrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 786, value=args.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 786, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
inputs=prompt,
|
||||
cache_examples=False,
|
||||
elem_id="prompt_examples",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2])
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=1,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
kwargs = dict(
|
||||
fn=txt2img_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
15
apps/stable_diffusion/web/ui/utils.py
Normal file
15
apps/stable_diffusion/web/ui/utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
import sys
|
||||
from apps.stable_diffusion.src import get_available_devices
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
available_devices = get_available_devices()
|
||||
31
apps/stable_diffusion/web/utils/gradio_configs.py
Normal file
31
apps/stable_diffusion/web/utils/gradio_configs.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import os
|
||||
import tempfile
|
||||
import gradio
|
||||
from os import listdir
|
||||
|
||||
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
|
||||
|
||||
# Clear all gradio tmp images
|
||||
def clear_gradio_tmp_imgs_folder():
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
return
|
||||
for fileName in listdir(gradio_tmp_imgs_folder):
|
||||
# Delete tmp png files
|
||||
if fileName.startswith("tmp") and fileName.endswith(".png"):
|
||||
os.remove(gradio_tmp_imgs_folder + fileName)
|
||||
|
||||
|
||||
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
os.mkdir(gradio_tmp_imgs_folder)
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
|
||||
)
|
||||
pil_image.save(file_obj)
|
||||
return file_obj
|
||||
|
||||
|
||||
# Register save_pil_to_file override
|
||||
gradio.processing_utils.save_pil_to_file = save_pil_to_file
|
||||
@@ -42,7 +42,7 @@ class TFHuggingFaceLanguage(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=tf_bert_input)
|
||||
@tf.function(input_signature=tf_bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
51
build_tools/image_comparison.py
Normal file
51
build_tools/image_comparison.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import argparse
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
import requests
|
||||
import shutil
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("-n", "--newfile")
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--golden_url",
|
||||
default="https://storage.googleapis.com/shark_tank/testdata/cyberpunk_fores_42_0_230119_021148.png",
|
||||
)
|
||||
|
||||
|
||||
def get_image(url, local_filename):
|
||||
res = requests.get(url, stream=True)
|
||||
if res.status_code == 200:
|
||||
with open(local_filename, "wb") as f:
|
||||
shutil.copyfileobj(res.raw, f)
|
||||
|
||||
|
||||
def compare_images(new_filename, golden_filename):
|
||||
new = np.array(Image.open(new_filename)) / 255.0
|
||||
golden = np.array(Image.open(golden_filename)) / 255.0
|
||||
diff = np.abs(new - golden)
|
||||
mean = np.mean(diff)
|
||||
if mean > 0.1:
|
||||
if os.name != "nt":
|
||||
subprocess.run(
|
||||
[
|
||||
"gsutil",
|
||||
"cp",
|
||||
new_filename,
|
||||
"gs://shark_tank/testdata/builder/",
|
||||
]
|
||||
)
|
||||
raise SystemExit("new and golden not close")
|
||||
else:
|
||||
print("SUCCESS")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
tempfile_name = os.path.join(os.getcwd(), "golden.png")
|
||||
get_image(args.golden_url, tempfile_name)
|
||||
compare_images(args.newfile, tempfile_name)
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
|
||||
source $GITHUB_WORKSPACE/shark.venv/bin/activate
|
||||
python generate_sharktank.py --upload=False --ci_tank_dir=True
|
||||
python generate_sharktank.py
|
||||
|
||||
143
build_tools/stable_diffusion_testing.py
Normal file
143
build_tools/stable_diffusion_testing.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import os
|
||||
from sys import executable
|
||||
import subprocess
|
||||
from apps.stable_diffusion.src.utils.resources import (
|
||||
get_json_file,
|
||||
)
|
||||
from datetime import datetime as dt
|
||||
from shark.shark_downloader import download_public_file
|
||||
from image_comparison import compare_images
|
||||
import argparse
|
||||
from glob import glob
|
||||
import shutil
|
||||
import requests
|
||||
|
||||
model_config_dicts = get_json_file(
|
||||
os.path.join(
|
||||
os.getcwd(),
|
||||
"apps/stable_diffusion/src/utils/resources/model_config.json",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_inpaint_inputs():
|
||||
os.mkdir("./test_images/inputs")
|
||||
img_url = (
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve"
|
||||
"/main/stable_diffusion_inpaint/input_bench_image.png"
|
||||
)
|
||||
mask_url = (
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve"
|
||||
"/main/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
)
|
||||
img = requests.get(img_url)
|
||||
mask = requests.get(mask_url)
|
||||
open("./test_images/inputs/image.png", "wb").write(img.content)
|
||||
open("./test_images/inputs/mask.png", "wb").write(mask.content)
|
||||
|
||||
|
||||
def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
# Get golden values from tank
|
||||
shutil.rmtree("./test_images", ignore_errors=True)
|
||||
os.mkdir("./test_images")
|
||||
os.mkdir("./test_images/golden")
|
||||
get_inpaint_inputs()
|
||||
hf_model_names = model_config_dicts[0].values()
|
||||
tuned_options = ["--no-use_tuned", "--use_tuned"]
|
||||
import_options = ["--import_mlir", "--no-import_mlir"]
|
||||
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
|
||||
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
if os.name == "nt":
|
||||
prompt_text = '--prompt="cyberpunk forest by Salvador Dali"'
|
||||
inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"'
|
||||
if beta:
|
||||
extra_flags.append("--beta_models=True")
|
||||
for import_opt in import_options:
|
||||
for model_name in hf_model_names:
|
||||
if model_name == "Linaqruf/anything-v3.0":
|
||||
continue
|
||||
for use_tune in tuned_options:
|
||||
command = (
|
||||
[
|
||||
executable, # executable is the python from the venv used to run this
|
||||
"apps/stable_diffusion/scripts/txt2img.py",
|
||||
"--device=" + device,
|
||||
prompt_text,
|
||||
"--negative_prompts=" + '""',
|
||||
"--seed=42",
|
||||
import_opt,
|
||||
"--output_dir="
|
||||
+ os.path.join(os.getcwd(), "test_images", model_name),
|
||||
"--hf_model_id=" + model_name,
|
||||
use_tune,
|
||||
]
|
||||
if "inpainting" not in model_name
|
||||
else [
|
||||
"python",
|
||||
"apps/stable_diffusion/scripts/inpaint.py",
|
||||
"--device=" + device,
|
||||
inpaint_prompt_text,
|
||||
"--negative_prompts=" + '""',
|
||||
"--img_path=./test_images/inputs/image.png",
|
||||
"--mask_path=./test_images/inputs/mask.png",
|
||||
"--seed=42",
|
||||
"--import_mlir",
|
||||
"--output_dir="
|
||||
+ os.path.join(os.getcwd(), "test_images", model_name),
|
||||
"--hf_model_id=" + model_name,
|
||||
use_tune,
|
||||
]
|
||||
)
|
||||
command += extra_flags
|
||||
if os.name == "nt":
|
||||
command = " ".join(command)
|
||||
generated_image = not subprocess.call(
|
||||
command, stdout=subprocess.DEVNULL
|
||||
)
|
||||
if os.name != "nt":
|
||||
command = " ".join(command)
|
||||
if generated_image:
|
||||
print(command)
|
||||
print("Successfully generated image")
|
||||
os.makedirs(
|
||||
"./test_images/golden/" + model_name, exist_ok=True
|
||||
)
|
||||
download_public_file(
|
||||
"gs://shark_tank/testdata/golden/" + model_name,
|
||||
"./test_images/golden/" + model_name,
|
||||
)
|
||||
test_file_path = os.path.join(
|
||||
os.getcwd(),
|
||||
"test_images",
|
||||
model_name,
|
||||
"generated_imgs",
|
||||
dt.now().strftime("%Y%m%d"),
|
||||
"*.png",
|
||||
)
|
||||
test_file = glob(test_file_path)[0]
|
||||
|
||||
golden_path = (
|
||||
"./test_images/golden/" + model_name + "/*.png"
|
||||
)
|
||||
golden_file = glob(golden_path)[0]
|
||||
compare_images(test_file, golden_file)
|
||||
else:
|
||||
print(command)
|
||||
print("failed to generate image for this configuration")
|
||||
if "2_1_base" in model_name:
|
||||
print("failed a known successful model.")
|
||||
exit(1)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("-d", "--device", default="vulkan")
|
||||
parser.add_argument(
|
||||
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
test_loop(args.device, args.beta, [])
|
||||
16
conftest.py
16
conftest.py
@@ -36,6 +36,12 @@ def pytest_addoption(parser):
|
||||
default="False",
|
||||
help="Enables uploading of reproduction artifacts upon test case failure during iree-compile or validation. Must be passed with --ci_sha option ",
|
||||
)
|
||||
parser.addoption(
|
||||
"--update_tank",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Update local shark tank with latest artifacts.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--ci_sha",
|
||||
action="store",
|
||||
@@ -54,3 +60,13 @@ def pytest_addoption(parser):
|
||||
default="gs://shark_tank/latest",
|
||||
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
|
||||
)
|
||||
parser.addoption(
|
||||
"--benchmark_dispatches",
|
||||
default=None,
|
||||
help="Benchmark individual dispatch kernels produced by IREE compiler. Use 'All' for all, or specific dispatches e.g. '0 1 2 10'",
|
||||
)
|
||||
parser.addoption(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="./temp_dispatch_benchmarks",
|
||||
help="Directory in which dispatch benchmarks are saved.",
|
||||
)
|
||||
|
||||
3
cpp/.gitignore
vendored
Normal file
3
cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
*.mlir
|
||||
*.vmfb
|
||||
*.ini
|
||||
@@ -54,5 +54,29 @@ python -m pip install tensorflow
|
||||
|
||||
*Run the vulkan_gui*
|
||||
```bash
|
||||
./build/vulkan_gui/iree-samples-vulkan-gui
|
||||
./build/vulkan_gui/iree-samples-resnet-vulkan-gui
|
||||
```
|
||||
|
||||
## Other models
|
||||
A tool for benchmarking other models is built and can be invoked with a command like the following
|
||||
```bash
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=path/to/.vmfb --function_input=...
|
||||
```
|
||||
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
|
||||
```
|
||||
VAE and Autoencoder are also available
|
||||
```bash
|
||||
# VAE
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
|
||||
|
||||
# CLIP Autoencoder
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
|
||||
```
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
|
||||
|
||||
def load_and_preprocess_image(fname: str):
|
||||
|
||||
@@ -40,45 +40,77 @@ set(IMGUI_DIR ${CMAKE_BINARY_DIR}/_deps/imgui-src)
|
||||
message("Looking for Imgui in ${IMGUI_DIR}")
|
||||
include_directories(${IMGUI_DIR} ${IMGUI_DIR}/backends ..)
|
||||
|
||||
# Define the sample executable.
|
||||
set(_NAME "iree-samples-vulkan-gui")
|
||||
add_executable(${_NAME} "")
|
||||
target_sources(${_NAME}
|
||||
PRIVATE
|
||||
vulkan_inference_gui.cc
|
||||
"${IMGUI_DIR}/backends/imgui_impl_sdl.cpp"
|
||||
"${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp"
|
||||
"${IMGUI_DIR}/imgui.cpp"
|
||||
"${IMGUI_DIR}/imgui_draw.cpp"
|
||||
"${IMGUI_DIR}/imgui_demo.cpp"
|
||||
"${IMGUI_DIR}/imgui_tables.cpp"
|
||||
"${IMGUI_DIR}/imgui_widgets.cpp"
|
||||
)
|
||||
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "iree-samples-vulkan-gui")
|
||||
target_include_directories(${_NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
|
||||
)
|
||||
target_link_libraries(${_NAME}
|
||||
SDL2::SDL2
|
||||
Vulkan::Vulkan
|
||||
iree_runtime_runtime
|
||||
iree_base_internal_main
|
||||
iree_hal_drivers_vulkan_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_vm_cc
|
||||
|
||||
function(iree_vulkan_sample)
|
||||
|
||||
cmake_parse_arguments(
|
||||
_RULE
|
||||
""
|
||||
"NAME"
|
||||
"SRCS"
|
||||
${ARGN}
|
||||
)
|
||||
|
||||
|
||||
# Define the sample executable.
|
||||
set(_NAME "${_RULE_NAME}")
|
||||
set(SRCS "${_RULE_SRCS}")
|
||||
add_executable(${_NAME} "")
|
||||
target_sources(${_NAME}
|
||||
PRIVATE
|
||||
${SRCS}
|
||||
"${IMGUI_DIR}/backends/imgui_impl_sdl.cpp"
|
||||
"${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp"
|
||||
"${IMGUI_DIR}/imgui.cpp"
|
||||
"${IMGUI_DIR}/imgui_draw.cpp"
|
||||
"${IMGUI_DIR}/imgui_demo.cpp"
|
||||
"${IMGUI_DIR}/imgui_tables.cpp"
|
||||
"${IMGUI_DIR}/imgui_widgets.cpp"
|
||||
)
|
||||
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_NAME}")
|
||||
target_include_directories(${_NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
|
||||
)
|
||||
target_link_libraries(${_NAME}
|
||||
SDL2::SDL2
|
||||
Vulkan::Vulkan
|
||||
iree_runtime_runtime
|
||||
iree_base_internal_main
|
||||
iree_hal_drivers_vulkan_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_vm_cc
|
||||
iree_tooling_vm_util_cc
|
||||
iree_tooling_context_util
|
||||
)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
|
||||
set(_GUI_LINKOPTS "-SUBSYSTEM:CONSOLE")
|
||||
else()
|
||||
set(_GUI_LINKOPTS "")
|
||||
endif()
|
||||
|
||||
target_link_options(${_NAME}
|
||||
PRIVATE
|
||||
${_GUI_LINKOPTS}
|
||||
)
|
||||
endfunction()
|
||||
|
||||
iree_vulkan_sample(
|
||||
NAME
|
||||
iree-samples-resnet-vulkan-gui
|
||||
|
||||
SRCS
|
||||
vulkan_resnet_inference_gui.cc
|
||||
)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
|
||||
set(_GUI_LINKOPTS "-SUBSYSTEM:CONSOLE")
|
||||
else()
|
||||
set(_GUI_LINKOPTS "")
|
||||
endif()
|
||||
iree_vulkan_sample(
|
||||
NAME
|
||||
iree-vulkan-gui
|
||||
|
||||
target_link_options(${_NAME}
|
||||
PRIVATE
|
||||
${_GUI_LINKOPTS}
|
||||
SRCS
|
||||
vulkan_inference_gui.cc
|
||||
)
|
||||
|
||||
message(STATUS "Configured vulkan_gui sample successfully")
|
||||
|
||||
@@ -18,6 +18,12 @@
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <array>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "iree/hal/drivers/vulkan/api.h"
|
||||
|
||||
@@ -30,6 +36,15 @@
|
||||
#include "iree/vm/bytecode_module.h"
|
||||
#include "iree/vm/ref_cc.h"
|
||||
|
||||
// iree-run-module
|
||||
#include "iree/base/internal/flags.h"
|
||||
#include "iree/base/status_cc.h"
|
||||
#include "iree/base/tracing.h"
|
||||
#include "iree/modules/hal/types.h"
|
||||
#include "iree/tooling/comparison.h"
|
||||
#include "iree/tooling/context_util.h"
|
||||
#include "iree/tooling/vm_util_cc.h"
|
||||
|
||||
// Other dependencies (helpers, etc.)
|
||||
#include "iree/base/internal/main.h"
|
||||
|
||||
@@ -38,6 +53,49 @@
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
IREE_FLAG(string, entry_function, "",
|
||||
"Name of a function contained in the module specified by module_file "
|
||||
"to run.");
|
||||
|
||||
// TODO(benvanik): move --function_input= flag into a util.
|
||||
static iree_status_t parse_function_io(iree_string_view_t flag_name,
|
||||
void* storage,
|
||||
iree_string_view_t value) {
|
||||
auto* list = (std::vector<std::string>*)storage;
|
||||
list->push_back(std::string(value.data, value.size));
|
||||
return iree_ok_status();
|
||||
}
|
||||
static void print_function_io(iree_string_view_t flag_name, void* storage,
|
||||
FILE* file) {
|
||||
auto* list = (std::vector<std::string>*)storage;
|
||||
if (list->empty()) {
|
||||
fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
|
||||
} else {
|
||||
for (size_t i = 0; i < list->size(); ++i) {
|
||||
fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
|
||||
list->at(i).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
static std::vector<std::string> FLAG_function_inputs;
|
||||
IREE_FLAG_CALLBACK(
|
||||
parse_function_io, print_function_io, &FLAG_function_inputs, function_input,
|
||||
"An input (a) value or (b) buffer of the format:\n"
|
||||
" (a) scalar value\n"
|
||||
" value\n"
|
||||
" e.g.: --function_input=\"3.14\"\n"
|
||||
" (b) buffer:\n"
|
||||
" [shape]xtype=[value]\n"
|
||||
" e.g.: --function_input=\"2x2xi32=1 2 3 4\"\n"
|
||||
"Optionally, brackets may be used to separate the element values:\n"
|
||||
" 2x2xi32=[[1 2][3 4]]\n"
|
||||
"Raw binary files can be read to provide buffer contents:\n"
|
||||
" 2x2xi32=@some/file.bin\n"
|
||||
"numpy npy files (from numpy.save) can be read to provide 1+ values:\n"
|
||||
" @some.npy\n"
|
||||
"Each occurrence of the flag indicates an input in the order they were\n"
|
||||
"specified on the command line.");
|
||||
|
||||
typedef struct iree_file_toc_t {
|
||||
const char* name; // the file's original name
|
||||
char* data; // beginning of the file
|
||||
@@ -87,225 +145,6 @@ static void check_vk_result(VkResult err) {
|
||||
abort();
|
||||
}
|
||||
|
||||
// Helper function to find Vulkan memory type bits. See ImGui_ImplVulkan_MemoryType() in imgui_impl_vulkan.cpp
|
||||
uint32_t findMemoryType(uint32_t type_filter, VkMemoryPropertyFlags properties)
|
||||
{
|
||||
VkPhysicalDeviceMemoryProperties mem_properties;
|
||||
vkGetPhysicalDeviceMemoryProperties(g_PhysicalDevice, &mem_properties);
|
||||
|
||||
for (uint32_t i = 0; i < mem_properties.memoryTypeCount; i++)
|
||||
{
|
||||
if ((type_filter & (1 << i)) && (mem_properties.memoryTypes[i].propertyFlags & properties) == properties)
|
||||
{
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return 0xFFFFFFFF; // Unable to find memoryType
|
||||
}
|
||||
|
||||
// Helper function to load an image with common settings and return a VkDescriptorSet as a sort of Vulkan pointer
|
||||
bool LoadTextureFromFile(const char* filename, VkDescriptorSet* img_ds, int* image_width, int* image_height)
|
||||
{
|
||||
// Specifying 4 channels forces stb to load the image in RGBA which is an easy format for Vulkan
|
||||
int image_channels = 4;
|
||||
unsigned char* image_data = stbi_load(filename, image_width, image_height, 0, image_channels);
|
||||
|
||||
if (image_data == NULL)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Calculate allocation size (in number of bytes)
|
||||
size_t image_size = (*image_width)*(*image_height)*image_channels;
|
||||
|
||||
VkResult err;
|
||||
|
||||
// Create the Vulkan image.
|
||||
VkImage texture_image;
|
||||
VkDeviceMemory texture_image_memory;
|
||||
{
|
||||
VkImageCreateInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO;
|
||||
info.imageType = VK_IMAGE_TYPE_2D;
|
||||
info.format = VK_FORMAT_R8G8B8A8_UNORM;
|
||||
info.extent.width = *image_width;
|
||||
info.extent.height = *image_height;
|
||||
info.extent.depth = 1;
|
||||
info.mipLevels = 1;
|
||||
info.arrayLayers = 1;
|
||||
info.samples = VK_SAMPLE_COUNT_1_BIT;
|
||||
info.tiling = VK_IMAGE_TILING_OPTIMAL;
|
||||
info.usage = VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT;
|
||||
info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
|
||||
info.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED;
|
||||
err = vkCreateImage(g_Device, &info, g_Allocator, &texture_image);
|
||||
check_vk_result(err);
|
||||
VkMemoryRequirements req;
|
||||
vkGetImageMemoryRequirements(g_Device, texture_image, &req);
|
||||
VkMemoryAllocateInfo alloc_info = {};
|
||||
alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
|
||||
alloc_info.allocationSize = req.size;
|
||||
alloc_info.memoryTypeIndex = findMemoryType(req.memoryTypeBits, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);
|
||||
err = vkAllocateMemory(g_Device, &alloc_info, g_Allocator, &texture_image_memory);
|
||||
check_vk_result(err);
|
||||
err = vkBindImageMemory(g_Device, texture_image, texture_image_memory, 0);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Create the Image View
|
||||
VkImageView image_view;
|
||||
{
|
||||
VkImageViewCreateInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO;
|
||||
info.image = texture_image;
|
||||
info.viewType = VK_IMAGE_VIEW_TYPE_2D;
|
||||
info.format = VK_FORMAT_R8G8B8A8_UNORM;
|
||||
info.subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
|
||||
info.subresourceRange.levelCount = 1;
|
||||
info.subresourceRange.layerCount = 1;
|
||||
err = vkCreateImageView(g_Device, &info, g_Allocator, &image_view);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Create Sampler
|
||||
VkSampler sampler;
|
||||
{
|
||||
VkSamplerCreateInfo sampler_info{};
|
||||
sampler_info.sType = VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO;
|
||||
sampler_info.magFilter = VK_FILTER_LINEAR;
|
||||
sampler_info.minFilter = VK_FILTER_LINEAR;
|
||||
sampler_info.mipmapMode = VK_SAMPLER_MIPMAP_MODE_LINEAR;
|
||||
sampler_info.addressModeU = VK_SAMPLER_ADDRESS_MODE_REPEAT; // outside image bounds just use border color
|
||||
sampler_info.addressModeV = VK_SAMPLER_ADDRESS_MODE_REPEAT;
|
||||
sampler_info.addressModeW = VK_SAMPLER_ADDRESS_MODE_REPEAT;
|
||||
sampler_info.minLod = -1000;
|
||||
sampler_info.maxLod = 1000;
|
||||
sampler_info.maxAnisotropy = 1.0f;
|
||||
err = vkCreateSampler(g_Device, &sampler_info, g_Allocator, &sampler);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Create Descriptor Set using ImGUI's implementation
|
||||
*img_ds = ImGui_ImplVulkan_AddTexture(sampler, image_view, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL);
|
||||
|
||||
// Create Upload Buffer
|
||||
VkBuffer upload_buffer;
|
||||
VkDeviceMemory upload_buffer_memory;
|
||||
{
|
||||
VkBufferCreateInfo buffer_info = {};
|
||||
buffer_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
|
||||
buffer_info.size = image_size;
|
||||
buffer_info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
|
||||
buffer_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
|
||||
err = vkCreateBuffer(g_Device, &buffer_info, g_Allocator, &upload_buffer);
|
||||
check_vk_result(err);
|
||||
VkMemoryRequirements req;
|
||||
vkGetBufferMemoryRequirements(g_Device, upload_buffer, &req);
|
||||
VkMemoryAllocateInfo alloc_info = {};
|
||||
alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
|
||||
alloc_info.allocationSize = req.size;
|
||||
alloc_info.memoryTypeIndex = findMemoryType(req.memoryTypeBits, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
|
||||
err = vkAllocateMemory(g_Device, &alloc_info, g_Allocator, &upload_buffer_memory);
|
||||
check_vk_result(err);
|
||||
err = vkBindBufferMemory(g_Device, upload_buffer, upload_buffer_memory, 0);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Upload to Buffer:
|
||||
{
|
||||
void* map = NULL;
|
||||
err = vkMapMemory(g_Device, upload_buffer_memory, 0, image_size, 0, &map);
|
||||
check_vk_result(err);
|
||||
memcpy(map, image_data, image_size);
|
||||
VkMappedMemoryRange range[1] = {};
|
||||
range[0].sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
|
||||
range[0].memory = upload_buffer_memory;
|
||||
range[0].size = image_size;
|
||||
err = vkFlushMappedMemoryRanges(g_Device, 1, range);
|
||||
check_vk_result(err);
|
||||
vkUnmapMemory(g_Device, upload_buffer_memory);
|
||||
}
|
||||
|
||||
// Release image memory using stb
|
||||
stbi_image_free(image_data);
|
||||
|
||||
// Create a command buffer that will perform following steps when hit in the command queue.
|
||||
// TODO: this works in the example, but may need input if this is an acceptable way to access the pool/create the command buffer.
|
||||
VkCommandPool command_pool = g_MainWindowData.Frames[g_MainWindowData.FrameIndex].CommandPool;
|
||||
VkCommandBuffer command_buffer;
|
||||
{
|
||||
VkCommandBufferAllocateInfo alloc_info{};
|
||||
alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
|
||||
alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
|
||||
alloc_info.commandPool = command_pool;
|
||||
alloc_info.commandBufferCount = 1;
|
||||
|
||||
err = vkAllocateCommandBuffers(g_Device, &alloc_info, &command_buffer);
|
||||
check_vk_result(err);
|
||||
|
||||
VkCommandBufferBeginInfo begin_info = {};
|
||||
begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
|
||||
begin_info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
|
||||
err = vkBeginCommandBuffer(command_buffer, &begin_info);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Copy to Image
|
||||
{
|
||||
VkImageMemoryBarrier copy_barrier[1] = {};
|
||||
copy_barrier[0].sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER;
|
||||
copy_barrier[0].dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
|
||||
copy_barrier[0].oldLayout = VK_IMAGE_LAYOUT_UNDEFINED;
|
||||
copy_barrier[0].newLayout = VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL;
|
||||
copy_barrier[0].srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
|
||||
copy_barrier[0].dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
|
||||
copy_barrier[0].image = texture_image;
|
||||
copy_barrier[0].subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
|
||||
copy_barrier[0].subresourceRange.levelCount = 1;
|
||||
copy_barrier[0].subresourceRange.layerCount = 1;
|
||||
vkCmdPipelineBarrier(command_buffer, VK_PIPELINE_STAGE_HOST_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 0, NULL, 0, NULL, 1, copy_barrier);
|
||||
|
||||
VkBufferImageCopy region = {};
|
||||
region.imageSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
|
||||
region.imageSubresource.layerCount = 1;
|
||||
region.imageExtent.width = *image_width;
|
||||
region.imageExtent.height = *image_height;
|
||||
region.imageExtent.depth = 1;
|
||||
vkCmdCopyBufferToImage(command_buffer, upload_buffer, texture_image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1, ®ion);
|
||||
|
||||
VkImageMemoryBarrier use_barrier[1] = {};
|
||||
use_barrier[0].sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER;
|
||||
use_barrier[0].srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
|
||||
use_barrier[0].dstAccessMask = VK_ACCESS_SHADER_READ_BIT;
|
||||
use_barrier[0].oldLayout = VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL;
|
||||
use_barrier[0].newLayout = VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL;
|
||||
use_barrier[0].srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
|
||||
use_barrier[0].dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
|
||||
use_barrier[0].image = texture_image;
|
||||
use_barrier[0].subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
|
||||
use_barrier[0].subresourceRange.levelCount = 1;
|
||||
use_barrier[0].subresourceRange.layerCount = 1;
|
||||
vkCmdPipelineBarrier(command_buffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0, 0, NULL, 0, NULL, 1, use_barrier);
|
||||
}
|
||||
|
||||
// End command buffer
|
||||
{
|
||||
VkSubmitInfo end_info = {};
|
||||
end_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
|
||||
end_info.commandBufferCount = 1;
|
||||
end_info.pCommandBuffers = &command_buffer;
|
||||
err = vkEndCommandBuffer(command_buffer);
|
||||
check_vk_result(err);
|
||||
err = vkQueueSubmit(g_Queue, 1, &end_info, VK_NULL_HANDLE);
|
||||
check_vk_result(err);
|
||||
err = vkDeviceWaitIdle(g_Device);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns the names of the Vulkan layers used for the given IREE
|
||||
// |extensibility_set| and |features|.
|
||||
std::vector<const char*> GetIreeLayers(
|
||||
@@ -723,7 +562,16 @@ namespace iree {
|
||||
|
||||
extern "C" int iree_main(int argc, char** argv) {
|
||||
|
||||
fprintf(stdout, "starting yo\n");
|
||||
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
|
||||
if (argc > 1) {
|
||||
// Avoid iree-run-module spinning endlessly on stdin if the user uses single
|
||||
// dashes for flags.
|
||||
printf(
|
||||
"[ERROR] unexpected positional argument (expected none)."
|
||||
" Did you use pass a flag with a single dash ('-')?"
|
||||
" Use '--' instead.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Create a window.
|
||||
@@ -835,8 +683,6 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
|
||||
// Demo state.
|
||||
bool show_iree_window = true;
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Setup IREE.
|
||||
|
||||
@@ -900,69 +746,44 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
|
||||
|
||||
// Load bytecode module
|
||||
iree_file_toc_t module_file_toc;
|
||||
const char network_model[] = "resnet50_tf.vmfb";
|
||||
fprintf(stdout, "Loading: %s\n", network_model);
|
||||
if (load_file(network_model, &module_file_toc.data, &module_file_toc.size) == false)
|
||||
{
|
||||
abort();
|
||||
return 1;
|
||||
}
|
||||
fprintf(stdout, "module size: %zu\n", module_file_toc.size);
|
||||
|
||||
static float input_res50[224*224*3];
|
||||
static float output_res50[1000];
|
||||
|
||||
char filename[] = "dog_imagenet.jpg";
|
||||
fprintf(stdout, "loading: %s\n", filename);
|
||||
int x,y,n;
|
||||
//unsigned char *image_raw = stbi_load(filename, &x, &y, &n, 3);
|
||||
stbi_load(filename, &x, &y, &n, 3);
|
||||
fprintf(stdout, "res: %i x %i x %i\n", x, y, n);
|
||||
|
||||
/* Preprocessing needs to go here. For now use a buffer preprocessed in python.
|
||||
|
||||
//convert image into floating point format
|
||||
for(int i=0;i<224*224*3;i++)
|
||||
{
|
||||
input_res50[i]= ((float)image_raw[i])/255.0f;
|
||||
}*/
|
||||
|
||||
std::ifstream fin("dog.bin", std::ifstream::in | std::ifstream::binary);
|
||||
fin.read((char*)input_res50, 224*224*3*sizeof(float));
|
||||
|
||||
// load image again so imgui can display it
|
||||
int my_image_width = 0;
|
||||
int my_image_height = 0;
|
||||
VkDescriptorSet my_image_texture = 0;
|
||||
bool ret = LoadTextureFromFile(filename, &my_image_texture, &my_image_width, &my_image_height);
|
||||
fprintf(stdout, "creating vulkan image: %s\n", ret ?"OK":"FAIL");
|
||||
IM_ASSERT(ret);
|
||||
//iree_file_toc_t module_file_toc;
|
||||
//const char network_model[] = "resnet50_tf.vmfb";
|
||||
//fprintf(stdout, "Loading: %s\n", network_model);
|
||||
//if (load_file(network_model, &module_file_toc.data, &module_file_toc.size) == false)
|
||||
//{
|
||||
// abort();
|
||||
// return 1;
|
||||
//}
|
||||
//fprintf(stdout, "module size: %zu\n", module_file_toc.size);
|
||||
|
||||
iree_vm_module_t* bytecode_module = nullptr;
|
||||
IREE_CHECK_OK(iree_vm_bytecode_module_create(
|
||||
iree_instance,
|
||||
iree_const_byte_span_t{
|
||||
reinterpret_cast<const uint8_t*>(module_file_toc.data),
|
||||
module_file_toc.size},
|
||||
iree_allocator_null(), iree_allocator_system(), &bytecode_module));
|
||||
// Query for details about what is in the loaded module.
|
||||
iree_vm_module_signature_t bytecode_module_signature =
|
||||
iree_vm_module_signature(bytecode_module);
|
||||
fprintf(stdout, "Module loaded, have <%" PRIhsz "> exported functions:\n",
|
||||
bytecode_module_signature.export_function_count);
|
||||
for (int i = 0; i < bytecode_module_signature.export_function_count; ++i) {
|
||||
iree_vm_function_t function;
|
||||
IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal(
|
||||
bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
|
||||
auto function_name = iree_vm_function_name(&function);
|
||||
auto function_signature = iree_vm_function_signature(&function);
|
||||
iree_status_t module_status = iree_tooling_load_module_from_flags(
|
||||
iree_instance, iree_allocator_system(), &bytecode_module);
|
||||
if (!iree_status_is_ok(module_status))
|
||||
return -1;
|
||||
//IREE_CHECK_OK(iree_vm_bytecode_module_create(
|
||||
// iree_instance,
|
||||
// iree_const_byte_span_t{
|
||||
// reinterpret_cast<const uint8_t*>(module_file_toc.data),
|
||||
// module_file_toc.size},
|
||||
// iree_allocator_null(), iree_allocator_system(), &bytecode_module));
|
||||
//// Query for details about what is in the loaded module.
|
||||
//iree_vm_module_signature_t bytecode_module_signature =
|
||||
// iree_vm_module_signature(bytecode_module);
|
||||
//fprintf(stdout, "Module loaded, have <%" PRIhsz "> exported functions:\n",
|
||||
// bytecode_module_signature.export_function_count);
|
||||
//for (int i = 0; i < bytecode_module_signature.export_function_count; ++i) {
|
||||
// iree_vm_function_t function;
|
||||
// IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal(
|
||||
// bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
|
||||
// auto function_name = iree_vm_function_name(&function);
|
||||
// auto function_signature = iree_vm_function_signature(&function);
|
||||
|
||||
fprintf(stdout, " %d: '%.*s' with calling convention '%.*s'\n", i,
|
||||
(int)function_name.size, function_name.data,
|
||||
(int)function_signature.calling_convention.size,
|
||||
function_signature.calling_convention.data);
|
||||
}
|
||||
// fprintf(stdout, " %d: '%.*s' with calling convention '%.*s'\n", i,
|
||||
// (int)function_name.size, function_name.data,
|
||||
// (int)function_signature.calling_convention.size,
|
||||
// function_signature.calling_convention.data);
|
||||
//}
|
||||
|
||||
// Allocate a context that will hold the module state across invocations.
|
||||
iree_vm_context_t* iree_context = nullptr;
|
||||
@@ -988,33 +809,42 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
// Write inputs into mappable buffers.
|
||||
iree_hal_allocator_t* allocator =
|
||||
iree_hal_device_allocator(iree_vk_device);
|
||||
iree_hal_memory_type_t input_memory_type =
|
||||
static_cast<iree_hal_memory_type_t>(
|
||||
IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
|
||||
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
|
||||
iree_hal_buffer_usage_t input_buffer_usage =
|
||||
static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
|
||||
iree_hal_buffer_params_t buffer_params;
|
||||
buffer_params.type = input_memory_type;
|
||||
buffer_params.usage = input_buffer_usage;
|
||||
buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE;
|
||||
//iree_hal_memory_type_t input_memory_type =
|
||||
// static_cast<iree_hal_memory_type_t>(
|
||||
// IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
|
||||
// IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
|
||||
//iree_hal_buffer_usage_t input_buffer_usage =
|
||||
// static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
|
||||
//iree_hal_buffer_params_t buffer_params;
|
||||
//buffer_params.type = input_memory_type;
|
||||
//buffer_params.usage = input_buffer_usage;
|
||||
//buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE;
|
||||
|
||||
// Wrap input buffers in buffer views.
|
||||
|
||||
iree_hal_buffer_view_t* input0_buffer_view = nullptr;
|
||||
constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
|
||||
IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
|
||||
allocator,
|
||||
/*shape_rank=*/4, /*shape=*/input_buffer_shape,
|
||||
IREE_HAL_ELEMENT_TYPE_FLOAT_32,
|
||||
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
|
||||
iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
|
||||
&input0_buffer_view));
|
||||
|
||||
vm::ref<iree_vm_list_t> inputs;
|
||||
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
|
||||
auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
|
||||
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
|
||||
iree_status_t input_status = ParseToVariantList(
|
||||
allocator,
|
||||
iree::span<const std::string>{FLAG_function_inputs.data(),
|
||||
FLAG_function_inputs.size()},
|
||||
iree_allocator_system(), &inputs);
|
||||
if (!iree_status_is_ok(input_status))
|
||||
return -1;
|
||||
//vm::ref<iree_vm_list_t> inputs;
|
||||
//IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
|
||||
|
||||
//iree_hal_buffer_view_t* input0_buffer_view = nullptr;
|
||||
//constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
|
||||
//IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
|
||||
// allocator,
|
||||
// /*shape_rank=*/4, /*shape=*/input_buffer_shape,
|
||||
// IREE_HAL_ELEMENT_TYPE_FLOAT_32,
|
||||
// IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
|
||||
// iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
|
||||
// &input0_buffer_view));
|
||||
|
||||
//auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
|
||||
//IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
|
||||
|
||||
// Prepare outputs list to accept results from the invocation.
|
||||
|
||||
@@ -1023,6 +853,7 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, kOutputCount * sizeof(float), iree_allocator_system(), &outputs));
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Main loop.
|
||||
bool done = false;
|
||||
while (!done) {
|
||||
@@ -1076,46 +907,11 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
/*policy=*/nullptr, inputs.get(),
|
||||
outputs.get(), iree_allocator_system()));
|
||||
|
||||
// Read back the results.
|
||||
auto* output_buffer_view = reinterpret_cast<iree_hal_buffer_view_t*>(
|
||||
iree_vm_list_get_ref_deref(outputs.get(),
|
||||
0,
|
||||
iree_hal_buffer_view_get_descriptor()));
|
||||
IREE_CHECK_OK(iree_hal_device_transfer_d2h(
|
||||
iree_vk_device,
|
||||
iree_hal_buffer_view_buffer(output_buffer_view),
|
||||
0,
|
||||
output_res50, sizeof(output_res50),
|
||||
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
|
||||
|
||||
// we want to run continuously so we can use tools like RenderDoc, RGP, etc...
|
||||
dirty = true;
|
||||
}
|
||||
|
||||
// find maxarg from results
|
||||
float max = 0.0f;
|
||||
int max_idx = -1;
|
||||
for(int i=0;i<1000;i++)
|
||||
{
|
||||
if (output_res50[i] > max)
|
||||
{
|
||||
max = output_res50[i];
|
||||
max_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
ImGui::Text("pointer = %p", my_image_texture);
|
||||
ImGui::Text("size = %d x %d", my_image_width, my_image_height);
|
||||
ImGui::Image((ImTextureID)my_image_texture, ImVec2(my_image_width, my_image_height));
|
||||
|
||||
// Display the latest computation output.
|
||||
ImGui::Text("Max idx = [%i]", max_idx);
|
||||
ImGui::Text("Max value = [%f]", max);
|
||||
|
||||
ImGui::Text("Resnet50 categories:");
|
||||
ImGui::PlotHistogram("Histogram", output_res50, IM_ARRAYSIZE(output_res50), 0, NULL, 0.0f, 1.0f, ImVec2(0,80));
|
||||
ImGui::Separator();
|
||||
|
||||
// Framerate counter.
|
||||
ImGui::Text("Application average %.3f ms/frame (%.1f FPS)",
|
||||
1000.0f / ImGui::GetIO().Framerate, ImGui::GetIO().Framerate);
|
||||
@@ -1137,6 +933,7 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
iree_vm_module_release(bytecode_module);
|
||||
iree_vm_context_release(iree_context);
|
||||
iree_hal_device_release(iree_vk_device);
|
||||
iree_hal_allocator_release(allocator);
|
||||
iree_hal_driver_release(iree_vk_driver);
|
||||
iree_hal_vulkan_syms_release(iree_vk_syms);
|
||||
iree_vm_instance_release(iree_instance);
|
||||
|
||||
1160
cpp/vulkan_gui/vulkan_resnet_inference_gui.cc
Normal file
1160
cpp/vulkan_gui/vulkan_resnet_inference_gui.cc
Normal file
File diff suppressed because it is too large
Load Diff
27
dataset/README.md
Normal file
27
dataset/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# Dataset annotation tool
|
||||
|
||||
SHARK annotator for adding or modifying prompts of dataset images
|
||||
|
||||
## Set up
|
||||
|
||||
Activate SHARK Python virtual environment and install additional packages
|
||||
```shell
|
||||
source ../shark.venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Run annotator
|
||||
|
||||
```shell
|
||||
python annotation_tool.py
|
||||
```
|
||||
|
||||
<img width="1280" alt="annotator" src="https://user-images.githubusercontent.com/49575973/214521137-7ef6ae10-7cd8-46e6-b270-b6c0445157f1.png">
|
||||
|
||||
* Select a dataset from `Dataset` dropdown list
|
||||
* Select an image from `Image` dropdown list
|
||||
* Image and the existing prompt will be loaded
|
||||
* Select a prompt from `Prompt` dropdown list to modify or "Add new" to add a prompt
|
||||
* Click `Save` to save changes, click `Delete` to delete prompt
|
||||
* Click `Back` or `Next` to switch image, you could also select other images from `Image`
|
||||
* Click `Finish` when finishing annotation or before switching dataset
|
||||
247
dataset/annotation_tool.py
Normal file
247
dataset/annotation_tool.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import gradio as gr
|
||||
import json
|
||||
import jsonlines
|
||||
import os
|
||||
from args import args
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from utils import get_datasets
|
||||
|
||||
|
||||
shark_root = Path(__file__).parent.parent
|
||||
demo_css = shark_root.joinpath("web/demo.css").resolve()
|
||||
nodlogo_loc = shark_root.joinpath(
|
||||
"web/models/stable_diffusion/logos/nod-logo.png"
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=100)
|
||||
|
||||
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
|
||||
prompt_data = dict()
|
||||
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
# TODO: add multiselect dataset, there is a gradio version conflict
|
||||
dataset = gr.Dropdown(label="Dataset", choices=datasets)
|
||||
image_name = gr.Dropdown(label="Image", choices=[])
|
||||
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
# TODO: add ability to search image by typing
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
image = gr.Image(type="filepath").style(height=512)
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
prompts = gr.Dropdown(
|
||||
label="Prompts",
|
||||
choices=[],
|
||||
)
|
||||
prompt = gr.Textbox(
|
||||
label="Editor",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
save = gr.Button("Save")
|
||||
delete = gr.Button("Delete")
|
||||
with gr.Row():
|
||||
back_image = gr.Button("Back")
|
||||
next_image = gr.Button("Next")
|
||||
finish = gr.Button("Finish")
|
||||
|
||||
def filter_datasets(dataset):
|
||||
if dataset is None:
|
||||
return gr.Dropdown.update(value=None, choices=[])
|
||||
|
||||
# create the dataset dir if doesn't exist and download prompt file
|
||||
dataset_path = str(shark_root) + "/dataset/" + dataset
|
||||
if not os.path.exists(dataset_path):
|
||||
os.mkdir(dataset_path)
|
||||
|
||||
# read prompt jsonlines file
|
||||
prompt_data.clear()
|
||||
if dataset in ds_w_prompts:
|
||||
prompt_gs_path = args.gs_url + "/" + dataset + "/metadata.jsonl"
|
||||
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
|
||||
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
|
||||
for line in reader.iter(type=dict, skip_invalid=True):
|
||||
prompt_data[line["file_name"]] = (
|
||||
[line["text"]]
|
||||
if type(line["text"]) is str
|
||||
else line["text"]
|
||||
)
|
||||
|
||||
return gr.Dropdown.update(choices=images[dataset])
|
||||
|
||||
dataset.change(fn=filter_datasets, inputs=dataset, outputs=image_name)
|
||||
|
||||
def display_image(dataset, image_name):
|
||||
if dataset is None or image_name is None:
|
||||
return gr.Image.update(value=None), gr.Dropdown.update(value=None)
|
||||
|
||||
# download and load the image
|
||||
img_gs_path = args.gs_url + "/" + dataset + "/" + image_name
|
||||
img_sub_path = "/".join(image_name.split("/")[:-1])
|
||||
img_dst_path = (
|
||||
str(shark_root) + "/dataset/" + dataset + "/" + img_sub_path + "/"
|
||||
)
|
||||
if not os.path.exists(img_dst_path):
|
||||
os.mkdir(img_dst_path)
|
||||
os.system(f'gsutil cp "{img_gs_path}" "{img_dst_path}"')
|
||||
img = Image.open(img_dst_path + image_name.split("/")[-1])
|
||||
|
||||
if image_name not in prompt_data.keys():
|
||||
prompt_data[image_name] = []
|
||||
prompt_choices = ["Add new"]
|
||||
prompt_choices += prompt_data[image_name]
|
||||
return gr.Image.update(value=img), gr.Dropdown.update(
|
||||
choices=prompt_choices
|
||||
)
|
||||
|
||||
image_name.change(
|
||||
fn=display_image,
|
||||
inputs=[dataset, image_name],
|
||||
outputs=[image, prompts],
|
||||
)
|
||||
|
||||
def edit_prompt(prompts):
|
||||
if prompts == "Add new":
|
||||
return gr.Textbox.update(value=None)
|
||||
|
||||
return gr.Textbox.update(value=prompts)
|
||||
|
||||
prompts.change(fn=edit_prompt, inputs=prompts, outputs=prompt)
|
||||
|
||||
def save_prompt(dataset, image_name, prompts, prompt):
|
||||
if (
|
||||
dataset is None
|
||||
or image_name is None
|
||||
or prompts is None
|
||||
or prompt is None
|
||||
):
|
||||
return
|
||||
|
||||
if prompts == "Add new":
|
||||
prompt_data[image_name].append(prompt)
|
||||
else:
|
||||
idx = prompt_data[image_name].index(prompts)
|
||||
prompt_data[image_name][idx] = prompt
|
||||
|
||||
prompt_path = (
|
||||
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
|
||||
)
|
||||
# write prompt jsonlines file
|
||||
with open(prompt_path, "w") as f:
|
||||
for key, value in prompt_data.items():
|
||||
if not value:
|
||||
continue
|
||||
v = value if len(value) > 1 else value[0]
|
||||
f.write(json.dumps({"file_name": key, "text": v}))
|
||||
f.write("\n")
|
||||
|
||||
prompt_choices = ["Add new"]
|
||||
prompt_choices += prompt_data[image_name]
|
||||
return gr.Dropdown.update(choices=prompt_choices, value=None)
|
||||
|
||||
save.click(
|
||||
fn=save_prompt,
|
||||
inputs=[dataset, image_name, prompts, prompt],
|
||||
outputs=prompts,
|
||||
)
|
||||
|
||||
def delete_prompt(dataset, image_name, prompts):
|
||||
if dataset is None or image_name is None or prompts is None:
|
||||
return
|
||||
if prompts == "Add new":
|
||||
return
|
||||
|
||||
prompt_data[image_name].remove(prompts)
|
||||
prompt_path = (
|
||||
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
|
||||
)
|
||||
# write prompt jsonlines file
|
||||
with open(prompt_path, "w") as f:
|
||||
for key, value in prompt_data.items():
|
||||
if not value:
|
||||
continue
|
||||
v = value if len(value) > 1 else value[0]
|
||||
f.write(json.dumps({"file_name": key, "text": v}))
|
||||
f.write("\n")
|
||||
|
||||
prompt_choices = ["Add new"]
|
||||
prompt_choices += prompt_data[image_name]
|
||||
return gr.Dropdown.update(choices=prompt_choices, value=None)
|
||||
|
||||
delete.click(
|
||||
fn=delete_prompt,
|
||||
inputs=[dataset, image_name, prompts],
|
||||
outputs=prompts,
|
||||
)
|
||||
|
||||
def get_back_image(dataset, image_name):
|
||||
if dataset is None or image_name is None:
|
||||
return
|
||||
|
||||
# remove local image
|
||||
img_path = str(shark_root) + "/dataset/" + dataset + "/" + image_name
|
||||
os.system(f'rm "{img_path}"')
|
||||
# get the index for the back image
|
||||
idx = images[dataset].index(image_name)
|
||||
if idx == 0:
|
||||
return gr.Dropdown.update(value=None)
|
||||
|
||||
return gr.Dropdown.update(value=images[dataset][idx - 1])
|
||||
|
||||
back_image.click(
|
||||
fn=get_back_image, inputs=[dataset, image_name], outputs=image_name
|
||||
)
|
||||
|
||||
def get_next_image(dataset, image_name):
|
||||
if dataset is None or image_name is None:
|
||||
return
|
||||
|
||||
# remove local image
|
||||
img_path = str(shark_root) + "/dataset/" + dataset + "/" + image_name
|
||||
os.system(f'rm "{img_path}"')
|
||||
# get the index for the next image
|
||||
idx = images[dataset].index(image_name)
|
||||
if idx == len(images[dataset]) - 1:
|
||||
return gr.Dropdown.update(value=None)
|
||||
|
||||
return gr.Dropdown.update(value=images[dataset][idx + 1])
|
||||
|
||||
next_image.click(
|
||||
fn=get_next_image, inputs=[dataset, image_name], outputs=image_name
|
||||
)
|
||||
|
||||
def finish_annotation(dataset):
|
||||
if dataset is None:
|
||||
return
|
||||
|
||||
# upload prompt and remove local data
|
||||
dataset_path = str(shark_root) + "/dataset/" + dataset
|
||||
dataset_gs_path = args.gs_url + "/" + dataset + "/"
|
||||
os.system(
|
||||
f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"'
|
||||
)
|
||||
os.system(f'rm -rf "{dataset_path}"')
|
||||
|
||||
return gr.Dropdown.update(value=None)
|
||||
|
||||
finish.click(fn=finish_annotation, inputs=dataset, outputs=dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
shark_web.launch(
|
||||
share=args.share,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.server_port,
|
||||
)
|
||||
34
dataset/args.py
Normal file
34
dataset/args.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Dataset Annotator flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--gs_url",
|
||||
type=str,
|
||||
required=True,
|
||||
help="URL to datasets in GS bucket",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--share",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for generating a public URL",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=8080,
|
||||
help="flag for setting server port",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
|
||||
args = p.parse_args()
|
||||
3
dataset/requirements.txt
Normal file
3
dataset/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
# SHARK Annotator
|
||||
gradio==3.15.0
|
||||
jsonlines
|
||||
29
dataset/utils.py
Normal file
29
dataset/utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from google.cloud import storage
|
||||
|
||||
|
||||
def get_datasets(gs_url):
|
||||
datasets = set()
|
||||
images = dict()
|
||||
ds_w_prompts = []
|
||||
|
||||
storage_client = storage.Client()
|
||||
bucket_name = gs_url.split("/")[2]
|
||||
source_blob_name = "/".join(gs_url.split("/")[3:])
|
||||
blobs = storage_client.list_blobs(bucket_name, prefix=source_blob_name)
|
||||
|
||||
for blob in blobs:
|
||||
dataset_name = blob.name.split("/")[1]
|
||||
if dataset_name == "":
|
||||
continue
|
||||
datasets.add(dataset_name)
|
||||
if dataset_name not in images.keys():
|
||||
images[dataset_name] = []
|
||||
|
||||
# check if image or jsonl
|
||||
file_sub_path = "/".join(blob.name.split("/")[2:])
|
||||
if "/" in file_sub_path:
|
||||
images[dataset_name] += [file_sub_path]
|
||||
elif "metadata.jsonl" in file_sub_path:
|
||||
ds_w_prompts.append(dataset_name)
|
||||
|
||||
return list(datasets), images, ds_w_prompts
|
||||
@@ -2,33 +2,26 @@
|
||||
"""SHARK Tank"""
|
||||
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
|
||||
# will generate local shark tank folder like this:
|
||||
# HOME
|
||||
# /.local
|
||||
# /shark_tank
|
||||
# /albert_lite_base
|
||||
# /...model_name...
|
||||
# /SHARK
|
||||
# /gen_shark_tank
|
||||
# /albert_lite_base
|
||||
# /...model_name...
|
||||
#
|
||||
|
||||
import os
|
||||
import csv
|
||||
import argparse
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
import tensorflow as tf
|
||||
import subprocess as sp
|
||||
import hashlib
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
visible_default = tf.config.list_physical_devices("GPU")
|
||||
try:
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
visible_devices = tf.config.get_visible_devices()
|
||||
for device in visible_devices:
|
||||
assert device.device_type != "GPU"
|
||||
except:
|
||||
# Invalid device or cannot modify virtual devices once initialized.
|
||||
pass
|
||||
from apps.stable_diffusion.src.models import (
|
||||
model_wrappers as mw,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.stable_args import (
|
||||
args,
|
||||
)
|
||||
|
||||
|
||||
def create_hash(file_name):
|
||||
@@ -41,9 +34,12 @@ def create_hash(file_name):
|
||||
|
||||
|
||||
def save_torch_model(torch_model_list):
|
||||
from tank.model_utils import get_hf_model
|
||||
from tank.model_utils import get_vision_model
|
||||
from tank.model_utils import get_hf_img_cls_model
|
||||
from tank.model_utils import (
|
||||
get_hf_model,
|
||||
get_vision_model,
|
||||
get_hf_img_cls_model,
|
||||
get_fp16_model,
|
||||
)
|
||||
|
||||
with open(torch_model_list) as csvfile:
|
||||
torch_reader = csv.reader(csvfile, delimiter=",")
|
||||
@@ -59,13 +55,39 @@ def save_torch_model(torch_model_list):
|
||||
|
||||
model = None
|
||||
input = None
|
||||
if model_type == "stable_diffusion":
|
||||
args.use_tuned = False
|
||||
args.import_mlir = True
|
||||
args.use_tuned = False
|
||||
args.local_tank_cache = WORKDIR
|
||||
|
||||
precision_values = ["fp16"]
|
||||
seq_lengths = [64, 77]
|
||||
for precision_value in precision_values:
|
||||
args.precision = precision_value
|
||||
for length in seq_lengths:
|
||||
model = mw.SharkifyStableDiffusionModel(
|
||||
model_id=torch_model_name,
|
||||
custom_weights="",
|
||||
precision=precision_value,
|
||||
max_len=length,
|
||||
width=512,
|
||||
height=512,
|
||||
use_base_vae=False,
|
||||
debug=True,
|
||||
sharktank_dir=WORKDIR,
|
||||
generate_vmfb=False,
|
||||
)
|
||||
model()
|
||||
continue
|
||||
if model_type == "vision":
|
||||
model, input, _ = get_vision_model(torch_model_name)
|
||||
elif model_type == "hf":
|
||||
model, input, _ = get_hf_model(torch_model_name)
|
||||
elif model_type == "hf_img_cls":
|
||||
model, input, _ = get_hf_img_cls_model(torch_model_name)
|
||||
|
||||
elif model_type == "fp16":
|
||||
model, input, _ = get_fp16_model(torch_model_name)
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
torch_model_dir = os.path.join(
|
||||
WORKDIR, str(torch_model_name) + "_torch"
|
||||
@@ -106,6 +128,17 @@ def save_tf_model(tf_model_list):
|
||||
get_keras_model,
|
||||
get_TFhf_model,
|
||||
)
|
||||
import tensorflow as tf
|
||||
|
||||
visible_default = tf.config.list_physical_devices("GPU")
|
||||
try:
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
visible_devices = tf.config.get_visible_devices()
|
||||
for device in visible_devices:
|
||||
assert device.device_type != "GPU"
|
||||
except:
|
||||
# Invalid device or cannot modify virtual devices once initialized.
|
||||
pass
|
||||
|
||||
with open(tf_model_list) as csvfile:
|
||||
tf_reader = csv.reader(csvfile, delimiter=",")
|
||||
@@ -129,13 +162,13 @@ def save_tf_model(tf_model_list):
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
input,
|
||||
inputs=input,
|
||||
frontend="tf",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
dir=tf_model_dir,
|
||||
model_name=tf_model_name,
|
||||
)
|
||||
@@ -201,51 +234,48 @@ def is_valid_file(arg):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--torch_model_csv",
|
||||
type=lambda x: is_valid_file(x),
|
||||
default="./tank/pytorch/torch_model_list.csv",
|
||||
help="""Contains the file with torch_model name and args.
|
||||
Please see: https://github.com/nod-ai/SHARK/blob/main/tank/pytorch/torch_model_list.csv""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tf_model_csv",
|
||||
type=lambda x: is_valid_file(x),
|
||||
default="./tank/tf/tf_model_list.csv",
|
||||
help="Contains the file with tf model name and args.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tflite_model_csv",
|
||||
type=lambda x: is_valid_file(x),
|
||||
default="./tank/tflite/tflite_model_list.csv",
|
||||
help="Contains the file with tf model name and args.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ci_tank_dir",
|
||||
type=bool,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument("--upload", type=bool, default=False)
|
||||
# Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument(
|
||||
# "--torch_model_csv",
|
||||
# type=lambda x: is_valid_file(x),
|
||||
# default="./tank/torch_model_list.csv",
|
||||
# help="""Contains the file with torch_model name and args.
|
||||
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--tf_model_csv",
|
||||
# type=lambda x: is_valid_file(x),
|
||||
# default="./tank/tf_model_list.csv",
|
||||
# help="Contains the file with tf model name and args.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--tflite_model_csv",
|
||||
# type=lambda x: is_valid_file(x),
|
||||
# default="./tank/tflite/tflite_model_list.csv",
|
||||
# help="Contains the file with tf model name and args.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--ci_tank_dir",
|
||||
# type=bool,
|
||||
# default=False,
|
||||
# )
|
||||
# parser.add_argument("--upload", type=bool, default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
# old_args = parser.parse_args()
|
||||
|
||||
home = str(Path.home())
|
||||
if args.ci_tank_dir == True:
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
|
||||
else:
|
||||
WORKDIR = os.path.join(home, ".local/shark_tank/")
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "tf_model_list.csv"
|
||||
)
|
||||
tflite_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
|
||||
)
|
||||
|
||||
if args.torch_model_csv:
|
||||
save_torch_model(args.torch_model_csv)
|
||||
|
||||
if args.tf_model_csv:
|
||||
save_tf_model(args.tf_model_csv)
|
||||
|
||||
if args.tflite_model_csv:
|
||||
save_tflite_model(args.tflite_model_csv)
|
||||
|
||||
if args.upload:
|
||||
git_hash = sp.getoutput("git log -1 --format='%h'") + "/"
|
||||
print("uploading files to gs://shark_tank/" + git_hash)
|
||||
os.system(f"gsutil cp -r {WORKDIR}* gs://shark_tank/" + git_hash)
|
||||
save_torch_model(torch_model_csv)
|
||||
save_tf_model(tf_model_csv)
|
||||
save_tflite_model(tflite_model_csv)
|
||||
|
||||
34
process_skipfiles.py
Normal file
34
process_skipfiles.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# This script will toggle the comment/uncommenting aspect for dealing
|
||||
# with __file__ AttributeError arising in case of a few modules in
|
||||
# `torch/_dynamo/skipfiles.py` (within shark.venv)
|
||||
|
||||
from distutils.sysconfig import get_python_lib
|
||||
import fileinput
|
||||
from pathlib import Path
|
||||
|
||||
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
|
||||
|
||||
modules_to_comment = ["abc,", "os,", "posixpath,", "_collections_abc,"]
|
||||
startMonitoring = 0
|
||||
for line in fileinput.input(path_to_skipfiles, inplace=True):
|
||||
if "SKIP_DIRS = " in line:
|
||||
startMonitoring = 1
|
||||
print(line, end="")
|
||||
elif startMonitoring in [1, 2]:
|
||||
if "]" in line:
|
||||
startMonitoring += 1
|
||||
print(line, end="")
|
||||
else:
|
||||
flag = True
|
||||
for module in modules_to_comment:
|
||||
if module in line:
|
||||
if not line.startswith("#"):
|
||||
print(f"#{line}", end="")
|
||||
else:
|
||||
print(f"{line[1:]}", end="")
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
print(line, end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
@@ -4,9 +4,9 @@ requires = [
|
||||
"wheel",
|
||||
"packaging",
|
||||
|
||||
"numpy==1.22.4",
|
||||
"torch-mlir>=20220428.420",
|
||||
"iree-compiler>=20220427.13",
|
||||
"iree-runtime>=20220427.13",
|
||||
"numpy>=1.22.4",
|
||||
"torch-mlir>=20221021.633",
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
[pytest]
|
||||
addopts = --verbose -p no:warnings
|
||||
norecursedirs = inference tank/tflite
|
||||
norecursedirs = inference tank/tflite examples benchmarks shark
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/
|
||||
--pre
|
||||
|
||||
numpy
|
||||
@@ -28,6 +28,7 @@ Pillow
|
||||
|
||||
# web dependecies.
|
||||
gradio
|
||||
altair
|
||||
|
||||
# Testing and support.
|
||||
#lit
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
|
||||
numpy==1.22.4
|
||||
torch
|
||||
numpy>1.22.4
|
||||
torchvision
|
||||
pytorch-triton
|
||||
tabulate
|
||||
|
||||
tqdm
|
||||
|
||||
@@ -14,7 +15,8 @@ iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow
|
||||
tf-nightly
|
||||
keras>=2.10
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
@@ -34,6 +36,7 @@ sacremoses
|
||||
|
||||
# web dependecies.
|
||||
gradio
|
||||
altair
|
||||
scipy
|
||||
|
||||
#ONNX and ORT for benchmarking
|
||||
|
||||
@@ -5,10 +5,25 @@ wheel
|
||||
tqdm
|
||||
|
||||
# SHARK Downloader
|
||||
gsutil
|
||||
google-cloud-storage
|
||||
|
||||
# Testing
|
||||
pytest
|
||||
pytest-xdist
|
||||
pytest-forked
|
||||
Pillow
|
||||
parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
transformers
|
||||
diffusers
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
altair
|
||||
omegaconf
|
||||
safetensors
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
pyinstaller
|
||||
|
||||
12
setup.py
12
setup.py
@@ -2,16 +2,17 @@ from setuptools import find_packages
|
||||
from setuptools import setup
|
||||
|
||||
import os
|
||||
import glob
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.4"
|
||||
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5"
|
||||
backend_deps = []
|
||||
if "NO_BACKEND" in os.environ.keys():
|
||||
backend_deps = [
|
||||
"iree-compiler>=20220427.13",
|
||||
"iree-runtime>=20220427.13",
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
]
|
||||
|
||||
setup(
|
||||
@@ -33,11 +34,12 @@ setup(
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
packages=find_packages(exclude=("examples")),
|
||||
python_requires=">=3.7",
|
||||
python_requires=">=3.9",
|
||||
data_files=glob.glob("apps/stable_diffusion/resources/**"),
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"PyYAML",
|
||||
"torch-mlir>=20220428.420",
|
||||
"torch-mlir>=20221021.633",
|
||||
]
|
||||
+ backend_deps,
|
||||
)
|
||||
|
||||
97
setup_venv.ps1
Normal file
97
setup_venv.ps1
Normal file
@@ -0,0 +1,97 @@
|
||||
<#
|
||||
.SYNOPSIS
|
||||
A script to update and install the SHARK runtime and its dependencies.
|
||||
|
||||
.DESCRIPTION
|
||||
This script updates and installs the SHARK runtime and its dependencies.
|
||||
It checks the Python version installed and installs any required build
|
||||
dependencies into a Python virtual environment.
|
||||
If that environment does not exist, it creates it.
|
||||
|
||||
.PARAMETER update-src
|
||||
git pulls latest version
|
||||
|
||||
.PARAMETER force
|
||||
removes and recreates venv to force update of all dependencies
|
||||
|
||||
.EXAMPLE
|
||||
.\setup_venv.ps1 --force
|
||||
|
||||
.EXAMPLE
|
||||
.\setup_venv.ps1 --update-src
|
||||
|
||||
.INPUTS
|
||||
None
|
||||
|
||||
.OUTPUTS
|
||||
None
|
||||
|
||||
#>
|
||||
|
||||
param([string]$arguments)
|
||||
|
||||
if ($arguments -eq "--update-src"){
|
||||
git pull
|
||||
}
|
||||
|
||||
if ($arguments -eq "--force"){
|
||||
if (Test-Path env:VIRTUAL_ENV) {
|
||||
Write-Host "deactivating..."
|
||||
Deactivate
|
||||
}
|
||||
|
||||
if (Test-Path .\shark.venv\) {
|
||||
Write-Host "removing and recreating venv..."
|
||||
Remove-Item .\shark.venv -Force -Recurse
|
||||
if (Test-Path .\shark.venv\) {
|
||||
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# redirect stderr into stdout
|
||||
$p = &{python -V} 2>&1
|
||||
# check if an ErrorRecord was returned
|
||||
$version = if($p -is [System.Management.Automation.ErrorRecord])
|
||||
{
|
||||
# grab the version string from the error message
|
||||
$p.Exception.Message
|
||||
}
|
||||
else
|
||||
{
|
||||
# otherwise return complete Python list
|
||||
$ErrorActionPreference = 'SilentlyContinue'
|
||||
$PyVer = py --list
|
||||
}
|
||||
|
||||
# deactivate any activated venvs
|
||||
if ($PyVer -like "*venv*")
|
||||
{
|
||||
deactivate # make sure we don't update the wrong venv
|
||||
$PyVer = py --list # update list
|
||||
}
|
||||
|
||||
Write-Host "Python versions found are"
|
||||
Write-Host ($PyVer | Out-String) # formatted output with line breaks
|
||||
if (!($PyVer.length -ne 0)) {$p} # return Python --version String if py.exe is unavailable
|
||||
if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any list
|
||||
{
|
||||
Write-Host "Please install Python 3.11 and try again"
|
||||
break
|
||||
}
|
||||
|
||||
Write-Host "Installing Build Dependencies"
|
||||
# make sure we really use 3.11 from list, even if it's not the default.
|
||||
if (!($PyVer.length -ne 0)) {py -3.11 -m venv .\shark.venv\}
|
||||
else {python -m venv .\shark.venv\}
|
||||
.\shark.venv\Scripts\activate
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
Write-Host "Build and installation completed successfully"
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
@@ -42,7 +42,7 @@ Green=`tput setaf 2`
|
||||
Yellow=`tput setaf 3`
|
||||
|
||||
# Assume no binary torch-mlir.
|
||||
# Currently available for macOS m1&intel (3.10) and Linux(3.7,3.8,3.9,3.10)
|
||||
# Currently available for macOS m1&intel (3.11) and Linux(3.8,3.10,3.11)
|
||||
torch_mlir_bin=false
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "${Yellow}Apple macOS detected"
|
||||
@@ -60,12 +60,12 @@ if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
fi
|
||||
echo "${Yellow}Run the following commands to setup your SSL certs for your Python version if you see SSL errors with tests"
|
||||
echo "${Yellow}/Applications/Python\ 3.XX/Install\ Certificates.command"
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.10" ]; then
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.11" ]; then
|
||||
torch_mlir_bin=true
|
||||
fi
|
||||
elif [[ $(uname -s) = 'Linux' ]]; then
|
||||
echo "${Yellow}Linux detected"
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.7" ] || [ "$PYTHON_VERSION_X_Y" == "3.8" ] || [ "$PYTHON_VERSION_X_Y" == "3.9" ] || [ "$PYTHON_VERSION_X_Y" == "3.10" ] ; then
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.8" ] || [ "$PYTHON_VERSION_X_Y" == "3.10" ] || [ "$PYTHON_VERSION_X_Y" == "3.11" ] ; then
|
||||
torch_mlir_bin=true
|
||||
fi
|
||||
else
|
||||
@@ -76,50 +76,64 @@ fi
|
||||
$PYTHON -m pip install --upgrade pip || die "Could not upgrade pip"
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
|
||||
if [ "$torch_mlir_bin" = true ]; then
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch-mlir"
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
echo "Could not install torch-mlir" >&2
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch-mlir"
|
||||
else
|
||||
echo "Could not install torch-mlir" >&2
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "${Red}No binaries found for Python $PYTHON_VERSION_X_Y on $(uname -s)"
|
||||
echo "${Yello}Python 3.10 supported on macOS and 3.7,3.8,3.9 and 3.10 on Linux"
|
||||
echo "${Yello}Python 3.11 supported on macOS and 3.8,3.10 and 3.11 on Linux"
|
||||
echo "${Red}Please build torch-mlir from source in your environment"
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "${USE_IREE}" ]]; then
|
||||
RUNTIME="nod-ai/SHARK-Runtime"
|
||||
rm .use-iree
|
||||
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
|
||||
else
|
||||
RUNTIME="google/iree"
|
||||
touch ./.use-iree
|
||||
RUNTIME="https://openxla.github.io/iree/pip-release-links.html"
|
||||
fi
|
||||
if [[ -z "${NO_BACKEND}" ]]; then
|
||||
echo "Installing ${RUNTIME}..."
|
||||
$PYTHON -m pip install --find-links https://github.com/${RUNTIME}/releases iree-compiler iree-runtime
|
||||
$PYTHON -m pip install --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
|
||||
else
|
||||
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
|
||||
fi
|
||||
|
||||
if [[ ! -z "${IMPORTER}" ]]; then
|
||||
echo "${Yellow}Installing importer tools.."
|
||||
if [[ $(uname -s) = 'Linux' ]]; then
|
||||
echo "${Yellow}Linux detected.. installing Linux importer tools"
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements-importer.txt" -f https://github.com/${RUNTIME}/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
#Always get the importer tools from upstream IREE
|
||||
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://openxla.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
elif [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "${Yellow}macOS detected.. installing macOS importer tools"
|
||||
#Conda seems to have some problems installing these packages and hope they get resolved upstream.
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements-importer-macos.txt" -f https://github.com/${RUNTIME}/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer-macos.txt" -f ${RUNTIME} --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
fi
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://github.com/${RUNTIME}/releases
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
T_VER=$($PYTHON -m pip show torch | grep Version)
|
||||
TORCH_VERSION=${T_VER:9:17}
|
||||
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
|
||||
TV_VERSION=${TV_VER:9:18}
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu116
|
||||
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu116."
|
||||
echo "Successfully Installed torch + cu117."
|
||||
else
|
||||
echo "Could not install torch + cu116." >&2
|
||||
echo "Could not install torch + cu117." >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torchdynamo
|
||||
import torch
|
||||
import torch_mlir
|
||||
import torch._dynamo as torchdynamo
|
||||
from shark.sharkdynamo.utils import make_shark_compiler
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,9 @@
|
||||
" from torchdynamo.optimizations.backends import create_backend\n",
|
||||
" from torchdynamo.optimizations.subgraph import SubGraph\n",
|
||||
"except ModuleNotFoundError:\n",
|
||||
" print(\"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\")\n",
|
||||
" print(\n",
|
||||
" \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n",
|
||||
" )\n",
|
||||
" exit()\n",
|
||||
"\n",
|
||||
"# torch-mlir imports for compiling\n",
|
||||
@@ -97,7 +99,9 @@
|
||||
"\n",
|
||||
" for node in fx_g.graph.nodes:\n",
|
||||
" if node.op == \"output\":\n",
|
||||
" assert len(node.args) == 1, \"Output node must have a single argument\"\n",
|
||||
" assert (\n",
|
||||
" len(node.args) == 1\n",
|
||||
" ), \"Output node must have a single argument\"\n",
|
||||
" node_arg = node.args[0]\n",
|
||||
" if isinstance(node_arg, tuple) and len(node_arg) == 1:\n",
|
||||
" node.args = (node_arg[0],)\n",
|
||||
@@ -116,8 +120,12 @@
|
||||
" if len(args) == 1 and isinstance(args[0], list):\n",
|
||||
" args = args[0]\n",
|
||||
"\n",
|
||||
" linalg_module = compile(ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS)\n",
|
||||
" callable, _ = get_iree_compiled_module(linalg_module, \"cuda\", func_name=\"forward\")\n",
|
||||
" linalg_module = compile(\n",
|
||||
" ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n",
|
||||
" )\n",
|
||||
" callable, _ = get_iree_compiled_module(\n",
|
||||
" linalg_module, \"cuda\", func_name=\"forward\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(*inputs):\n",
|
||||
" return callable(*inputs)\n",
|
||||
@@ -212,6 +220,7 @@
|
||||
" assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n",
|
||||
" return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@torchdynamo.optimize(\"torch_mlir\")\n",
|
||||
"def toy_example2(*args):\n",
|
||||
" a, b = args\n",
|
||||
|
||||
@@ -22,7 +22,7 @@ class CLIPModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, pixel_values=z
|
||||
)
|
||||
|
||||
@tf.function(input_signature=clip_vit_inputs)
|
||||
@tf.function(input_signature=clip_vit_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, pixel_values):
|
||||
return self.m.predict(
|
||||
input_ids, attention_mask, pixel_values
|
||||
|
||||
15
shark/examples/shark_inference/ESRGAN/README.md
Normal file
15
shark/examples/shark_inference/ESRGAN/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
## Running ESRGAN
|
||||
|
||||
```
|
||||
1. pip install numpy opencv-python
|
||||
2. mkdir InputImages
|
||||
(this is where all the input images will reside in)
|
||||
3. mkdir OutputImages
|
||||
(this is where the model will generate all the images)
|
||||
4. mkdir models
|
||||
(save the .pth checkpoint file here)
|
||||
5. python esrgan.py
|
||||
```
|
||||
|
||||
- Download [RRDB_ESRGAN_x4.pth](https://drive.google.com/drive/u/0/folders/17VYV_SoZZesU6mbxz2dMAIccSSlqLecY) and place it in the `models` directory as mentioned above in step 4.
|
||||
- Credits : [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
239
shark/examples/shark_inference/ESRGAN/esrgan.py
Normal file
239
shark/examples/shark_inference/ESRGAN/esrgan.py
Normal file
@@ -0,0 +1,239 @@
|
||||
from ast import arg
|
||||
import os.path as osp
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from shark.shark_inference import SharkInference
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""Residual in Residual Dense Block"""
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(
|
||||
self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
fea = self.lrelu(
|
||||
self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument(
|
||||
"--mlir_loc",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location of the model's mlir file",
|
||||
)
|
||||
args = p.parse_args()
|
||||
###################################################
|
||||
|
||||
|
||||
def inference(input_m):
|
||||
return model(input_m)
|
||||
|
||||
|
||||
def load_mlir(mlir_loc):
|
||||
import os
|
||||
|
||||
if mlir_loc == None:
|
||||
return None
|
||||
print(f"Trying to load the model from {mlir_loc}.")
|
||||
with open(os.path.join(mlir_loc)) as f:
|
||||
mlir_module = f.read()
|
||||
return mlir_module
|
||||
|
||||
|
||||
def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
module = load_mlir(mlir_loc)
|
||||
if module == None:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
print("Torchscript graph generated successfully")
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
inputs,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mlir_model = str(module)
|
||||
func_name = "forward"
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
model_path = "models/RRDB_ESRGAN_x4.pth" # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
|
||||
# device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
|
||||
device = torch.device("cpu")
|
||||
|
||||
test_img_folder = "InputImages/*"
|
||||
|
||||
model = RRDBNet(3, 3, 64, 23, gc=32)
|
||||
model.load_state_dict(torch.load(model_path), strict=True)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
print("Model path {:s}. \nTesting...".format(model_path))
|
||||
|
||||
if __name__ == "__main__":
|
||||
idx = 0
|
||||
for path in glob.glob(test_img_folder):
|
||||
idx += 1
|
||||
base = osp.splitext(osp.basename(path))[0]
|
||||
print(idx, base)
|
||||
# read images
|
||||
img = cv2.imread(path, cv2.IMREAD_COLOR)
|
||||
img = img * 1.0 / 255
|
||||
img = torch.from_numpy(
|
||||
np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))
|
||||
).float()
|
||||
img_LR = img.unsqueeze(0)
|
||||
img_LR = img_LR.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
shark_module = compile_through_fx(inference, img_LR)
|
||||
shark_output = shark_module.forward((img_LR,))
|
||||
shark_output = torch.from_numpy(shark_output)
|
||||
shark_output = (
|
||||
shark_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
)
|
||||
esrgan_output = (
|
||||
model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
)
|
||||
# SHARK OUTPUT
|
||||
shark_output = np.transpose(shark_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
shark_output = (shark_output * 255.0).round()
|
||||
cv2.imwrite(
|
||||
"OutputImages/{:s}_rlt_shark_output.png".format(base), shark_output
|
||||
)
|
||||
print("Generated SHARK's output")
|
||||
# ESRGAN OUTPUT
|
||||
esrgan_output = np.transpose(esrgan_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
esrgan_output = (esrgan_output * 255.0).round()
|
||||
cv2.imwrite(
|
||||
"OutputImages/{:s}_rlt_esrgan_output.png".format(base),
|
||||
esrgan_output,
|
||||
)
|
||||
print("Generated ESRGAN's output")
|
||||
@@ -28,7 +28,7 @@ class AlbertModule(tf.Module):
|
||||
self.m = TFAutoModelForMaskedLM.from_pretrained("albert-base-v2")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs)
|
||||
@tf.function(input_signature=t5_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("bloom")
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"bloom", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"
|
||||
|
||||
@@ -19,7 +19,7 @@ class GPT2Module(tf.Module):
|
||||
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
|
||||
|
||||
@tf.function(input_signature=gpt2_inputs)
|
||||
@tf.function(input_signature=gpt2_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class BertModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased",
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class BertModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
################################## Preprocessing inputs and model ############
|
||||
@@ -66,10 +66,12 @@ labels = load_labels()
|
||||
|
||||
|
||||
## Can pass any img or input to the forward module.
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50")
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"resnet50", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
|
||||
# shark_module.compile()
|
||||
shark_module.compile()
|
||||
path = shark_module.save_module()
|
||||
shark_module.load_module(path)
|
||||
result = shark_module.forward((img.detach().numpy(),))
|
||||
|
||||
421
shark/examples/shark_inference/sharded_bloom.py
Normal file
421
shark/examples/shark_inference/sharded_bloom.py
Normal file
@@ -0,0 +1,421 @@
|
||||
####################################################################################
|
||||
# Please make sure you have transformers 4.21.2 installed before running this demo
|
||||
#
|
||||
# -p --model_path: the directory in which you want to store the bloom files.
|
||||
# -dl --device_list: the list of device indices you want to use. if you want to only use the first device, or you are running on cpu leave this blank.
|
||||
# Otherwise, please give this argument in this format: "[0, 1, 2]"
|
||||
# -de --device: the device you want to run bloom on. E.G. cpu, cuda
|
||||
# -c, --recompile: set to true if you want to recompile to vmfb.
|
||||
# -d, --download: set to true if you want to redownload the mlir files
|
||||
# -t --token_count: the number of tokens you want to generate
|
||||
# -pr --prompt: the prompt you want to feed to the model
|
||||
#####################################################################################
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
import torch_mlir
|
||||
from torch_mlir import TensorPlaceholder
|
||||
import re
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
||||
import json
|
||||
import sys
|
||||
import argparse
|
||||
from cuda.cudart import cudaSetDevice
|
||||
import json
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
|
||||
from transformers.models.bloom.modeling_bloom import (
|
||||
BloomBlock,
|
||||
build_alibi_tensor,
|
||||
)
|
||||
|
||||
IS_CUDA = False
|
||||
|
||||
|
||||
class ShardedBloom:
|
||||
def __init__(self, src_folder):
|
||||
f = open(f"{src_folder}/config.json")
|
||||
config = json.load(f)
|
||||
f.close()
|
||||
|
||||
self.layers_initialized = False
|
||||
|
||||
self.src_folder = src_folder
|
||||
self.n_embed = config["n_embed"]
|
||||
self.vocab_size = config["vocab_size"]
|
||||
self.n_layer = config["n_layer"]
|
||||
self.n_head = config["num_attention_heads"]
|
||||
|
||||
def _init_layer(self, layer_name, device, replace, device_idx):
|
||||
if replace or not os.path.exists(
|
||||
f"{self.src_folder}/{layer_name}.vmfb"
|
||||
):
|
||||
f_ = open(f"{self.src_folder}/{layer_name}.mlir")
|
||||
module = f_.read()
|
||||
f_.close()
|
||||
module = bytes(module, "utf-8")
|
||||
shark_module = SharkInference(
|
||||
module,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
shark_module.save_module(
|
||||
module_name=f"{self.src_folder}/{layer_name}",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
"",
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
return shark_module
|
||||
|
||||
def init_layers(self, device, replace=False, device_idx=[0]):
|
||||
if device_idx is not None:
|
||||
n_devices = len(device_idx)
|
||||
|
||||
self.word_embeddings_module = self._init_layer(
|
||||
"word_embeddings",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[0 % n_devices],
|
||||
)
|
||||
self.word_embeddings_layernorm_module = self._init_layer(
|
||||
"word_embeddings_layernorm",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[1 % n_devices],
|
||||
)
|
||||
self.ln_f_module = self._init_layer(
|
||||
"ln_f",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[2 % n_devices],
|
||||
)
|
||||
self.lm_head_module = self._init_layer(
|
||||
"lm_head",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[3 % n_devices],
|
||||
)
|
||||
self.block_modules = [
|
||||
self._init_layer(
|
||||
f"bloom_block_{i}",
|
||||
device,
|
||||
replace,
|
||||
device_idx
|
||||
if device_idx is None
|
||||
else device_idx[(i + 4) % n_devices],
|
||||
)
|
||||
for i in range(self.n_layer)
|
||||
]
|
||||
|
||||
self.layers_initialized = True
|
||||
|
||||
def load_layers(self):
|
||||
assert self.layers_initialized
|
||||
|
||||
self.word_embeddings_module.load_module(
|
||||
f"{self.src_folder}/word_embeddings.vmfb"
|
||||
)
|
||||
self.word_embeddings_layernorm_module.load_module(
|
||||
f"{self.src_folder}/word_embeddings_layernorm.vmfb"
|
||||
)
|
||||
for block_module, i in zip(self.block_modules, range(self.n_layer)):
|
||||
block_module.load_module(f"{self.src_folder}/bloom_block_{i}.vmfb")
|
||||
self.ln_f_module.load_module(f"{self.src_folder}/ln_f.vmfb")
|
||||
self.lm_head_module.load_module(f"{self.src_folder}/lm_head.vmfb")
|
||||
|
||||
def forward_pass(self, input_ids, device):
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.word_embeddings_module.device_idx)
|
||||
|
||||
input_embeds = self.word_embeddings_module(
|
||||
inputs=(input_ids,), function_name="forward"
|
||||
)
|
||||
|
||||
input_embeds = torch.tensor(input_embeds).float()
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.word_embeddings_layernorm_module.device_idx)
|
||||
hidden_states = self.word_embeddings_layernorm_module(
|
||||
inputs=(input_embeds,), function_name="forward"
|
||||
)
|
||||
|
||||
hidden_states = torch.tensor(hidden_states).float()
|
||||
|
||||
attention_mask = torch.ones(
|
||||
[hidden_states.shape[0], len(input_ids[0])]
|
||||
)
|
||||
alibi = build_alibi_tensor(
|
||||
attention_mask,
|
||||
self.n_head,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
)
|
||||
|
||||
causal_mask = _prepare_attn_mask(
|
||||
attention_mask, input_ids.size(), input_embeds, 0
|
||||
)
|
||||
causal_mask = torch.tensor(causal_mask).float()
|
||||
|
||||
presents = ()
|
||||
all_hidden_states = tuple(hidden_states)
|
||||
|
||||
for block_module, i in zip(self.block_modules, range(self.n_layer)):
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(block_module.device_idx)
|
||||
|
||||
output = block_module(
|
||||
inputs=(
|
||||
hidden_states.detach().numpy(),
|
||||
alibi.detach().numpy(),
|
||||
causal_mask.detach().numpy(),
|
||||
),
|
||||
function_name="forward",
|
||||
)
|
||||
hidden_states = torch.tensor(output[0]).float()
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
presents = presents + (
|
||||
tuple(
|
||||
(
|
||||
output[1],
|
||||
output[2],
|
||||
)
|
||||
),
|
||||
)
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.ln_f_module.device_idx)
|
||||
|
||||
hidden_states = self.ln_f_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.lm_head_module.device_idx)
|
||||
|
||||
logits = self.lm_head_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
logits = torch.tensor(logits).float()
|
||||
|
||||
return torch.argmax(logits[:, -1, :], dim=-1)
|
||||
|
||||
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
past_key_values_length: int = 0,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
|
||||
mask.masked_fill_(intermediate_mask, 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
target_length, past_key_values_length, dtype=dtype
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
expanded_mask = mask[None, None, :, :].expand(
|
||||
batch_size, 1, target_length, target_length + past_key_values_length
|
||||
)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
batch_size, source_length = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else source_length
|
||||
|
||||
expanded_mask = (
|
||||
mask[:, None, None, :]
|
||||
.expand(batch_size, 1, tgt_len, source_length)
|
||||
.to(dtype)
|
||||
)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def _prepare_attn_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
past_key_values_length=past_key_values_length,
|
||||
).to(attention_mask.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
|
||||
def download_560m(destination_folder):
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_0.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_1.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_2.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_3.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_4.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_5.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_6.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_7.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_8.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_9.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_10.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_11.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_12.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_13.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_14.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_15.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_16.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_17.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_18.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_19.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_20.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_21.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_22.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_23.mlir", destination_folder
|
||||
)
|
||||
download_public_file("https://bloom-560m/config.json", destination_folder)
|
||||
download_public_file("https://bloom-560m/lm_head.mlir", destination_folder)
|
||||
download_public_file("https://bloom-560m/ln_f.mlir", destination_folder)
|
||||
download_public_file(
|
||||
"https://bloom-560m/word_embeddings.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/word_embeddings_layernorm.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/tokenizer.json", destination_folder
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(prog="Bloom-560m")
|
||||
parser.add_argument("-p", "--model_path")
|
||||
parser.add_argument("-dl", "--device_list", default=None)
|
||||
parser.add_argument("-de", "--device", default="cpu")
|
||||
parser.add_argument("-c", "--recompile", default=False, type=bool)
|
||||
parser.add_argument("-d", "--download", default=False, type=bool)
|
||||
parser.add_argument("-t", "--token_count", default=10, type=int)
|
||||
parser.add_argument(
|
||||
"-pr",
|
||||
"--prompt",
|
||||
default="The SQL command to extract all the users whose name starts with A is: ",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device_list is not None:
|
||||
args.device_list = json.loads(args.device_list)
|
||||
|
||||
if args.device == "cuda" and args.device_list is not None:
|
||||
IS_CUDA = True
|
||||
if args.download:
|
||||
download_560m(args.model_path)
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
input_ids = tokenizer.encode(args.prompt, return_tensors="pt")
|
||||
|
||||
shardedbloom = ShardedBloom(args.model_path)
|
||||
shardedbloom.init_layers(
|
||||
device=args.device, replace=args.recompile, device_idx=args.device_list
|
||||
)
|
||||
shardedbloom.load_layers()
|
||||
|
||||
for _ in range(args.token_count):
|
||||
next_token = shardedbloom.forward_pass(
|
||||
torch.tensor(input_ids), device=args.device
|
||||
)
|
||||
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
|
||||
|
||||
print(tokenizer.decode(input_ids.squeeze()))
|
||||
@@ -151,7 +151,6 @@ class DLRM_Net(nn.Module):
|
||||
and (ln_top is not None)
|
||||
and (arch_interaction_op is not None)
|
||||
):
|
||||
|
||||
# save arguments
|
||||
self.output_d = 0
|
||||
self.arch_interaction_op = arch_interaction_op
|
||||
@@ -216,7 +215,6 @@ class DLRM_Net(nn.Module):
|
||||
return ly
|
||||
|
||||
def interact_features(self, x, ly):
|
||||
|
||||
if self.arch_interaction_op == "dot":
|
||||
# concatenate dense and sparse features
|
||||
(batch_size, d) = x.shape
|
||||
|
||||
@@ -99,7 +99,6 @@ class SparseArchShark(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, *batched_inputs):
|
||||
|
||||
concatenated_list = []
|
||||
input_enum, embedding_enum = 0, 0
|
||||
|
||||
@@ -121,7 +120,6 @@ class SparseArchShark(nn.Module):
|
||||
|
||||
|
||||
def test_sparse_arch() -> None:
|
||||
|
||||
D = 3
|
||||
eb1_config = EmbeddingBagConfig(
|
||||
name="t1",
|
||||
@@ -211,7 +209,6 @@ class DLRMShark(nn.Module):
|
||||
def forward(
|
||||
self, dense_features: torch.Tensor, *sparse_features
|
||||
) -> torch.Tensor:
|
||||
|
||||
embedded_dense = self.dense_arch(dense_features)
|
||||
embedded_sparse = self.sparse_arch(*sparse_features)
|
||||
concatenated_dense = self.inter_arch(
|
||||
|
||||
@@ -1,268 +0,0 @@
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
from tqdm.auto import tqdm
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import numpy as np
|
||||
|
||||
# pip install diffusers
|
||||
# pip install scipy
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a photograph of an astronaut riding a horse",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument("--steps", type=int, default=10, help="the device to use")
|
||||
p.add_argument("--mlir_loc", type=str, default=None, help="the device to use")
|
||||
p.add_argument("--vae_loc", type=str, default=None, help="the device to use")
|
||||
args = p.parse_args()
|
||||
|
||||
#####################################################
|
||||
|
||||
|
||||
def load_mlir(mlir_loc):
|
||||
import os
|
||||
|
||||
if mlir_loc == None:
|
||||
return None
|
||||
print(f"Trying to load the model from {mlir_loc}.")
|
||||
with open(os.path.join(mlir_loc)) as f:
|
||||
mlir_module = f.read()
|
||||
return mlir_module
|
||||
|
||||
|
||||
def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
|
||||
module = load_mlir(mlir_loc)
|
||||
if mlir_loc == None:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(*inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
inputs,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
# 1. Load the autoencoder model which will be used to decode the latents into image space.
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.vae.decode(input, return_dict=False)[0]
|
||||
|
||||
vae = VaeModel()
|
||||
vae_input = torch.rand(1, 4, 64, 64)
|
||||
shark_vae = compile_through_fx(vae, (vae_input,), args.vae_loc)
|
||||
|
||||
# Wrap the unet model to return tuples.
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
# 3. The UNet model for generating the latents.
|
||||
unet = UnetModel()
|
||||
latent_model_input = torch.rand([2, 4, 64, 64])
|
||||
text_embeddings = torch.rand([2, 77, 768])
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
(latent_model_input, torch.tensor([1.0]), text_embeddings),
|
||||
args.mlir_loc,
|
||||
)
|
||||
|
||||
# torch.jit.script(unet)
|
||||
|
||||
scheduler = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
prompt = [args.prompt]
|
||||
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
|
||||
num_inference_steps = args.steps # Number of denoising steps
|
||||
|
||||
guidance_scale = 7.5 # Scale for classifier-free guidance
|
||||
|
||||
generator = torch.manual_seed(
|
||||
42
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_input = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_embeddings = text_encoder(text_input.input_ids)[0]
|
||||
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer(
|
||||
[""] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = text_encoder(uncond_input.input_ids)[0]
|
||||
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
)
|
||||
# latents = latents.to(torch_device)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
latents = latents * scheduler.sigmas[0]
|
||||
# print(latents, latents.shape)
|
||||
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
|
||||
print(f"i = {i} t = {t}")
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
sigma = scheduler.sigmas[i]
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
# with torch.no_grad():
|
||||
# noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
|
||||
|
||||
latent_model_input_numpy = latent_model_input.detach().numpy()
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
|
||||
noise_pred = shark_unet.forward(
|
||||
(
|
||||
latent_model_input_numpy,
|
||||
np.array([t]).astype(np.float32),
|
||||
text_embeddings_numpy,
|
||||
)
|
||||
)
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
|
||||
|
||||
# print("Latents shape : ", latents.shape)
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents.detach().numpy()
|
||||
image = shark_vae.forward((latents_numpy,))
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
pil_images[0].save("astro.jpg")
|
||||
@@ -1,313 +0,0 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from keras_cv.models.generative.stable_diffusion.clip_tokenizer import (
|
||||
SimpleTokenizer,
|
||||
)
|
||||
from keras_cv.models.generative.stable_diffusion.constants import (
|
||||
_ALPHAS_CUMPROD,
|
||||
)
|
||||
from keras_cv.models.generative.stable_diffusion.constants import (
|
||||
_UNCONDITIONAL_TOKENS,
|
||||
)
|
||||
from keras_cv.models.generative.stable_diffusion.decoder import Decoder
|
||||
from keras_cv.models.generative.stable_diffusion.text_encoder import (
|
||||
TextEncoder,
|
||||
)
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
from PIL import Image
|
||||
|
||||
# pip install "git+https://github.com/keras-team/keras-cv.git"
|
||||
# pip install tensorflow_dataset
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a photograph of an astronaut riding a horse",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument(
|
||||
"--steps", type=int, default=10, help="the number of steps to use"
|
||||
)
|
||||
p.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="the file to save the resulting image to. (default to <input prompt>.jpg)",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
#####################################################
|
||||
|
||||
MAX_PROMPT_LENGTH = 77
|
||||
|
||||
|
||||
class SharkStableDiffusion:
|
||||
"""Shark implementation of Stable Diffusion based on model from keras_cv.
|
||||
Stable Diffusion is a powerful image generation model that can be used,
|
||||
among other things, to generate pictures according to a short text description
|
||||
(called a "prompt").
|
||||
Arguments:
|
||||
device: Device to use with SHARK. Default: cpu
|
||||
jit_compile: Whether to compile the underlying models to XLA.
|
||||
This can lead to a significant speedup on some systems. Default: False.
|
||||
References:
|
||||
- [About Stable Diffusion](https://stability.ai/blog/stable-diffusion-announcement)
|
||||
- [Original implementation](https://github.com/CompVis/stable-diffusion)
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", jit_compile=True):
|
||||
self.img_height = 512
|
||||
self.img_width = 512
|
||||
self.tokenizer = SimpleTokenizer()
|
||||
|
||||
# Create models
|
||||
self.text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_tf_model(
|
||||
"stable_diff", tank_url="gs://shark_tank/quinn"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
self.diffusion_model = shark_module
|
||||
self.decoder = Decoder(self.img_height, self.img_width)
|
||||
if jit_compile:
|
||||
self.text_encoder.compile(jit_compile=True)
|
||||
self.decoder.compile(jit_compile=True)
|
||||
|
||||
print(
|
||||
"By using this model checkpoint, you acknowledge that its usage is "
|
||||
"subject to the terms of the CreativeML Open RAIL-M license at "
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE"
|
||||
)
|
||||
# Load weights
|
||||
text_encoder_weights_fpath = keras.utils.get_file(
|
||||
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
|
||||
file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
|
||||
)
|
||||
decoder_weights_fpath = keras.utils.get_file(
|
||||
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5",
|
||||
file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962",
|
||||
)
|
||||
self.text_encoder.load_weights(text_encoder_weights_fpath)
|
||||
self.decoder.load_weights(decoder_weights_fpath)
|
||||
|
||||
def text_to_image(
|
||||
self,
|
||||
prompt,
|
||||
batch_size=1,
|
||||
num_steps=25,
|
||||
unconditional_guidance_scale=7.5,
|
||||
seed=None,
|
||||
):
|
||||
encoded_text = self.encode_text(prompt)
|
||||
|
||||
return self.generate_image(
|
||||
encoded_text,
|
||||
batch_size=batch_size,
|
||||
num_steps=num_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def encode_text(self, prompt):
|
||||
"""Encodes a prompt into a latent text encoding.
|
||||
The encoding produced by this method should be used as the
|
||||
`encoded_text` parameter of `StableDiffusion.generate_image`. Encoding
|
||||
text separately from generating an image can be used to arbitrarily
|
||||
modify the text encoding priot to image generation, e.g. for walking
|
||||
between two prompts.
|
||||
Args:
|
||||
prompt: a string to encode, must be 77 tokens or shorter.
|
||||
Example:
|
||||
```python
|
||||
from keras_cv.models import StableDiffusion
|
||||
model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
|
||||
encoded_text = model.encode_text("Tacos at dawn")
|
||||
img = model.generate_image(encoded_text)
|
||||
```
|
||||
"""
|
||||
# Tokenize prompt (i.e. starting context)
|
||||
inputs = self.tokenizer.encode(prompt)
|
||||
if len(inputs) > MAX_PROMPT_LENGTH:
|
||||
raise ValueError(
|
||||
f"Prompt is too long (should be <= {MAX_PROMPT_LENGTH} tokens)"
|
||||
)
|
||||
phrase = inputs + [49407] * (MAX_PROMPT_LENGTH - len(inputs))
|
||||
phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)
|
||||
|
||||
context = self.text_encoder.predict_on_batch(
|
||||
[phrase, self._get_pos_ids()]
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
encoded_text,
|
||||
batch_size=1,
|
||||
num_steps=25,
|
||||
unconditional_guidance_scale=7.5,
|
||||
diffusion_noise=None,
|
||||
seed=None,
|
||||
):
|
||||
"""Generates an image based on encoded text.
|
||||
The encoding passed to this method should be derived from
|
||||
`StableDiffusion.encode_text`.
|
||||
Args:
|
||||
encoded_text: Tensor of shape (`batch_size`, 77, 768), or a Tensor
|
||||
of shape (77, 768). When the batch axis is omitted, the same encoded
|
||||
text will be used to produce every generated image.
|
||||
batch_size: number of images to generate. Default: 1.
|
||||
num_steps: number of diffusion steps (controls image quality).
|
||||
Default: 25.
|
||||
unconditional_guidance_scale: float controling how closely the image
|
||||
should adhere to the prompt. Larger values result in more
|
||||
closely adhering to the prompt, but will make the image noisier.
|
||||
Default: 7.5.
|
||||
diffusion_noise: Tensor of shape (`batch_size`, img_height // 8,
|
||||
img_width // 8, 4), or a Tensor of shape (img_height // 8,
|
||||
img_width // 8, 4). Optional custom noise to seed the diffusion
|
||||
process. When the batch axis is omitted, the same noise will be
|
||||
used to seed diffusion for every generated image.
|
||||
seed: integer which is used to seed the random generation of
|
||||
diffusion noise, only to be specified if `diffusion_noise` is
|
||||
None.
|
||||
Example:
|
||||
```python
|
||||
from keras_cv.models import StableDiffusion
|
||||
batch_size = 8
|
||||
model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
|
||||
e_tacos = model.encode_text("Tacos at dawn")
|
||||
e_watermelons = model.encode_text("Watermelons at dusk")
|
||||
e_interpolated = tf.linspace(e_tacos, e_watermelons, batch_size)
|
||||
images = model.generate_image(e_interpolated, batch_size=batch_size)
|
||||
```
|
||||
"""
|
||||
if diffusion_noise is not None and seed is not None:
|
||||
raise ValueError(
|
||||
"`diffusion_noise` and `seed` should not both be passed to "
|
||||
"`generate_image`. `seed` is only used to generate diffusion "
|
||||
"noise when it's not already user-specified."
|
||||
)
|
||||
|
||||
encoded_text = tf.squeeze(encoded_text)
|
||||
if encoded_text.shape.rank == 2:
|
||||
encoded_text = tf.repeat(
|
||||
tf.expand_dims(encoded_text, axis=0), batch_size, axis=0
|
||||
)
|
||||
|
||||
context = encoded_text
|
||||
unconditional_context = tf.repeat(
|
||||
self._get_unconditional_context(), batch_size, axis=0
|
||||
)
|
||||
context = tf.concat([context, unconditional_context], 0)
|
||||
|
||||
if diffusion_noise is not None:
|
||||
diffusion_noise = tf.squeeze(diffusion_noise)
|
||||
if diffusion_noise.shape.rank == 3:
|
||||
diffusion_noise = tf.repeat(
|
||||
tf.expand_dims(diffusion_noise, axis=0), batch_size, axis=0
|
||||
)
|
||||
latent = diffusion_noise
|
||||
else:
|
||||
latent = self._get_initial_diffusion_noise(batch_size, seed)
|
||||
|
||||
# Iterative reverse diffusion stage
|
||||
timesteps = tf.range(1, 1000, 1000 // num_steps)
|
||||
alphas, alphas_prev = self._get_initial_alphas(timesteps)
|
||||
progbar = keras.utils.Progbar(len(timesteps))
|
||||
iteration = 0
|
||||
for index, timestep in list(enumerate(timesteps))[::-1]:
|
||||
latent_prev = latent # Set aside the previous latent vector
|
||||
t_emb = self._get_timestep_embedding(timestep, batch_size)
|
||||
|
||||
# Prepare the latent and unconditional latent to be run with a single forward call
|
||||
latent = tf.concat([latent, latent], 0)
|
||||
t_emb = tf.concat([t_emb, t_emb], 0)
|
||||
latent_numpy = self.diffusion_model.forward(
|
||||
[latent.numpy(), t_emb.numpy(), context.numpy()]
|
||||
)
|
||||
latent = tf.convert_to_tensor(latent_numpy, dtype=tf.float32)
|
||||
latent, unconditional_latent = tf.split(latent, 2)
|
||||
|
||||
latent = unconditional_latent + unconditional_guidance_scale * (
|
||||
latent - unconditional_latent
|
||||
)
|
||||
a_t, a_prev = alphas[index], alphas_prev[index]
|
||||
pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(
|
||||
a_t
|
||||
)
|
||||
latent = (
|
||||
latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0
|
||||
)
|
||||
iteration += 1
|
||||
progbar.update(iteration)
|
||||
|
||||
# Decoding stage
|
||||
decoded = self.decoder.predict_on_batch(latent)
|
||||
decoded = ((decoded + 1) / 2) * 255
|
||||
return np.clip(decoded, 0, 255).astype("uint8")
|
||||
|
||||
def _get_unconditional_context(self):
|
||||
unconditional_tokens = tf.convert_to_tensor(
|
||||
[_UNCONDITIONAL_TOKENS], dtype=tf.int32
|
||||
)
|
||||
unconditional_context = self.text_encoder.predict_on_batch(
|
||||
[unconditional_tokens, self._get_pos_ids()]
|
||||
)
|
||||
|
||||
return unconditional_context
|
||||
|
||||
def _get_timestep_embedding(
|
||||
self, timestep, batch_size, dim=320, max_period=10000
|
||||
):
|
||||
half = dim // 2
|
||||
freqs = tf.math.exp(
|
||||
-math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
|
||||
)
|
||||
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
|
||||
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
|
||||
embedding = tf.reshape(embedding, [1, -1])
|
||||
return tf.repeat(embedding, batch_size, axis=0)
|
||||
|
||||
def _get_initial_alphas(self, timesteps):
|
||||
alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
|
||||
alphas_prev = [1.0] + alphas[:-1]
|
||||
|
||||
return alphas, alphas_prev
|
||||
|
||||
def _get_initial_diffusion_noise(self, batch_size, seed):
|
||||
return tf.random.normal(
|
||||
(batch_size, self.img_height // 8, self.img_width // 8, 4),
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_pos_ids():
|
||||
return tf.convert_to_tensor(
|
||||
[list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SD = SharkStableDiffusion(device=args.device)
|
||||
images = SD.text_to_image(args.prompt, num_steps=args.steps)
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
save_fname = args.prompt + ".jpg"
|
||||
if args.save_path is not None:
|
||||
save_fname = args.save_path
|
||||
pil_images[0].save(save_fname)
|
||||
@@ -18,7 +18,7 @@ class T5Module(tf.Module):
|
||||
self.m = TFT5Model.from_pretrained("t5-small")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, decoder_input_ids=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs)
|
||||
@tf.function(input_signature=t5_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, decoder_input_ids):
|
||||
return self.m.predict(input_ids, decoder_input_ids)
|
||||
|
||||
|
||||
21
shark/examples/shark_inference/upscaler/main.py
Normal file
21
shark/examples/shark_inference/upscaler/main.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from pipeline_shark_stable_diffusion_upscale import (
|
||||
SharkStableDiffusionUpscalePipeline,
|
||||
)
|
||||
import torch
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipeline = SharkStableDiffusionUpscalePipeline(model_id)
|
||||
|
||||
# let's download an image
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
|
||||
response = requests.get(url)
|
||||
low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
low_res_img = low_res_img.resize((128, 128))
|
||||
|
||||
prompt = "a white cat"
|
||||
|
||||
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
|
||||
upscaled_image.save("upsampled_cat.png")
|
||||
98
shark/examples/shark_inference/upscaler/model_wrappers.py
Normal file
98
shark/examples/shark_inference/upscaler/model_wrappers.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from utils import compile_through_fx
|
||||
import torch
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
|
||||
model_input = {
|
||||
"clip": (torch.randint(1, 2, (1, 77)),),
|
||||
"vae": (torch.randn(1, 4, 128, 128),),
|
||||
"unet": (
|
||||
torch.randn(2, 7, 128, 128), # latents
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, 77, 1024), # embedding
|
||||
torch.randn(2).to(torch.int64), # noise_level
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
)
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
model_input["clip"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
|
||||
def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
return x
|
||||
|
||||
vae = VaeModel()
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
model_input["vae"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, latent, timestep, text_embedding, noise_level):
|
||||
unet_out = self.unet.forward(
|
||||
latent,
|
||||
timestep,
|
||||
text_embedding,
|
||||
noise_level,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
return unet_out
|
||||
|
||||
unet = UnetModel()
|
||||
f16_input_mask = (True, True, True, False)
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
model_input["unet"],
|
||||
model_name=model_name,
|
||||
is_f16=True,
|
||||
f16_input_mask=f16_input_mask,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_unet
|
||||
48
shark/examples/shark_inference/upscaler/opt_params.py
Normal file
48
shark/examples/shark_inference/upscaler/opt_params.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import sys
|
||||
from model_wrappers import (
|
||||
get_vae_mlir,
|
||||
get_unet_mlir,
|
||||
get_clip_mlir,
|
||||
)
|
||||
from upscaler_args import args
|
||||
from utils import get_shark_model
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
if BATCH_SIZE != 1:
|
||||
sys.exit("Only batch size 1 is supported.")
|
||||
|
||||
|
||||
unet_flag = [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
]
|
||||
|
||||
vae_flag = [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
|
||||
clip_flag = [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
|
||||
bucket = "gs://shark_tank/stable_diffusion/"
|
||||
|
||||
|
||||
def get_unet():
|
||||
model_name = "upscaler_unet"
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(model_name, unet_flag)
|
||||
return get_shark_model(bucket, model_name, unet_flag)
|
||||
|
||||
|
||||
def get_vae():
|
||||
model_name = "upscaler_vae"
|
||||
if args.import_mlir:
|
||||
return get_vae_mlir(model_name, vae_flag)
|
||||
return get_shark_model(bucket, model_name, vae_flag)
|
||||
|
||||
|
||||
def get_clip():
|
||||
model_name = "upscaler_clip"
|
||||
if args.import_mlir:
|
||||
return get_clip_mlir(model_name, clip_flag)
|
||||
return get_shark_model(bucket, model_name, clip_flag)
|
||||
@@ -0,0 +1,489 @@
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers import logging
|
||||
from diffusers.pipeline_utils import ImagePipelineOutput
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
|
||||
image = [np.array(i.resize((w, h)))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
def shark_run_wrapper(model, *args):
|
||||
np_inputs = tuple([x.detach().numpy() for x in args])
|
||||
outputs = model("forward", np_inputs)
|
||||
return torch.from_numpy(outputs)
|
||||
|
||||
|
||||
class SharkStableDiffusionUpscalePipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_id,
|
||||
):
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
model_id, subfolder="tokenizer"
|
||||
)
|
||||
self.low_res_scheduler = DDPMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
self.scheduler = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
self.vae = get_vae()
|
||||
self.unet = get_unet()
|
||||
self.text_encoder = get_clip()
|
||||
self.max_noise_level = (350,)
|
||||
self._execution_device = "cpu"
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
Args:
|
||||
prompt (`str` or `list(int)`):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
"""
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(
|
||||
prompt, padding="longest", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
||||
-1
|
||||
] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
# if (
|
||||
# hasattr(self.text_encoder.config, "use_attention_mask")
|
||||
# and self.text_encoder.config.use_attention_mask
|
||||
# ):
|
||||
# attention_mask = text_inputs.attention_mask.to(device)
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
text_embeddings = shark_run_wrapper(
|
||||
self.text_encoder, text_input_ids.to(device)
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(
|
||||
bs_embed * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# if (
|
||||
# hasattr(self.text_encoder.config, "use_attention_mask")
|
||||
# and self.text_encoder.config.use_attention_mask
|
||||
# ):
|
||||
# attention_mask = uncond_input.attention_mask.to(device)
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
uncond_embeddings = shark_run_wrapper(
|
||||
self.text_encoder,
|
||||
uncond_input.input_ids.to(device),
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(
|
||||
1, num_images_per_prompt, 1
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
image = shark_run_wrapper(self.vae, latents)
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def check_inputs(self, prompt, image, noise_level, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(
|
||||
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
||||
)
|
||||
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
)
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor
|
||||
if isinstance(image, list) or isinstance(image, torch.Tensor):
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = len(prompt)
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
else:
|
||||
image_batch_size = image.shape[0]
|
||||
if batch_size != image_batch_size:
|
||||
raise ValueError(
|
||||
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
|
||||
" Please make sure that passed `prompt` matches the batch size of `image`."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
if images.shape[-1] == 1:
|
||||
# special case for grayscale (single channel) images
|
||||
pil_images = [
|
||||
Image.fromarray(image.squeeze(), mode="L") for image in images
|
||||
]
|
||||
else:
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
if latents is None:
|
||||
if device == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(
|
||||
shape, generator=generator, device="cpu", dtype=dtype
|
||||
).to(device)
|
||||
else:
|
||||
latents = torch.randn(
|
||||
shape, generator=generator, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(
|
||||
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
|
||||
)
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[
|
||||
torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]
|
||||
],
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 20,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[
|
||||
Union[torch.Generator, List[torch.Generator]]
|
||||
] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[
|
||||
Callable[[int, int, torch.FloatTensor], None]
|
||||
] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
):
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, image, noise_level, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = image.to(dtype=text_embeddings.dtype, device=device)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor(
|
||||
[noise_level], dtype=torch.long, device=device
|
||||
)
|
||||
if device == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
noise = torch.randn(
|
||||
image.shape,
|
||||
generator=generator,
|
||||
device="cpu",
|
||||
dtype=text_embeddings.dtype,
|
||||
).to(device)
|
||||
else:
|
||||
noise = torch.randn(
|
||||
image.shape,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=text_embeddings.dtype,
|
||||
)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
image = torch.cat([image] * batch_multiplier * num_images_per_prompt)
|
||||
noise_level = torch.cat([noise_level] * image.shape[0])
|
||||
|
||||
# 6. Prepare latent variables
|
||||
height, width = image.shape[2:]
|
||||
# num_channels_latents = self.vae.config.latent_channels
|
||||
num_channels_latents = 4
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Check that sizes of image and latents match
|
||||
num_channels_image = image.shape[1]
|
||||
# if (
|
||||
# num_channels_latents + num_channels_image
|
||||
# != self.unet.config.in_channels
|
||||
# ):
|
||||
# raise ValueError(
|
||||
# f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
# f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
# f" `num_channels_image`: {num_channels_image} "
|
||||
# f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
||||
# " `pipeline.unet` or your `image` input."
|
||||
# )
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = (
|
||||
len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
)
|
||||
for i, t in tqdm(enumerate(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
latent_model_input = torch.cat([latent_model_input, image], dim=1)
|
||||
|
||||
timestep = torch.tensor([t]).to(torch.float32)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = shark_run_wrapper(
|
||||
self.unet,
|
||||
latent_model_input.half(),
|
||||
timestep,
|
||||
text_embeddings.half(),
|
||||
noise_level,
|
||||
)
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# # call the callback, if provided
|
||||
# if i == len(timesteps) - 1 or (
|
||||
# (i + 1) > num_warmup_steps
|
||||
# and (i + 1) % self.scheduler.order == 0
|
||||
# ):
|
||||
# progress_bar.update()
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
# self.vae.to(dtype=torch.float32)
|
||||
image = self.decode_latents(latents.float())
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
111
shark/examples/shark_inference/upscaler/upscaler_args.py
Normal file
111
shark/examples/shark_inference/upscaler/upscaler_args.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=["cyberpunk forest by Salvador Dali"],
|
||||
help="text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative-prompts",
|
||||
nargs="+",
|
||||
default=[""],
|
||||
help="text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the no. of steps to do the sampling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed to use.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="the value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Model Config and Usage Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="vulkan", help="device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp16", help="precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="saves the compiled flatbuffer to the local directory",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree-vulkan-target-triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Profiles vulkan device and collects the .rdc info",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="4147483648",
|
||||
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for disabling vulkan validation layers when benchmarking",
|
||||
)
|
||||
|
||||
|
||||
args = p.parse_args()
|
||||
232
shark/examples/shark_inference/upscaler/utils.py
Normal file
232
shark/examples/shark_inference/upscaler/utils.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import os
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from upscaler_args import args
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
print("Saving to {}".format(vmfb_path))
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(
|
||||
vmfb_path
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.parser import shark_args
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
# shark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=tank_url,
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
# Converts the torch-module into a shark_module.
|
||||
def compile_through_fx(
|
||||
model, inputs, model_name, is_f16=False, f16_input_mask=None, extra_args=[]
|
||||
):
|
||||
mlir_module, func_name = import_with_fx(
|
||||
model, inputs, is_f16, f16_input_mask
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
|
||||
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
|
||||
]
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
"""
|
||||
Inputs: driver_name
|
||||
Returns a list of all the available devices for a given driver sorted by
|
||||
the iree path names of the device as in --list_devices option in iree.
|
||||
"""
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return device_list_src
|
||||
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired combination of name/path.
|
||||
"""
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
|
||||
def get_output_value(dev_dict):
|
||||
if key_combination == 1:
|
||||
return f"{driver}://{dev_dict['path']}"
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
|
||||
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
return device_map
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
try:
|
||||
device_mapping = device_map[device]
|
||||
except KeyError:
|
||||
raise ValueError(f"Device '{device}' is not a valid device.")
|
||||
return device_mapping
|
||||
|
||||
|
||||
def set_init_device_flags():
|
||||
if "vulkan" in args.device:
|
||||
# set runtime flags for vulkan.
|
||||
set_iree_runtime_flags()
|
||||
|
||||
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
||||
device_name, args.device = map_device_to_name_path(args.device)
|
||||
if not args.iree_vulkan_target_triple:
|
||||
triple = get_vulkan_target_triple(device_name)
|
||||
if triple is not None:
|
||||
args.iree_vulkan_target_triple = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
|
||||
)
|
||||
elif "cuda" in args.device:
|
||||
args.device = "cuda"
|
||||
elif "cpu" in args.device:
|
||||
args.device = "cpu"
|
||||
|
||||
# set max_length based on availability.
|
||||
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
args.max_length = 77
|
||||
elif args.variant == "openjourney":
|
||||
args.max_length = 64
|
||||
|
||||
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
|
||||
if (
|
||||
args.variant in ["openjourney", "dreamlike"]
|
||||
or args.precision != "fp16"
|
||||
or "vulkan" not in args.device
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
):
|
||||
args.use_tuned = False
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
elif args.use_base_vae and args.variant != "stablediffusion":
|
||||
args.use_tuned = False
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
if args.use_tuned:
|
||||
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
|
||||
|
||||
|
||||
# Utility to get list of devices available.
|
||||
def get_available_devices():
|
||||
def get_devices_by_name(driver_name):
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
driver_name = iree_device_map(driver_name)
|
||||
device_list_dict = get_all_devices(driver_name)
|
||||
print(f"{driver_name} devices are available.")
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_list.append(f"{driver_name}://{i} => {device['name']}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
vulkan_devices = get_devices_by_name("vulkan")
|
||||
available_devices.extend(vulkan_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
available_devices.append("cpu")
|
||||
return available_devices
|
||||
@@ -1,8 +1,10 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"v_diffusion", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="vulkan", mlir_dialect="linalg"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch.nn.utils import _stateless
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_runner import SharkTrainer
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
@@ -42,6 +42,7 @@ def forward(params, buffers, args):
|
||||
return params, buffers
|
||||
|
||||
|
||||
shark_module = SharkTrainer(mod, inp, custom_inference_fn=forward)
|
||||
shark_module = SharkTrainer(mod, inp)
|
||||
shark_module.compile(forward)
|
||||
|
||||
print(shark_module.forward())
|
||||
print(shark_module.train())
|
||||
|
||||
@@ -52,7 +52,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user