mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 07:28:06 -05:00
Compare commits
889 Commits
v4.2.9.dev
...
v5.0.0.a5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
865f419ffe | ||
|
|
dc2ed2f1c1 | ||
|
|
818a7a01cc | ||
|
|
218b216ebd | ||
|
|
eee3b7acf3 | ||
|
|
0a923cc77b | ||
|
|
07df5c2d39 | ||
|
|
153533157f | ||
|
|
3898e09b8c | ||
|
|
277498bceb | ||
|
|
08078440cb | ||
|
|
73ea5cb42a | ||
|
|
3aaaae4d1c | ||
|
|
fbe02e3d1d | ||
|
|
790195854e | ||
|
|
a1179eb227 | ||
|
|
e5792278b9 | ||
|
|
452e612235 | ||
|
|
241fc18d69 | ||
|
|
eafd3d1ec7 | ||
|
|
8506d98f34 | ||
|
|
cb53772722 | ||
|
|
5a8ce724db | ||
|
|
867026f11f | ||
|
|
a802b92684 | ||
|
|
fd19b0691a | ||
|
|
91a08086c0 | ||
|
|
afe3942024 | ||
|
|
2b3e4e123d | ||
|
|
2622f7dc02 | ||
|
|
1a598873de | ||
|
|
12cab9fc31 | ||
|
|
6fb8e45761 | ||
|
|
637960d67e | ||
|
|
d2ab668fa0 | ||
|
|
82df16d8ce | ||
|
|
dd3013d333 | ||
|
|
269db8ae19 | ||
|
|
30ea852761 | ||
|
|
c03f80b19c | ||
|
|
96930055e2 | ||
|
|
5fa7f0154f | ||
|
|
ab0e9dfcad | ||
|
|
88dcb388dc | ||
|
|
5a89bf841f | ||
|
|
5b8707a74f | ||
|
|
cfb538bdc2 | ||
|
|
9f06a9b03c | ||
|
|
561db0751b | ||
|
|
248e4a81b2 | ||
|
|
b6aba92426 | ||
|
|
7d15f9381d | ||
|
|
4f2fc65257 | ||
|
|
68237d357a | ||
|
|
bb2db3d6c3 | ||
|
|
ff94146ee8 | ||
|
|
1d09091a67 | ||
|
|
ee4c0efbf7 | ||
|
|
a4250e3ff2 | ||
|
|
67a234c1bb | ||
|
|
420045cb34 | ||
|
|
53792fafb3 | ||
|
|
615eddea6f | ||
|
|
b3d60bd56a | ||
|
|
fd42da5a36 | ||
|
|
bc55791db1 | ||
|
|
c5f3297841 | ||
|
|
cd2c2a7fde | ||
|
|
1cffcc02a5 | ||
|
|
ac9950bdbb | ||
|
|
059d57f447 | ||
|
|
581008b432 | ||
|
|
aeaeec9b9d | ||
|
|
301739c4a8 | ||
|
|
a2e2a31b95 | ||
|
|
88c276cd09 | ||
|
|
457871af93 | ||
|
|
e88d4aa0e8 | ||
|
|
c8a74f969b | ||
|
|
4240817128 | ||
|
|
80877a1f15 | ||
|
|
7fc25e7e01 | ||
|
|
9a355c5585 | ||
|
|
2975ec5467 | ||
|
|
8ab3b938c1 | ||
|
|
f82640b5df | ||
|
|
e3e50abc5a | ||
|
|
061bff2814 | ||
|
|
e5a53be42b | ||
|
|
54c94bd713 | ||
|
|
8d56becf04 | ||
|
|
dc51ccd9a6 | ||
|
|
f5eefedc49 | ||
|
|
136891ec3d | ||
|
|
c5543e42c7 | ||
|
|
edae8a1617 | ||
|
|
9c1cf3e860 | ||
|
|
b6cef9d440 | ||
|
|
ebb92bee26 | ||
|
|
d6c553ca5e | ||
|
|
8b6512cc90 | ||
|
|
a6b998c125 | ||
|
|
5275782533 | ||
|
|
ede3bd8e64 | ||
|
|
da2583b894 | ||
|
|
9210970130 | ||
|
|
2a022a811c | ||
|
|
1a53e8dc5c | ||
|
|
4e12e23b69 | ||
|
|
fd56b35982 | ||
|
|
71e0abe653 | ||
|
|
56956ccf78 | ||
|
|
6d46d82028 | ||
|
|
3ed29a16a8 | ||
|
|
b67c369bdb | ||
|
|
e774b6879e | ||
|
|
e7d95c3724 | ||
|
|
1b65884dbe | ||
|
|
eff9ddc980 | ||
|
|
400ef8cdc3 | ||
|
|
b0ec3de40a | ||
|
|
b38b8bc90c | ||
|
|
a5ab5e5146 | ||
|
|
61fc30b345 | ||
|
|
46d0ba8ce2 | ||
|
|
5a3e0d76d9 | ||
|
|
5eb919f602 | ||
|
|
2301b388e8 | ||
|
|
dbf13999a0 | ||
|
|
a37592f9f3 | ||
|
|
60d4514fd8 | ||
|
|
9709da901c | ||
|
|
44df59e9e9 | ||
|
|
fbe80ceab2 | ||
|
|
a86822db4d | ||
|
|
f024cb1d05 | ||
|
|
6b2d900b54 | ||
|
|
3d6d5affb5 | ||
|
|
99b683fc1f | ||
|
|
d5cd50c3ea | ||
|
|
d7cde0fc23 | ||
|
|
541605edb4 | ||
|
|
0194344de2 | ||
|
|
34f3cb3116 | ||
|
|
5ab4818eb6 | ||
|
|
60d2541934 | ||
|
|
8d87549ebe | ||
|
|
4cb5854990 | ||
|
|
6f4d3d0395 | ||
|
|
93e9e64b3a | ||
|
|
2bdfc340aa | ||
|
|
2a1bc3e044 | ||
|
|
b4d006d14b | ||
|
|
464603e0ea | ||
|
|
864e471e5a | ||
|
|
670e054fe0 | ||
|
|
0abd81ac80 | ||
|
|
1870daffa1 | ||
|
|
d6d27a82a6 | ||
|
|
ff0d2fcc92 | ||
|
|
a2969816fa | ||
|
|
6b20d1564d | ||
|
|
bf484bc90e | ||
|
|
fc58d34d25 | ||
|
|
c15793b794 | ||
|
|
1e32be827e | ||
|
|
8422908b70 | ||
|
|
d10ff59f9c | ||
|
|
eab1f50a6f | ||
|
|
6e346884e3 | ||
|
|
1c9fd1f19a | ||
|
|
28385d06d1 | ||
|
|
12e6f1be89 | ||
|
|
e1a66e22e9 | ||
|
|
b3569e5c0d | ||
|
|
c64693fffd | ||
|
|
ce9f17726f | ||
|
|
5f62dc6699 | ||
|
|
07cb12eef7 | ||
|
|
9e9f465552 | ||
|
|
e148cc810b | ||
|
|
160f54d1ea | ||
|
|
480856a528 | ||
|
|
97aad2ab2f | ||
|
|
2b93dbd96a | ||
|
|
ce4c79a8d9 | ||
|
|
151b4efd3f | ||
|
|
16806e5d8d | ||
|
|
8e01d295db | ||
|
|
fd00e40ca7 | ||
|
|
029158ef3a | ||
|
|
96b74f4a79 | ||
|
|
b1e85f8b60 | ||
|
|
aa418f0aba | ||
|
|
8b747b022b | ||
|
|
ed4b5dfac3 | ||
|
|
b189937bc9 | ||
|
|
e176e48fa3 | ||
|
|
4931bdace5 | ||
|
|
c3b52a1853 | ||
|
|
b201541cb0 | ||
|
|
ba54a05efd | ||
|
|
6746870591 | ||
|
|
542844c6a3 | ||
|
|
4e5f4dadf2 | ||
|
|
1c15c2cb03 | ||
|
|
a041f1f388 | ||
|
|
d0b62c88c9 | ||
|
|
0fd4dd4513 | ||
|
|
4d3ed34232 | ||
|
|
74de22349d | ||
|
|
18ad271225 | ||
|
|
f92730080c | ||
|
|
f83b500645 | ||
|
|
1349e73a1a | ||
|
|
1fdb702557 | ||
|
|
4df531b7c0 | ||
|
|
a5a077964e | ||
|
|
944719cb9c | ||
|
|
92ae679314 | ||
|
|
771c3210b7 | ||
|
|
517946f66e | ||
|
|
eb09253b4e | ||
|
|
d81cd050ef | ||
|
|
ae5ed18f12 | ||
|
|
9026180533 | ||
|
|
437ea1109b | ||
|
|
95177a7389 | ||
|
|
d01af064f9 | ||
|
|
d50ee14d0b | ||
|
|
096e8deac5 | ||
|
|
e3b6ad7076 | ||
|
|
23c93509e0 | ||
|
|
f5eb6a06b5 | ||
|
|
db99b773bc | ||
|
|
daa0064947 | ||
|
|
ea062ab01a | ||
|
|
0c81a435f4 | ||
|
|
be7254dbf8 | ||
|
|
f49cee976d | ||
|
|
c246fc98b3 | ||
|
|
45e155d392 | ||
|
|
c82e17916f | ||
|
|
d9359bac23 | ||
|
|
ae65f89999 | ||
|
|
dd8b25260d | ||
|
|
4f76f5f848 | ||
|
|
3cdc5d869f | ||
|
|
19aa747b8f | ||
|
|
e20ae31d96 | ||
|
|
09fd415527 | ||
|
|
50768a957e | ||
|
|
3942e2a501 | ||
|
|
1a51842277 | ||
|
|
d001a36e14 | ||
|
|
8c65f60e7d | ||
|
|
d48ce8168e | ||
|
|
a955ab6bee | ||
|
|
81bfd4cc08 | ||
|
|
65f1944a93 | ||
|
|
b68845f43f | ||
|
|
bb994751ee | ||
|
|
f3aad7a494 | ||
|
|
80a69e0867 | ||
|
|
e2f2bdbbc2 | ||
|
|
ecda2b1681 | ||
|
|
d00e006784 | ||
|
|
9a6411f2c8 | ||
|
|
b05b0281af | ||
|
|
fb9bce6636 | ||
|
|
92eebd6aaf | ||
|
|
4484981c97 | ||
|
|
8cff753c81 | ||
|
|
b5681f1657 | ||
|
|
abb74fa664 | ||
|
|
ff88536b4a | ||
|
|
cb20c3b313 | ||
|
|
e8335fe7c4 | ||
|
|
749ff3eb71 | ||
|
|
6877db12c9 | ||
|
|
bbdbe36ada | ||
|
|
fca09d79cc | ||
|
|
719cc12d82 | ||
|
|
b8fed9a554 | ||
|
|
e0ea8b72a6 | ||
|
|
df41564c4c | ||
|
|
42ec07daad | ||
|
|
f33e3d63d5 | ||
|
|
451ee78f31 | ||
|
|
65ea492a75 | ||
|
|
afb35d9717 | ||
|
|
f6624322d8 | ||
|
|
00a4504406 | ||
|
|
2d737f824c | ||
|
|
174c136abc | ||
|
|
eb4dcf4453 | ||
|
|
df6ee189db | ||
|
|
d558aefcc7 | ||
|
|
2adffc84d4 | ||
|
|
5b1035d64c | ||
|
|
da48a5d533 | ||
|
|
f22366a427 | ||
|
|
7def35b1c0 | ||
|
|
ace87948dd | ||
|
|
04555f3916 | ||
|
|
dce1fb0d02 | ||
|
|
1617ee0e6f | ||
|
|
ee94ac3d32 | ||
|
|
10066b349b | ||
|
|
db8084fda1 | ||
|
|
f85536de22 | ||
|
|
7c47e7cfc3 | ||
|
|
37ee1ab35b | ||
|
|
488b682489 | ||
|
|
9601d99c01 | ||
|
|
56aa6a3114 | ||
|
|
4f60cec997 | ||
|
|
e012832386 | ||
|
|
b9ce1cfc16 | ||
|
|
17dd8bb37b | ||
|
|
459d59aac4 | ||
|
|
5cb26fac9f | ||
|
|
3b8c9bb34b | ||
|
|
f9d380107c | ||
|
|
f8b60da938 | ||
|
|
f5fd25d235 | ||
|
|
0097958f62 | ||
|
|
7f8e0c00d9 | ||
|
|
1ef5db035d | ||
|
|
89ff9b8b88 | ||
|
|
bac0ce1e69 | ||
|
|
04f78a99ad | ||
|
|
f4d8809758 | ||
|
|
06dd144c92 | ||
|
|
9b3ec12a3e | ||
|
|
82d50bfcc9 | ||
|
|
7563214a6d | ||
|
|
d99dbdfe7c | ||
|
|
d9fe16bab4 | ||
|
|
db50525442 | ||
|
|
e8190f4389 | ||
|
|
e5e59bf801 | ||
|
|
dd7d4da5e3 | ||
|
|
f394584dff | ||
|
|
1a06b5f1c6 | ||
|
|
9a089495a1 | ||
|
|
c5c8859463 | ||
|
|
6a6efc4574 | ||
|
|
e6bc861ebf | ||
|
|
1499cea82e | ||
|
|
f55282f9bf | ||
|
|
452784068b | ||
|
|
e6b841126b | ||
|
|
31ce4f9283 | ||
|
|
60b3dc846e | ||
|
|
7bb2dc0075 | ||
|
|
7f437adaba | ||
|
|
5a1309cf6e | ||
|
|
f56648be3c | ||
|
|
15735dda6e | ||
|
|
1f1777f7a6 | ||
|
|
167c8ba4ec | ||
|
|
cc7ae42baa | ||
|
|
5fe844c5d9 | ||
|
|
23248dad90 | ||
|
|
caeefdf2ed | ||
|
|
d40d6291a0 | ||
|
|
fd38668f55 | ||
|
|
583654d176 | ||
|
|
59cba2f860 | ||
|
|
772f0b80a1 | ||
|
|
8d8272ee53 | ||
|
|
fef1dddd50 | ||
|
|
725da6e875 | ||
|
|
257b18230a | ||
|
|
a8de6406c5 | ||
|
|
dd2e68bf00 | ||
|
|
7825e325df | ||
|
|
33b3268f83 | ||
|
|
3dbd8212aa | ||
|
|
3694f337bc | ||
|
|
ab77997746 | ||
|
|
5fa7910664 | ||
|
|
8dbb473fde | ||
|
|
4a1240a709 | ||
|
|
664987f2aa | ||
|
|
9e391ec431 | ||
|
|
06944b3ea7 | ||
|
|
f48b949aa8 | ||
|
|
b4166083c5 | ||
|
|
56d53b18f0 | ||
|
|
20961215e7 | ||
|
|
49c75ca381 | ||
|
|
cf6751cc06 | ||
|
|
6cc828b628 | ||
|
|
ddeffb3ef1 | ||
|
|
95b606683f | ||
|
|
0598b89738 | ||
|
|
c2be63a811 | ||
|
|
639304197b | ||
|
|
c4a85cf1bf | ||
|
|
cff80524a8 | ||
|
|
2d1b13bde7 | ||
|
|
220b78d0e7 | ||
|
|
efb97c301e | ||
|
|
cd865347eb | ||
|
|
54ccb9846d | ||
|
|
22a2849683 | ||
|
|
2bae67cfe9 | ||
|
|
de8e8d9f68 | ||
|
|
eced34a72a | ||
|
|
591e8162c1 | ||
|
|
f4998bc308 | ||
|
|
39a49fb585 | ||
|
|
2b9073da36 | ||
|
|
d3aa54f7bd | ||
|
|
f0a959f6fe | ||
|
|
9a5b702013 | ||
|
|
018807d678 | ||
|
|
cf5e8bf4ea | ||
|
|
03ae65863c | ||
|
|
3b7b6d6404 | ||
|
|
e9171c80f6 | ||
|
|
0fd3881b3a | ||
|
|
01ac4c3b3e | ||
|
|
f1fcc98a09 | ||
|
|
b2823569f0 | ||
|
|
3bd98e62de | ||
|
|
318672be53 | ||
|
|
c5a05691fe | ||
|
|
04fcb9e8e6 | ||
|
|
a1534b6503 | ||
|
|
0aa4b1575d | ||
|
|
85eb6ad616 | ||
|
|
9fd2841df0 | ||
|
|
bd23dcd751 | ||
|
|
4d480093d9 | ||
|
|
bb0d2b6ce2 | ||
|
|
0d863a876b | ||
|
|
3fadfd3bbb | ||
|
|
401152f16f | ||
|
|
b69350e9ee | ||
|
|
7b429e0a54 | ||
|
|
3d23fe1fe0 | ||
|
|
d4117f5595 | ||
|
|
2686210887 | ||
|
|
9a804b7986 | ||
|
|
ef0699310d | ||
|
|
afa2da3d2d | ||
|
|
ac1132b5bc | ||
|
|
0276dac38f | ||
|
|
5a3dd83167 | ||
|
|
9f587009cd | ||
|
|
c5ed5e866e | ||
|
|
1f10bc1d63 | ||
|
|
311451b3c9 | ||
|
|
a48e5d9cb0 | ||
|
|
ad92010778 | ||
|
|
01e8988fcc | ||
|
|
d6fec0a0df | ||
|
|
37dc7ee595 | ||
|
|
6d79dc61d2 | ||
|
|
966bc67001 | ||
|
|
4c66a0dcd0 | ||
|
|
50051ee147 | ||
|
|
621f12a1bc | ||
|
|
741b22041d | ||
|
|
f358bb9364 | ||
|
|
65bbc0f00f | ||
|
|
7bf0e554ea | ||
|
|
82b1d8dab8 | ||
|
|
5dda364b2c | ||
|
|
c4e95684b5 | ||
|
|
a0d644ac42 | ||
|
|
37198159c9 | ||
|
|
7170adf3a2 | ||
|
|
cc50578faf | ||
|
|
e80d8b4365 | ||
|
|
30050a23b9 | ||
|
|
706a3c8f2b | ||
|
|
384601898a | ||
|
|
94eb5e638f | ||
|
|
5629c54d55 | ||
|
|
1303396d0e | ||
|
|
bcd5bcf8d7 | ||
|
|
787a4422cb | ||
|
|
5d52633c78 | ||
|
|
1d45444104 | ||
|
|
dd84f2ca64 | ||
|
|
b1c4a91de0 | ||
|
|
187ef3548e | ||
|
|
4abf24a2f6 | ||
|
|
2435ce34be | ||
|
|
e7841824ef | ||
|
|
10596073ac | ||
|
|
405994ee7a | ||
|
|
534d4fa495 | ||
|
|
2aa413d44f | ||
|
|
e6ebb0390e | ||
|
|
5fb9ffca6f | ||
|
|
bd62bab91f | ||
|
|
54edd3f101 | ||
|
|
a889a762b8 | ||
|
|
2163f65be7 | ||
|
|
78471b4bc3 | ||
|
|
af99238a96 | ||
|
|
4e5937036d | ||
|
|
6edc7bbd1d | ||
|
|
db437da726 | ||
|
|
95a9bacd01 | ||
|
|
e95e776733 | ||
|
|
760c7a3076 | ||
|
|
7dd1aec767 | ||
|
|
976b1a5fee | ||
|
|
b79a5e46e2 | ||
|
|
02ddfc5aac | ||
|
|
57f3107dba | ||
|
|
acde3d8952 | ||
|
|
be4983fcbb | ||
|
|
39c8bded65 | ||
|
|
e8f678adde | ||
|
|
e1666c85b7 | ||
|
|
6469cd6e24 | ||
|
|
b6032fd186 | ||
|
|
7a546349e4 | ||
|
|
375c7494b6 | ||
|
|
ac0cc91046 | ||
|
|
918254b600 | ||
|
|
814c3bed09 | ||
|
|
d94ceb25b0 | ||
|
|
619d469fa5 | ||
|
|
02c2308938 | ||
|
|
cf66e6d4ce | ||
|
|
8df40d2d94 | ||
|
|
9942d9a1dc | ||
|
|
835431ad9a | ||
|
|
b5c2b8fdec | ||
|
|
bbcc242280 | ||
|
|
e4ff850ca8 | ||
|
|
9117753a70 | ||
|
|
8095a17f0c | ||
|
|
0d1af8e26e | ||
|
|
b5834002a5 | ||
|
|
f2ba9c5d20 | ||
|
|
2fac67d8a5 | ||
|
|
36e07269e8 | ||
|
|
a35a2a6c8f | ||
|
|
050f258c8e | ||
|
|
4bad6d005a | ||
|
|
22287c9362 | ||
|
|
ee4b27c051 | ||
|
|
93c4454b8d | ||
|
|
5fc2a6a4ad | ||
|
|
c7d2766f2e | ||
|
|
06d76ed362 | ||
|
|
4a1fc2a91f | ||
|
|
0578bf0890 | ||
|
|
e3984cd006 | ||
|
|
f2e197f4e7 | ||
|
|
3cf9a53f88 | ||
|
|
c8d42e64c5 | ||
|
|
82e91afed2 | ||
|
|
13e3fc5e7a | ||
|
|
a32a2c3782 | ||
|
|
73611a7d83 | ||
|
|
7a012e4487 | ||
|
|
8935e6e7c2 | ||
|
|
8af572d502 | ||
|
|
8a0e2d9475 | ||
|
|
6d39a86dbd | ||
|
|
25d16bc779 | ||
|
|
805343f525 | ||
|
|
054c3becc0 | ||
|
|
e317f0ce29 | ||
|
|
a98d92a6c7 | ||
|
|
919f8b1386 | ||
|
|
7cd510a501 | ||
|
|
1b9aeaaea0 | ||
|
|
9b176de649 | ||
|
|
bd63cc0562 | ||
|
|
5580131017 | ||
|
|
5ae4bff91c | ||
|
|
67f06b2f6e | ||
|
|
5be89533f2 | ||
|
|
e54cc241cd | ||
|
|
a17d1f2186 | ||
|
|
23952baaff | ||
|
|
3d286ab8c3 | ||
|
|
2bb64b99e6 | ||
|
|
e26fb33ca7 | ||
|
|
6ab3e9048b | ||
|
|
7a1170f96c | ||
|
|
436ee920bb | ||
|
|
cd09b49e77 | ||
|
|
8a4b4ec4fe | ||
|
|
2b7e6b44ec | ||
|
|
989330af83 | ||
|
|
6c8971748f | ||
|
|
906d70b495 | ||
|
|
a036413f6a | ||
|
|
bb52dccc7a | ||
|
|
d19479941d | ||
|
|
820adec14a | ||
|
|
64efb6b486 | ||
|
|
479063564d | ||
|
|
ba0e4bdc62 | ||
|
|
fc34fec30a | ||
|
|
d69ab7fc86 | ||
|
|
eee0ffd6db | ||
|
|
dcf9e8f2a7 | ||
|
|
8adb0d8fa9 | ||
|
|
3d4c18abf6 | ||
|
|
eba1d054ef | ||
|
|
58b6923bc7 | ||
|
|
ad5c815ade | ||
|
|
d0c0b5e7c4 | ||
|
|
758badb05a | ||
|
|
6bad5bf2d7 | ||
|
|
fbae3fca60 | ||
|
|
fd42c82c83 | ||
|
|
35f9bd57fd | ||
|
|
90f7e4851e | ||
|
|
4ec45a22c7 | ||
|
|
c2b746a3e3 | ||
|
|
2c5e76aa8b | ||
|
|
7ea21370b2 | ||
|
|
ae5e7845bb | ||
|
|
f96a83eecf | ||
|
|
9ce74d8eff | ||
|
|
59ff96a085 | ||
|
|
b82c8d87a3 | ||
|
|
513f95e221 | ||
|
|
34729f7703 | ||
|
|
433b9d6380 | ||
|
|
0cbc684cb8 | ||
|
|
56f5698fc6 | ||
|
|
6e4dc2a69a | ||
|
|
137e9aa820 | ||
|
|
13e8710de9 | ||
|
|
767337fb8e | ||
|
|
d4a0e7899b | ||
|
|
181f54afd3 | ||
|
|
7900a7e2c0 | ||
|
|
ffb9b94719 | ||
|
|
115d938e8e | ||
|
|
53b6959bd5 | ||
|
|
184baaf579 | ||
|
|
eeaa17fbee | ||
|
|
beb4d73f04 | ||
|
|
8c9472cf4e | ||
|
|
ebaa6769b0 | ||
|
|
74de066363 | ||
|
|
148ca3b7d8 | ||
|
|
05ca8951a6 | ||
|
|
95b94a2aa7 | ||
|
|
8661152a73 | ||
|
|
145775021d | ||
|
|
2fd9575cd3 | ||
|
|
749cdcc39e | ||
|
|
9fc4008bfc | ||
|
|
f80127772e | ||
|
|
37b02ba467 | ||
|
|
971da20198 | ||
|
|
f55711c14b | ||
|
|
2f6e4c4a4a | ||
|
|
a0fc840835 | ||
|
|
b65866cb2e | ||
|
|
dffa0bb2fe | ||
|
|
8e56452df8 | ||
|
|
839e24e597 | ||
|
|
44c68f8551 | ||
|
|
5b17bbaac2 | ||
|
|
a9ec37ea79 | ||
|
|
8ed4351a9a | ||
|
|
c7b88219d3 | ||
|
|
8189af0f41 | ||
|
|
083b7d99c8 | ||
|
|
682c2f5c75 | ||
|
|
e56b5e6966 | ||
|
|
5a8fb2af90 | ||
|
|
8d08d456b6 | ||
|
|
a6c2497b35 | ||
|
|
0fcd203b6c | ||
|
|
e91562c245 | ||
|
|
9a0a48a939 | ||
|
|
c28224d574 | ||
|
|
a2840d31bd | ||
|
|
847d1c534c | ||
|
|
dc51374601 | ||
|
|
9680bd61fe | ||
|
|
fdb27d836d | ||
|
|
4d0567823a | ||
|
|
d0cfe632c9 | ||
|
|
03809763a6 | ||
|
|
41ff92592c | ||
|
|
3c754032c9 | ||
|
|
92a1d41eac | ||
|
|
8a0f723b28 | ||
|
|
f5474f18d6 | ||
|
|
2c729946a2 | ||
|
|
e7933cdae1 | ||
|
|
a012cc7041 | ||
|
|
fc2bb5014c | ||
|
|
002fddbf6e | ||
|
|
5d1b6452b0 | ||
|
|
1ea31f6952 | ||
|
|
b19bbc9212 | ||
|
|
16ce3da31f | ||
|
|
91bf5ac9a2 | ||
|
|
9d51882192 | ||
|
|
ac99d61e17 | ||
|
|
b21c28e8fe | ||
|
|
361d3383fc | ||
|
|
54ff94ec38 | ||
|
|
07beb170be | ||
|
|
eafa536c56 | ||
|
|
abdb5abbc1 | ||
|
|
a1dbf426ec | ||
|
|
30ba131704 | ||
|
|
e3f0fb539e | ||
|
|
d6667c773b | ||
|
|
3bd180882c | ||
|
|
1bb7f40b0a | ||
|
|
93d1140a31 | ||
|
|
4235885d47 | ||
|
|
6dc8f5b42e | ||
|
|
bf8d2250ca | ||
|
|
1b2d045be1 | ||
|
|
04df9f5873 | ||
|
|
849b775e55 | ||
|
|
728e21b5ae | ||
|
|
d3a183fe1d | ||
|
|
9ab9d0948f | ||
|
|
7bb6f18175 | ||
|
|
ac0f93f2c2 | ||
|
|
8a75b1411a | ||
|
|
0d552d0ba6 | ||
|
|
6ee0064ce0 | ||
|
|
5c6cd1e897 | ||
|
|
5fcaae39df | ||
|
|
7899c0ef78 | ||
|
|
543af856de | ||
|
|
3e21106336 | ||
|
|
9295985082 | ||
|
|
3ccd58af50 | ||
|
|
3f56c93b8c | ||
|
|
1311276a27 | ||
|
|
327788b1d6 | ||
|
|
1c6015ca73 | ||
|
|
4eaedbb981 | ||
|
|
2c52b77187 | ||
|
|
70527bf931 | ||
|
|
2911de8d7b | ||
|
|
62037ce577 | ||
|
|
e5bff7646a | ||
|
|
ce4b1f7f8d | ||
|
|
09bf3e7d29 | ||
|
|
18d61c2408 | ||
|
|
efac5c8f06 | ||
|
|
dd9f71203d | ||
|
|
3b51509f18 | ||
|
|
324033bdf8 | ||
|
|
d5c32dc2e7 | ||
|
|
b8c8276645 | ||
|
|
c6bf9193e2 | ||
|
|
17911ecf64 | ||
|
|
13bb45934c | ||
|
|
54ba852e71 | ||
|
|
bc85ef6e65 | ||
|
|
856b0f81d5 | ||
|
|
060fe11663 | ||
|
|
9dab54c1ed | ||
|
|
0f7a422153 | ||
|
|
058bf94c93 | ||
|
|
1a0600772f | ||
|
|
d54c18f8c3 | ||
|
|
5fc0bc5136 | ||
|
|
6f0a2d1104 | ||
|
|
9be3e0050d | ||
|
|
11596e45d1 | ||
|
|
ca3913a3c8 | ||
|
|
a6c900ef83 | ||
|
|
209f9e26a0 | ||
|
|
f9eb25b861 | ||
|
|
a3a5e81fdb | ||
|
|
0d73d9dfd3 | ||
|
|
7cdea43a37 | ||
|
|
638d16ce6e | ||
|
|
9a860dbab5 | ||
|
|
5c2a48bba8 | ||
|
|
05338bdba3 | ||
|
|
b32eeada1b | ||
|
|
acc1fefa77 | ||
|
|
a850ffa537 | ||
|
|
2bcb53fe03 | ||
|
|
94fc73ed95 | ||
|
|
df9f998671 | ||
|
|
be3ad43a07 | ||
|
|
5aa155c39f | ||
|
|
c21a21c2aa | ||
|
|
91bcdc10eb | ||
|
|
f18c8e2239 | ||
|
|
2db7608401 | ||
|
|
506632206c | ||
|
|
234a1b6571 | ||
|
|
c9d45d864f | ||
|
|
c0177516f2 | ||
|
|
accf2b5831 | ||
|
|
2f14f83a9a | ||
|
|
262968d0c9 | ||
|
|
244ac735af | ||
|
|
b919bcfc8c | ||
|
|
c21e44cf6b | ||
|
|
593ff0be75 | ||
|
|
6fd042df96 | ||
|
|
c3e1cf7230 | ||
|
|
5b3d86ab14 | ||
|
|
5d4bbbd806 | ||
|
|
cfc6d9e439 | ||
|
|
d10954f47a | ||
|
|
c3e1198448 | ||
|
|
fe9f042111 | ||
|
|
32e86ba72d | ||
|
|
28cd39d152 | ||
|
|
25f3e25555 | ||
|
|
699fbb4e55 | ||
|
|
5fa93de8c4 | ||
|
|
74e976aae4 | ||
|
|
dd829e9d6a | ||
|
|
56bca03fbe | ||
|
|
d0572730a8 | ||
|
|
eb816936ed | ||
|
|
e1b9cac1df | ||
|
|
d927b631c5 | ||
|
|
17dc5d98d1 | ||
|
|
cda086093d | ||
|
|
bda579577c | ||
|
|
a16b555d47 | ||
|
|
6667c39c73 | ||
|
|
5219ac12a6 | ||
|
|
445f813fb9 | ||
|
|
87f9e59cfb | ||
|
|
8b03b39aa8 | ||
|
|
e59b6bb971 | ||
|
|
24a7ed467c | ||
|
|
f01f1033ac | ||
|
|
d35f515413 | ||
|
|
125b459e56 | ||
|
|
33edee1ba6 | ||
|
|
d20335dabc | ||
|
|
d10d258213 | ||
|
|
d57ba1ed8b | ||
|
|
2d0e34e57b | ||
|
|
a005d06255 | ||
|
|
a301ef5a5a | ||
|
|
9422df2737 | ||
|
|
6dabe4d3ca | ||
|
|
00e4652d30 | ||
|
|
b6434c5318 | ||
|
|
3f7f9f8d61 | ||
|
|
f3bb592544 | ||
|
|
69f080fb75 | ||
|
|
04272a7cc8 | ||
|
|
8d35af946e | ||
|
|
24065ec6b6 | ||
|
|
627b0bf644 | ||
|
|
b43da46b82 | ||
|
|
4255a01c64 | ||
|
|
23adbd4002 | ||
|
|
fb5a24fcc6 | ||
|
|
cfdd5a1900 | ||
|
|
2313f326df | ||
|
|
2e092a2313 | ||
|
|
763ef06c18 | ||
|
|
8292f6cd42 | ||
|
|
278bba499e | ||
|
|
dd99ed28e0 | ||
|
|
9a8aca69bf | ||
|
|
7ad62512eb | ||
|
|
bd466661ec | ||
|
|
7ebb509d05 | ||
|
|
0aa13c046c | ||
|
|
a7a33d73f5 | ||
|
|
ffa39857d3 | ||
|
|
e85c3bc465 | ||
|
|
8185ba7054 | ||
|
|
d501865bec | ||
|
|
d62310bb5f | ||
|
|
1835bff196 |
37
.github/workflows/build-container.yml
vendored
37
.github/workflows/build-container.yml
vendored
@@ -13,6 +13,12 @@ on:
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
push-to-registry:
|
||||
description: Push the built image to the container registry
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -50,16 +56,15 @@ jobs:
|
||||
df -h
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
images: |
|
||||
ghcr.io/${{ github.repository }}
|
||||
${{ env.DOCKERHUB_REPOSITORY }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
@@ -72,49 +77,33 @@ jobs:
|
||||
suffix=-${{ matrix.gpu-driver }},onlatest=false
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# - name: Login to Docker Hub
|
||||
# if: github.event_name != 'pull_request' && vars.DOCKERHUB_REPOSITORY != ''
|
||||
# uses: docker/login-action@v2
|
||||
# with:
|
||||
# username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
# password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build container
|
||||
timeout-minutes: 40
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v4
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: |
|
||||
type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
type=gha,scope=main-${{ matrix.gpu-driver }}
|
||||
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
|
||||
# - name: Docker Hub Description
|
||||
# if: github.ref == 'refs/heads/main' || github.ref == 'refs/tags/*' && vars.DOCKERHUB_REPOSITORY != ''
|
||||
# uses: peter-evans/dockerhub-description@v3
|
||||
# with:
|
||||
# username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
# password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
# repository: ${{ vars.DOCKERHUB_REPOSITORY }}
|
||||
# short-description: ${{ github.event.repository.description }}
|
||||
|
||||
@@ -196,6 +196,22 @@ tips to reduce the problem:
|
||||
=== "12GB VRAM GPU"
|
||||
|
||||
This should be sufficient to generate larger images up to about 1280x1280.
|
||||
|
||||
## Checkpoint Models Load Slowly or Use Too Much RAM
|
||||
|
||||
The difference between diffusers models (a folder containing multiple
|
||||
subfolders) and checkpoint models (a file ending with .safetensors or
|
||||
.ckpt) is that InvokeAI is able to load diffusers models into memory
|
||||
incrementally, while checkpoint models must be loaded all at
|
||||
once. With very large models, or systems with limited RAM, you may
|
||||
experience slowdowns and other memory-related issues when loading
|
||||
checkpoint models.
|
||||
|
||||
To solve this, go to the Model Manager tab (the cube), select the
|
||||
checkpoint model that's giving you trouble, and press the "Convert"
|
||||
button in the upper right of your browser window. This will conver the
|
||||
checkpoint into a diffusers model, after which loading should be
|
||||
faster and less memory-intensive.
|
||||
|
||||
## Memory Leak (Linux)
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional, Type
|
||||
|
||||
@@ -17,6 +19,7 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
@@ -31,6 +34,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
@@ -50,6 +54,13 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class CacheType(str, Enum):
|
||||
"""Cache type - one of vram or ram."""
|
||||
|
||||
RAM = "RAM"
|
||||
VRAM = "VRAM"
|
||||
|
||||
|
||||
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||
"""Add a cover image URL to a model configuration."""
|
||||
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
|
||||
@@ -797,3 +808,83 @@ async def get_starter_models() -> list[StarterModel]:
|
||||
model.dependencies = missing_deps
|
||||
|
||||
return starter_models
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/model_cache",
|
||||
operation_id="get_cache_size",
|
||||
response_model=float,
|
||||
summary="Get maximum size of model manager RAM or VRAM cache.",
|
||||
)
|
||||
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
|
||||
"""Return the current RAM or VRAM cache size setting (in GB)."""
|
||||
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||
value = 0.0
|
||||
if cache_type == CacheType.RAM:
|
||||
value = cache.max_cache_size
|
||||
elif cache_type == CacheType.VRAM:
|
||||
value = cache.max_vram_cache_size
|
||||
return value
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/model_cache",
|
||||
operation_id="set_cache_size",
|
||||
response_model=float,
|
||||
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
|
||||
)
|
||||
async def set_cache_size(
|
||||
value: float = Query(description="The new value for the maximum cache size"),
|
||||
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
|
||||
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
|
||||
) -> float:
|
||||
"""Set the current RAM or VRAM cache size setting (in GB). ."""
|
||||
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||
app_config = get_config()
|
||||
# Record initial state.
|
||||
vram_old = app_config.vram
|
||||
ram_old = app_config.ram
|
||||
|
||||
# Prepare target state.
|
||||
vram_new = vram_old
|
||||
ram_new = ram_old
|
||||
if cache_type == CacheType.RAM:
|
||||
ram_new = value
|
||||
elif cache_type == CacheType.VRAM:
|
||||
vram_new = value
|
||||
else:
|
||||
raise ValueError(f"Unexpected {cache_type=}.")
|
||||
|
||||
config_path = app_config.config_file_path
|
||||
new_config_path = config_path.with_suffix(".yaml.new")
|
||||
|
||||
try:
|
||||
# Try to apply the target state.
|
||||
cache.max_vram_cache_size = vram_new
|
||||
cache.max_cache_size = ram_new
|
||||
app_config.ram = ram_new
|
||||
app_config.vram = vram_new
|
||||
if persist:
|
||||
app_config.write_file(new_config_path)
|
||||
shutil.move(new_config_path, config_path)
|
||||
except Exception as e:
|
||||
# If there was a failure, restore the initial state.
|
||||
cache.max_cache_size = ram_old
|
||||
cache.max_vram_cache_size = vram_old
|
||||
app_config.ram = ram_old
|
||||
app_config.vram = vram_old
|
||||
|
||||
raise RuntimeError("Failed to update cache size") from e
|
||||
return value
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/stats",
|
||||
operation_id="get_stats",
|
||||
response_model=Optional[CacheStats],
|
||||
summary="Get model manager RAM cache performance statistics.",
|
||||
)
|
||||
async def get_stats() -> Optional[CacheStats]:
|
||||
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""
|
||||
|
||||
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
|
||||
|
||||
@@ -11,6 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByDestinationResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
PruneResult,
|
||||
@@ -105,6 +106,21 @@ async def cancel_by_batch_ids(
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_by_destination",
|
||||
operation_id="cancel_by_destination",
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
)
|
||||
async def cancel_by_destination(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
destination: str = Query(description="The destination to cancel all queue items for"),
|
||||
) -> CancelByDestinationResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/clear",
|
||||
operation_id="clear",
|
||||
|
||||
@@ -20,7 +20,6 @@ from typing import (
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import semver
|
||||
@@ -61,11 +60,13 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
- `Stable`: The invocation, including its inputs/outputs and internal logic, is stable. You may build workflows with it, having confidence that they will not break because of a change in this invocation.
|
||||
- `Beta`: The invocation is not yet stable, but is planned to be stable in the future. Workflows built around this invocation may break, but we are committed to supporting this invocation long-term.
|
||||
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
|
||||
- `Deprecated`: The invocation is deprecated and may be removed in a future version.
|
||||
"""
|
||||
|
||||
Stable = "stable"
|
||||
Beta = "beta"
|
||||
Prototype = "prototype"
|
||||
Deprecated = "deprecated"
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
@@ -80,7 +81,7 @@ class UIConfigBase(BaseModel):
|
||||
version: str = Field(
|
||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||
)
|
||||
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
|
||||
node_pack: str = Field(description="The node pack that this node belongs to, will be 'invokeai' for built-in nodes")
|
||||
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -230,18 +231,16 @@ class BaseInvocation(ABC, BaseModel):
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||
if uiconfig is not None:
|
||||
if uiconfig.title is not None:
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig.tags is not None:
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig.category is not None:
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig.node_pack is not None:
|
||||
schema["node_pack"] = uiconfig.node_pack
|
||||
schema["classification"] = uiconfig.classification
|
||||
schema["version"] = uiconfig.version
|
||||
if title := model_class.UIConfig.title:
|
||||
schema["title"] = title
|
||||
if tags := model_class.UIConfig.tags:
|
||||
schema["tags"] = tags
|
||||
if category := model_class.UIConfig.category:
|
||||
schema["category"] = category
|
||||
if node_pack := model_class.UIConfig.node_pack:
|
||||
schema["node_pack"] = node_pack
|
||||
schema["classification"] = model_class.UIConfig.classification
|
||||
schema["version"] = model_class.UIConfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
schema["class"] = "invocation"
|
||||
@@ -312,7 +311,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
UIConfig: ClassVar[UIConfigBase]
|
||||
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
@@ -441,30 +440,25 @@ def invocation(
|
||||
validate_fields(cls.model_fields, invocation_type)
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconfig_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
|
||||
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
|
||||
cls.UIConfig.title = title
|
||||
cls.UIConfig.tags = tags
|
||||
cls.UIConfig.category = category
|
||||
cls.UIConfig.classification = classification
|
||||
|
||||
# Grab the node pack's name from the module name, if it's a custom node
|
||||
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
|
||||
if is_custom_node:
|
||||
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
|
||||
else:
|
||||
cls.UIConfig.node_pack = None
|
||||
uiconfig: dict[str, Any] = {}
|
||||
uiconfig["title"] = title
|
||||
uiconfig["tags"] = tags
|
||||
uiconfig["category"] = category
|
||||
uiconfig["classification"] = classification
|
||||
# The node pack is the module name - will be "invokeai" for built-in nodes
|
||||
uiconfig["node_pack"] = cls.__module__.split(".")[0]
|
||||
|
||||
if version is not None:
|
||||
try:
|
||||
semver.Version.parse(version)
|
||||
except ValueError as e:
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
cls.UIConfig.version = version
|
||||
uiconfig["version"] = version
|
||||
else:
|
||||
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||
cls.UIConfig.version = "1.0.0"
|
||||
uiconfig["version"] = "1.0.0"
|
||||
|
||||
cls.UIConfig = UIConfigBase(**uiconfig)
|
||||
|
||||
if use_cache is not None:
|
||||
cls.model_fields["use_cache"].default = use_cache
|
||||
|
||||
34
invokeai/app/invocations/canny.py
Normal file
34
invokeai/app/invocations/canny.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import cv2
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.util import cv2_to_pil, pil_to_cv2
|
||||
|
||||
|
||||
@invocation(
|
||||
"canny_edge_detection",
|
||||
title="Canny Edge Detection",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CannyEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Geneartes an edge map using a cv2's Canny algorithm."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
low_threshold: int = InputField(
|
||||
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
high_threshold: int = InputField(
|
||||
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
np_img = pil_to_cv2(image)
|
||||
edge_map = cv2.Canny(np_img, self.low_threshold, self.high_threshold)
|
||||
edge_map_pil = cv2_to_pil(edge_map)
|
||||
image_dto = context.images.save(image=edge_map_pil)
|
||||
return ImageOutput.build(image_dto)
|
||||
41
invokeai/app/invocations/color_map.py
Normal file
41
invokeai/app/invocations/color_map.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import cv2
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
@invocation(
|
||||
"color_map",
|
||||
title="Color Map",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ColorMapInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates a color map from the provided image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
height, width = np_image.shape[:2]
|
||||
|
||||
width_tile_size = min(self.tile_size, width)
|
||||
height_tile_size = min(self.tile_size, height)
|
||||
|
||||
color_map = cv2.resize(
|
||||
np_image,
|
||||
(width // width_tile_size, height // height_tile_size),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
color_map_pil = np_to_pil(color_map)
|
||||
|
||||
image_dto = context.images.save(image=color_map_pil)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -19,7 +19,7 @@ from invokeai.app.invocations.model import CLIPField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@@ -55,7 +55,6 @@ class CompelInvocation(BaseInvocation):
|
||||
clip: CLIPField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
|
||||
25
invokeai/app/invocations/content_shuffle.py
Normal file
25
invokeai/app/invocations/content_shuffle.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.content_shuffle import content_shuffle
|
||||
|
||||
|
||||
@invocation(
|
||||
"content_shuffle",
|
||||
title="Content Shuffle",
|
||||
tags=["controlnet", "normal"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ContentShuffleInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Shuffles the image, similar to a 'liquify' filter."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
scale_factor: int = InputField(default=256, ge=0, description="The scale factor used for the shuffle")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
output_image = content_shuffle(input_image=image, scale_factor=self.scale_factor)
|
||||
image_dto = context.images.save(image=output_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -174,6 +174,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.3.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
@@ -208,6 +209,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
@@ -237,6 +239,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
@@ -259,6 +262,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
@@ -282,6 +286,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
@@ -314,6 +319,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
@@ -330,7 +336,12 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.3"
|
||||
"mlsd_image_processor",
|
||||
title="MLSD Processor",
|
||||
tags=["controlnet", "mlsd"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
@@ -353,7 +364,12 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.3"
|
||||
"pidi_image_processor",
|
||||
title="PIDI Processor",
|
||||
tags=["controlnet", "pidi"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
@@ -381,6 +397,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
@@ -411,6 +428,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
@@ -427,6 +445,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
@@ -454,6 +473,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
@@ -483,6 +503,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
@@ -523,6 +544,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
@@ -570,6 +592,7 @@ class SamDetectorReproducibleColors(SamDetector):
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a color map from the provided image"""
|
||||
@@ -609,6 +632,7 @@ DEPTH_ANYTHING_MODELS = {
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.1.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
@@ -643,6 +667,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
tags=["controlnet", "dwpose", "openpose"],
|
||||
category="controlnet",
|
||||
version="1.1.1",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates an openpose pose from an image using DWPose"""
|
||||
|
||||
@@ -36,7 +36,7 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
@@ -185,7 +185,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
input=Input.Connection,
|
||||
ui_order=8,
|
||||
)
|
||||
|
||||
45
invokeai/app/invocations/depth_anything.py
Normal file
45
invokeai/app/invocations/depth_anything.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
|
||||
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "LiheYoung/depth-anything-large-hf",
|
||||
"base": "LiheYoung/depth-anything-base-hf",
|
||||
"small": "LiheYoung/depth-anything-small-hf",
|
||||
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"depth_anything_depth_estimation",
|
||||
title="Depth Anything Depth Estimation",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates a depth map using a Depth Anything model."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||
default="small_v2", description="The size of the depth model to use"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
model_url = DEPTH_ANYTHING_MODELS[self.model_size]
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
loaded_model = context.models.load_remote_model(model_url, DepthAnythingPipeline.load_model)
|
||||
|
||||
with loaded_model as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
||||
image_dto = context.images.save(image=depth_map)
|
||||
return ImageOutput.build(image_dto)
|
||||
50
invokeai/app/invocations/dw_openpose.py
Normal file
50
invokeai/app/invocations/dw_openpose.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import onnxruntime as ort
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector2
|
||||
|
||||
|
||||
@invocation(
|
||||
"dw_openpose_detection",
|
||||
title="DW Openpose Detection",
|
||||
tags=["controlnet", "dwpose", "openpose"],
|
||||
category="controlnet",
|
||||
version="1.1.1",
|
||||
)
|
||||
class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an openpose pose from an image using DWPose"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
draw_body: bool = InputField(default=True)
|
||||
draw_face: bool = InputField(default=False)
|
||||
draw_hands: bool = InputField(default=False)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_det())
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
|
||||
|
||||
loaded_session_det = context.models.load_local_model(
|
||||
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
loaded_session_pose = context.models.load_local_model(
|
||||
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
|
||||
with loaded_session_det as session_det, loaded_session_pose as session_pose:
|
||||
assert isinstance(session_det, ort.InferenceSession)
|
||||
assert isinstance(session_pose, ort.InferenceSession)
|
||||
detector = DWOpenposeDetector2(session_det=session_det, session_pose=session_pose)
|
||||
detected_image = detector.run(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
draw_hands=self.draw_hands,
|
||||
draw_body=self.draw_body,
|
||||
)
|
||||
image_dto = context.images.save(image=detected_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -181,7 +181,7 @@ class FieldDescriptions:
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
denoise_mask = "A mask of the region to apply the denoising process to."
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
|
||||
249
invokeai/app/invocations/flux_denoise.py
Normal file
249
invokeai/app/invocations/flux_denoise.py
Normal file
@@ -0,0 +1,249 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
clip_timestep_schedule,
|
||||
generate_img_ids,
|
||||
get_noise,
|
||||
get_schedule,
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_denoise",
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
input=Input.Connection,
|
||||
)
|
||||
denoising_start: float = InputField(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
|
||||
)
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
|
||||
|
||||
# Prepare input noise.
|
||||
noise = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=image_seq_len,
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||
timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end)
|
||||
|
||||
# Prepare input latent image.
|
||||
if init_latents is not None:
|
||||
# If init_latents is provided, we are doing image-to-image.
|
||||
|
||||
if is_schnell:
|
||||
context.logger.warning(
|
||||
"Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely "
|
||||
"to be poor. Consider using a FLUX dev model instead."
|
||||
)
|
||||
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
|
||||
x = noise
|
||||
|
||||
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
||||
# denoising steps.
|
||||
if len(timesteps) <= 1:
|
||||
return x
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, h, w = x.shape
|
||||
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||
assert image_seq_len == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
inpaint_extension = InpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
inpaint_extension=inpaint_extension,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
|
||||
- Loads the mask
|
||||
- Resizes if necessary
|
||||
- Casts to same device/dtype as latents
|
||||
- Expands mask to the same shape as latents so that they line up after 'packing'
|
||||
|
||||
Args:
|
||||
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
||||
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
|
||||
device, and dtype for the inpaint mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor | None: Inpaint mask.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
|
||||
# `latents`.
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
|
||||
context.util.flux_step_callback(state)
|
||||
|
||||
return step_callback
|
||||
@@ -1,169 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
title="FLUX Text to Image",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Text-to-image generation using a FLUX model."""
|
||||
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
|
||||
)
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
image = self._run_vae_decoding(context, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Prepare input noise.
|
||||
x = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
x, img_ids = prepare_latent_img_patches(x)
|
||||
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=x.shape[1],
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
def step_callback() -> None:
|
||||
if context.util.is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# TODO: Make this look like the image before re-enabling
|
||||
# latent_image = unpack(img.float(), self.height, self.width)
|
||||
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
|
||||
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
|
||||
|
||||
# # Create a new tensor of the required shape [255, 255, 3]
|
||||
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
|
||||
|
||||
# # Convert to a NumPy array and then to a PIL Image
|
||||
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
|
||||
|
||||
# (width, height) = image.size
|
||||
# width *= 8
|
||||
# height *= 8
|
||||
|
||||
# dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
# # TODO: move this whole function to invocation context to properly reference these variables
|
||||
# context._services.events.emit_invocation_denoise_progress(
|
||||
# context._data.queue_item,
|
||||
# context._data.invocation,
|
||||
# state,
|
||||
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
# )
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=step_callback,
|
||||
guidance=self.guidance,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
return x
|
||||
|
||||
def _run_vae_decoding(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
latents: torch.Tensor,
|
||||
) -> Image.Image:
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
return img_pil
|
||||
60
invokeai/app/invocations/flux_vae_decode.py
Normal file
60
invokeai/app/invocations/flux_vae_decode.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_vae_decode",
|
||||
title="FLUX Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i", "flux"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
return img_pil
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
67
invokeai/app/invocations/flux_vae_encode.py
Normal file
67
invokeai/app/invocations/flux_vae_encode.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_vae_encode",
|
||||
title="FLUX Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l", "flux"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(ryand): Expose seed parameter at the invocation level.
|
||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||
# should be used for VAE encode sampling.
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
image_tensor = image_tensor.to(
|
||||
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
33
invokeai/app/invocations/hed.py
Normal file
33
invokeai/app/invocations/hed.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from builtins import bool
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.hed import ControlNetHED_Apache2, HEDEdgeDetector
|
||||
|
||||
|
||||
@invocation(
|
||||
"hed_edge_detection",
|
||||
title="HED Edge Detection",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Geneartes an edge map using the HED (softedge) model."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(HEDEdgeDetector.get_model_url(), HEDEdgeDetector.load_model)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, ControlNetHED_Apache2)
|
||||
hed_processor = HEDEdgeDetector(model)
|
||||
edge_map = hed_processor.run(image=image, scribble=self.scribble)
|
||||
|
||||
image_dto = context.images.save(image=edge_map)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -6,13 +6,19 @@ import cv2
|
||||
import numpy
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import (
|
||||
ColorField,
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
@@ -1007,3 +1013,62 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation_output("canvas_v2_mask_and_crop_output")
|
||||
class CanvasV2MaskAndCropOutput(ImageOutput):
|
||||
offset_x: int = OutputField(description="The x offset of the image, after cropping")
|
||||
offset_y: int = OutputField(description="The y offset of the image, after cropping")
|
||||
|
||||
|
||||
@invocation(
|
||||
"canvas_v2_mask_and_crop",
|
||||
title="Canvas V2 Mask and Crop",
|
||||
tags=["image", "mask", "id"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Handles Canvas V2 image output masking and cropping"""
|
||||
|
||||
source_image: ImageField | None = InputField(
|
||||
default=None,
|
||||
description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.",
|
||||
)
|
||||
generated_image: ImageField = InputField(description="The image to apply the mask to")
|
||||
mask: ImageField = InputField(description="The mask to apply")
|
||||
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
|
||||
|
||||
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
|
||||
mask_array = numpy.array(mask)
|
||||
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
|
||||
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
|
||||
dilated_mask = Image.fromarray(dilated_mask_array)
|
||||
if self.mask_blur > 0:
|
||||
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
||||
return ImageOps.invert(mask.convert("L"))
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
|
||||
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
|
||||
|
||||
if self.source_image:
|
||||
generated_image = context.images.get_pil(self.generated_image.image_name)
|
||||
source_image = context.images.get_pil(self.source_image.image_name)
|
||||
source_image.paste(generated_image, (0, 0), mask)
|
||||
image_dto = context.images.save(image=source_image)
|
||||
else:
|
||||
generated_image = context.images.get_pil(self.generated_image.image_name)
|
||||
generated_image.putalpha(mask)
|
||||
image_dto = context.images.save(image=generated_image)
|
||||
|
||||
# bbox = image.getbbox()
|
||||
# image = image.crop(bbox)
|
||||
|
||||
return CanvasV2MaskAndCropOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
offset_x=0,
|
||||
offset_y=0,
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
34
invokeai/app/invocations/lineart.py
Normal file
34
invokeai/app/invocations/lineart.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from builtins import bool
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.lineart import Generator, LineartEdgeDetector
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_edge_detection",
|
||||
title="Lineart Edge Detection",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an edge map using the Lineart model."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
model_url = LineartEdgeDetector.get_model_url(self.coarse)
|
||||
loaded_model = context.models.load_remote_model(model_url, LineartEdgeDetector.load_model)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, Generator)
|
||||
detector = LineartEdgeDetector(model)
|
||||
edge_map = detector.run(image=image)
|
||||
|
||||
image_dto = context.images.save(image=edge_map)
|
||||
return ImageOutput.build(image_dto)
|
||||
31
invokeai/app/invocations/lineart_anime.py
Normal file
31
invokeai/app/invocations/lineart_anime.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeEdgeDetector, UnetGenerator
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_anime_edge_detection",
|
||||
title="Lineart Anime Edge Detection",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Geneartes an edge map using the Lineart model."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
model_url = LineartAnimeEdgeDetector.get_model_url()
|
||||
loaded_model = context.models.load_remote_model(model_url, LineartAnimeEdgeDetector.load_model)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, UnetGenerator)
|
||||
detector = LineartAnimeEdgeDetector(model)
|
||||
edge_map = detector.run(image=image)
|
||||
|
||||
image_dto = context.images.save(image=edge_map)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -126,7 +126,7 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
||||
title="Tensor Mask to Image",
|
||||
tags=["mask"],
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
)
|
||||
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Convert a mask tensor to an image."""
|
||||
@@ -135,6 +135,11 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.tensors.load(self.mask.tensor_name)
|
||||
|
||||
# Squeeze the channel dimension if it exists.
|
||||
if mask.dim() == 3:
|
||||
mask = mask.squeeze(0)
|
||||
|
||||
# Ensure that the mask is binary.
|
||||
if mask.dtype != torch.bool:
|
||||
mask = mask > 0.5
|
||||
|
||||
26
invokeai/app/invocations/mediapipe_face.py
Normal file
26
invokeai/app/invocations/mediapipe_face.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.mediapipe_face import detect_faces
|
||||
|
||||
|
||||
@invocation(
|
||||
"mediapipe_face_detection",
|
||||
title="MediaPipe Face Detection",
|
||||
tags=["controlnet", "face"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MediaPipeFaceDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Detects faces using MediaPipe."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
detected_faces = detect_faces(image=image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
image_dto = context.images.save(image=detected_faces)
|
||||
return ImageOutput.build(image_dto)
|
||||
39
invokeai/app/invocations/mlsd.py
Normal file
39
invokeai/app/invocations/mlsd.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.mlsd import MLSDDetector
|
||||
from invokeai.backend.image_util.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_detection",
|
||||
title="MLSD Detection",
|
||||
tags=["controlnet", "mlsd", "edge"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an line segment map using MLSD."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
score_threshold: float = InputField(
|
||||
default=0.1, ge=0, description="The threshold used to score points when determining line segments"
|
||||
)
|
||||
distance_threshold: float = InputField(
|
||||
default=20.0,
|
||||
ge=0,
|
||||
description="Threshold for including a line segment - lines shorter than this distance will be discarded",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(MLSDDetector.get_model_url(), MLSDDetector.load_model)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, MobileV2_MLSD_Large)
|
||||
detector = MLSDDetector(model)
|
||||
edge_map = detector.run(image, self.score_threshold, self.distance_threshold)
|
||||
|
||||
image_dto = context.images.save(image=edge_map)
|
||||
return ImageOutput.build(image_dto)
|
||||
31
invokeai/app/invocations/normal_bae.py
Normal file
31
invokeai/app/invocations/normal_bae.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.normal_bae import NormalMapDetector
|
||||
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
|
||||
|
||||
|
||||
@invocation(
|
||||
"normal_map",
|
||||
title="Normal Map",
|
||||
tags=["controlnet", "normal"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates a normal map."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(NormalMapDetector.get_model_url(), NormalMapDetector.load_model)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, NNET)
|
||||
detector = NormalMapDetector(model)
|
||||
normal_map = detector.run(image=image)
|
||||
|
||||
image_dto = context.images.save(image=normal_map)
|
||||
return ImageOutput.build(image_dto)
|
||||
33
invokeai/app/invocations/pidi.py
Normal file
33
invokeai/app/invocations/pidi.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.pidi import PIDINetDetector
|
||||
from invokeai.backend.image_util.pidi.model import PiDiNet
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_edge_detection",
|
||||
title="PiDiNet Edge Detection",
|
||||
tags=["controlnet", "edge"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an edge map using PiDiNet."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
quantize_edges: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(PIDINetDetector.get_model_url(), PIDINetDetector.load_model)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, PiDiNet)
|
||||
detector = PIDINetDetector(model)
|
||||
edge_map = detector.run(image=image, quantize_edges=self.quantize_edges, scribble=self.scribble)
|
||||
|
||||
image_dto = context.images.save(image=edge_map)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -22,7 +22,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
|
||||
@@ -88,6 +88,8 @@ class QueueItemEventBase(QueueEventBase):
|
||||
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the queue item")
|
||||
destination: str | None = Field(default=None, description="The destination of the queue item")
|
||||
|
||||
|
||||
class InvocationEventBase(QueueItemEventBase):
|
||||
@@ -95,8 +97,6 @@ class InvocationEventBase(QueueItemEventBase):
|
||||
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||
@@ -114,6 +114,8 @@ class InvocationStartedEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -147,6 +149,8 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -184,6 +188,8 @@ class InvocationCompleteEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -216,6 +222,8 @@ class InvocationErrorEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -253,6 +261,8 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
status=queue_item.status,
|
||||
error_type=queue_item.error_type,
|
||||
@@ -279,12 +289,14 @@ class BatchEnqueuedEvent(QueueEventBase):
|
||||
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
||||
)
|
||||
priority: int = Field(description="The priority of the batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the batch")
|
||||
|
||||
@classmethod
|
||||
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
||||
return cls(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
origin=enqueue_result.batch.origin,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
requested=enqueue_result.requested,
|
||||
priority=enqueue_result.priority,
|
||||
|
||||
@@ -103,7 +103,7 @@ class HFModelSource(StringLikeSource):
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
base += f"::{self.subfolder.as_posix()}"
|
||||
return base
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByDestinationResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
@@ -95,6 +96,11 @@ class SessionQueueBase(ABC):
|
||||
"""Cancels all queue items with matching batch IDs"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
"""Cancels all queue items with the given batch destination"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
"""Cancels all queue items with matching queue ID"""
|
||||
|
||||
@@ -77,6 +77,14 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
||||
|
||||
class Batch(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||
origin: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
|
||||
)
|
||||
destination: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
|
||||
)
|
||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||
graph: Graph = Field(description="The graph to initialize the session with")
|
||||
workflow: Optional[WorkflowWithoutID] = Field(
|
||||
@@ -195,6 +203,14 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
||||
priority: int = Field(default=0, description="The priority of this queue item")
|
||||
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
||||
origin: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
|
||||
)
|
||||
destination: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
|
||||
)
|
||||
session_id: str = Field(
|
||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||
)
|
||||
@@ -294,6 +310,8 @@ class SessionQueueStatus(BaseModel):
|
||||
class BatchStatus(BaseModel):
|
||||
queue_id: str = Field(..., description="The ID of the queue")
|
||||
batch_id: str = Field(..., description="The ID of the batch")
|
||||
origin: str | None = Field(..., description="The origin of the batch")
|
||||
destination: str | None = Field(..., description="The destination of the batch")
|
||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||
@@ -328,6 +346,12 @@ class CancelByBatchIDsResult(BaseModel):
|
||||
canceled: int = Field(..., description="Number of queue items canceled")
|
||||
|
||||
|
||||
class CancelByDestinationResult(CancelByBatchIDsResult):
|
||||
"""Result of canceling by a destination"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||
"""Result of canceling by queue id"""
|
||||
|
||||
@@ -433,6 +457,8 @@ class SessionQueueValueToInsert(NamedTuple):
|
||||
field_values: Optional[str] # field_values json
|
||||
priority: int # priority
|
||||
workflow: Optional[str] # workflow json
|
||||
origin: str | None
|
||||
destination: str | None
|
||||
|
||||
|
||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
@@ -453,6 +479,8 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
||||
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
||||
priority, # priority
|
||||
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
||||
batch.origin, # origin
|
||||
batch.destination, # destination
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
||||
@@ -10,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByDestinationResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
@@ -127,8 +128,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
self.__cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
@@ -417,11 +418,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
@@ -429,6 +426,46 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.release()
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND destination == ?
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = (queue_id, destination)
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
params,
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
{where};
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByDestinationResult(canceled=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
@@ -541,7 +578,9 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
started_at,
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id
|
||||
queue_id,
|
||||
origin,
|
||||
destination
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
@@ -621,7 +660,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
@@ -633,6 +672,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
destination = result[0]["destination"] if result else None
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
@@ -641,6 +682,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
return BatchStatus(
|
||||
batch_id=batch_id,
|
||||
origin=origin,
|
||||
destination=destination,
|
||||
queue_id=queue_id,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
|
||||
@@ -14,7 +14,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@@ -557,6 +557,24 @@ class UtilInterface(InvocationContextInterface):
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
def flux_step_callback(self, intermediate_state: PipelineIntermediateState) -> None:
|
||||
"""
|
||||
The step callback emits a progress event with the current step, the total number of
|
||||
steps, a preview image, and some other internal metadata.
|
||||
|
||||
This should be called after each denoising step.
|
||||
|
||||
Args:
|
||||
intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
flux_step_callback(
|
||||
context_data=self._data,
|
||||
intermediate_state=intermediate_state,
|
||||
events=self._services.events,
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Provides access to various services and data for the current invocation.
|
||||
|
||||
@@ -17,6 +17,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -51,6 +52,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_12(app_config=config))
|
||||
migrator.register_migration(build_migration_13())
|
||||
migrator.register_migration(build_migration_14())
|
||||
migrator.register_migration(build_migration_15())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration15Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._add_origin_col(cursor)
|
||||
|
||||
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
- Adds `origin` column to the session queue table.
|
||||
- Adds `destination` column to the session queue table.
|
||||
"""
|
||||
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN destination TEXT;")
|
||||
|
||||
|
||||
def build_migration_15() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 14 to 15.
|
||||
|
||||
This migration does the following:
|
||||
- Adds `origin` column to the session queue table.
|
||||
- Adds `destination` column to the session queue table.
|
||||
"""
|
||||
migration_15 = Migration(
|
||||
from_version=14,
|
||||
to_version=15,
|
||||
callback=Migration15Callback(),
|
||||
)
|
||||
|
||||
return migration_15
|
||||
@@ -0,0 +1,407 @@
|
||||
{
|
||||
"name": "FLUX Image to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple image-to-image workflow using a FLUX dev model. ",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "image2image, flux, image-to-image",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend using FLUX dev models for image-to-image workflows. The image-to-image performance with FLUX schnell models is poor.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "clip_embed_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "vae_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"fieldName": "denoising_start"
|
||||
},
|
||||
{
|
||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"fieldName": "num_steps"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"type": "flux_vae_encode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"image": {
|
||||
"name": "image",
|
||||
"label": "",
|
||||
"value": {
|
||||
"image_name": "8a5c62aa-9335-45d2-9c71-89af9fc1f8d4.png"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 732.7680166609682,
|
||||
"y": -24.37398171806909
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"type": "flux_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"denoise_mask": {
|
||||
"name": "denoise_mask",
|
||||
"label": ""
|
||||
},
|
||||
"denoising_start": {
|
||||
"name": "denoising_start",
|
||||
"label": "",
|
||||
"value": 0.04
|
||||
},
|
||||
"denoising_end": {
|
||||
"name": "denoising_end",
|
||||
"label": "",
|
||||
"value": 1
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1182.8836633018684,
|
||||
"y": -251.38882958913183
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "flux_vae_decode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1575.5797431839133,
|
||||
"y": -209.00150975507415
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.4",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": "Model (dev variant recommended for Image-to-Image)"
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "fa23a584-b623-415d-832a-21b5098ff1a1",
|
||||
"hash": "blake3:17c19f0ef941c3b7609a9c94a659ca5364de0be364a91d4179f0e39ba17c3b70",
|
||||
"name": "clip-vit-large-patch14",
|
||||
"base": "any",
|
||||
"type": "clip_embed"
|
||||
}
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "74fc82ba-c0a8-479d-a890-2126f82da758",
|
||||
"hash": "blake3:ce21cb76364aa6e2421311cf4a4b5eb052a76c4f1cd207b50703d8978198a068",
|
||||
"name": "FLUX.1-schnell_ae",
|
||||
"base": "flux",
|
||||
"type": "vae"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 328.1809894659957,
|
||||
"y": -90.2241133566946
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "flux_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"clip": {
|
||||
"name": "clip",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"t5_max_seq_len": {
|
||||
"name": "t5_max_seq_len",
|
||||
"label": "T5 Max Seq Len",
|
||||
"value": 256
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": "a cat wearing a birthday hat"
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 745.8823365057267,
|
||||
"y": -299.60249175851914
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 725.834098928012,
|
||||
"y": 496.2710031089931
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bheight-ace0258f-67d7-4eee-a218-6fff27065214height",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "height",
|
||||
"targetHandle": "height"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bwidth-ace0258f-67d7-4eee-a218-6fff27065214width",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "width",
|
||||
"targetHandle": "width"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912blatents-ace0258f-67d7-4eee-a218-6fff27065214latents",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-2981a67c-480f-4237-9384-26b68dbf912bvae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-ace0258f-67d7-4eee-a218-6fff27065214latents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
|
||||
"type": "default",
|
||||
"source": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-ace0258f-67d7-4eee-a218-6fff27065214seed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-ace0258f-67d7-4eee-a218-6fff27065214transformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-ace0258f-67d7-4eee-a218-6fff27065214positive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "FLUX Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models.",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "text2image, flux",
|
||||
@@ -11,17 +11,25 @@
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "clip_embed_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "vae_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"nodeId": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"fieldName": "num_steps"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
@@ -29,6 +37,121 @@
|
||||
"category": "default"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"type": "flux_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"denoise_mask": {
|
||||
"name": "denoise_mask",
|
||||
"label": ""
|
||||
},
|
||||
"denoising_start": {
|
||||
"name": "denoising_start",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"denoising_end": {
|
||||
"name": "denoising_end",
|
||||
"label": "",
|
||||
"value": 1
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1186.1868226120378,
|
||||
"y": -214.9459927686657
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "flux_vae_decode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1575.5797431839133,
|
||||
"y": -209.00150975507415
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
@@ -99,8 +222,8 @@
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 824.1970602278849,
|
||||
"y": 146.98251001061735
|
||||
"x": 778.4899149328337,
|
||||
"y": -100.36469216659502
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -129,77 +252,52 @@
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 822.9899179655476,
|
||||
"y": 360.9657214885052
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"type": "flux_text_to_image",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1216.3900791301849,
|
||||
"y": 5.500841807102248
|
||||
"x": 800.9667463219505,
|
||||
"y": 285.8297267547506
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-4fe24f07-f906-4f55-ab2c-9beee56ef5bdtransformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-4fe24f07-f906-4f55-ab2c-9beee56ef5bdpositive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-4fe24f07-f906-4f55-ab2c-9beee56ef5bdseed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4fe24f07-f906-4f55-ab2c-9beee56ef5bdlatents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
|
||||
"type": "default",
|
||||
"source": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
@@ -208,14 +306,6 @@
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
@@ -231,30 +321,6 @@
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -38,6 +38,25 @@ SD1_5_LATENT_RGB_FACTORS = [
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
]
|
||||
|
||||
FLUX_LATENT_RGB_FACTORS = [
|
||||
[-0.0412, 0.0149, 0.0521],
|
||||
[0.0056, 0.0291, 0.0768],
|
||||
[0.0342, -0.0681, -0.0427],
|
||||
[-0.0258, 0.0092, 0.0463],
|
||||
[0.0863, 0.0784, 0.0547],
|
||||
[-0.0017, 0.0402, 0.0158],
|
||||
[0.0501, 0.1058, 0.1152],
|
||||
[-0.0209, -0.0218, -0.0329],
|
||||
[-0.0314, 0.0083, 0.0896],
|
||||
[0.0851, 0.0665, -0.0472],
|
||||
[-0.0534, 0.0238, -0.0024],
|
||||
[0.0452, -0.0026, 0.0048],
|
||||
[0.0892, 0.0831, 0.0881],
|
||||
[-0.1117, -0.0304, -0.0789],
|
||||
[0.0027, -0.0479, -0.0043],
|
||||
[-0.1146, -0.0827, -0.0598],
|
||||
]
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(
|
||||
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
|
||||
@@ -94,3 +113,32 @@ def stable_diffusion_step_callback(
|
||||
intermediate_state,
|
||||
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
)
|
||||
|
||||
|
||||
def flux_step_callback(
|
||||
context_data: "InvocationContextData",
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
events: "EventServiceBase",
|
||||
is_canceled: Callable[[], bool],
|
||||
) -> None:
|
||||
if is_canceled():
|
||||
raise CanceledException
|
||||
sample = intermediate_state.latents
|
||||
latent_rgb_factors = torch.tensor(FLUX_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
latent_image_perm = sample.permute(1, 2, 0).to(dtype=sample.dtype, device=sample.device)
|
||||
latent_image = latent_image_perm @ latent_rgb_factors
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF) # change scale from -1..1 to 0..1 # to 0..255
|
||||
).to(device="cpu", dtype=torch.uint8)
|
||||
image = Image.fromarray(latents_ubyte.cpu().numpy())
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
events.emit_invocation_denoise_progress(
|
||||
context_data.queue_item,
|
||||
context_data.invocation,
|
||||
intermediate_state,
|
||||
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
)
|
||||
|
||||
56
invokeai/backend/flux/denoise.py
Normal file
56
invokeai/backend/flux/denoise.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
guidance: float,
|
||||
inpaint_extension: InpaintExtension | None,
|
||||
):
|
||||
step = 0
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
preview_img = img - t_curr * pred
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
if inpaint_extension is not None:
|
||||
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step,
|
||||
order=1,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t_curr),
|
||||
latents=preview_img,
|
||||
),
|
||||
)
|
||||
step += 1
|
||||
|
||||
return img
|
||||
35
invokeai/backend/flux/inpaint_extension.py
Normal file
35
invokeai/backend/flux/inpaint_extension.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
|
||||
|
||||
class InpaintExtension:
|
||||
"""A class for managing inpainting with FLUX."""
|
||||
|
||||
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
|
||||
"""Initialize InpaintExtension.
|
||||
|
||||
Args:
|
||||
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format.
|
||||
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
|
||||
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
|
||||
inpainted region with the background. In 'packed' format.
|
||||
noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format.
|
||||
"""
|
||||
assert init_latents.shape == inpaint_mask.shape == noise.shape
|
||||
self._init_latents = init_latents
|
||||
self._inpaint_mask = inpaint_mask
|
||||
self._noise = noise
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
self, intermediate_latents: torch.Tensor, timestep: float
|
||||
) -> torch.Tensor:
|
||||
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
|
||||
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
|
||||
trajectory.
|
||||
|
||||
This function should be called after each denoising step.
|
||||
"""
|
||||
# Noise the init latents for the current timestep.
|
||||
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
|
||||
|
||||
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)
|
||||
@@ -258,16 +258,17 @@ class Decoder(nn.Module):
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||
def __init__(self, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if self.sample:
|
||||
if sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
# Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we
|
||||
# have to use torch.randn(...) instead.
|
||||
return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device)
|
||||
else:
|
||||
return mean
|
||||
|
||||
@@ -297,8 +298,21 @@ class AutoEncoder(nn.Module):
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
z = self.reg(self.encoder(x))
|
||||
def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||
"""Run VAE encoding on input tensor x.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width).
|
||||
sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean.
|
||||
Defaults to True.
|
||||
generator (torch.Generator | None, optional): Optional random number generator for reproducibility.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
|
||||
"""
|
||||
|
||||
z = self.reg(self.encoder(x), sample=sample, generator=generator)
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt)
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[], None],
|
||||
guidance: float = 4.0,
|
||||
):
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
step_callback()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
|
||||
|
||||
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert an input image in latent space to patches for diffusion.
|
||||
|
||||
This implementation was extracted from:
|
||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||
"""
|
||||
bs, c, h, w = latent_img.shape
|
||||
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
return img, img_ids
|
||||
135
invokeai/backend/flux/sampling_utils.py
Normal file
135
invokeai/backend/flux/sampling_utils.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# estimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int:
|
||||
"""Find the last index in timesteps that is >= val.
|
||||
|
||||
We use epsilon-close equality to avoid potential floating point errors.
|
||||
"""
|
||||
idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1
|
||||
assert idx >= 0
|
||||
return idx
|
||||
|
||||
|
||||
def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]:
|
||||
"""Clip the timestep schedule to the denoising range.
|
||||
|
||||
Args:
|
||||
timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0].
|
||||
denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2
|
||||
would mean that the denoising process start at the last timestep in the schedule >= 0.8.
|
||||
denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would
|
||||
mean that the denoising process end at the last timestep in the schedule >= 0.2.
|
||||
|
||||
Returns:
|
||||
list[float]: The clipped timestep schedule.
|
||||
"""
|
||||
assert 0.0 <= denoising_start <= 1.0
|
||||
assert 0.0 <= denoising_end <= 1.0
|
||||
assert denoising_start <= denoising_end
|
||||
|
||||
t_start_val = 1.0 - denoising_start
|
||||
t_end_val = 1.0 - denoising_end
|
||||
|
||||
t_start_idx = _find_last_index_ge_val(timesteps, t_start_val)
|
||||
t_end_idx = _find_last_index_ge_val(timesteps, t_end_val)
|
||||
|
||||
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
|
||||
|
||||
return clipped_timesteps
|
||||
|
||||
|
||||
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""Unpack flat array of patch embeddings to latent image."""
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
|
||||
|
||||
def pack(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack latent image to flattented array of patch embeddings."""
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
|
||||
|
||||
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids.
|
||||
|
||||
Args:
|
||||
h (int): Height of image in latent space.
|
||||
w (int): Width of image in latent space.
|
||||
batch_size (int): Batch size.
|
||||
device (torch.device): Device.
|
||||
dtype (torch.dtype): dtype.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image position ids.
|
||||
"""
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
return img_ids
|
||||
40
invokeai/backend/image_util/content_shuffle.py
Normal file
40
invokeai/backend/image_util/content_shuffle.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Adapted from https://github.com/huggingface/controlnet_aux
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
def make_noise_disk(H, W, C, F):
|
||||
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
|
||||
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
|
||||
noise = noise[F : F + H, F : F + W]
|
||||
noise -= np.min(noise)
|
||||
noise /= np.max(noise)
|
||||
if C == 1:
|
||||
noise = noise[:, :, None]
|
||||
return noise
|
||||
|
||||
|
||||
def content_shuffle(input_image: Image.Image, scale_factor: int | None = None) -> Image.Image:
|
||||
"""Shuffles the content of an image using a disk noise pattern, similar to a 'liquify' effect."""
|
||||
|
||||
np_img = pil_to_np(input_image)
|
||||
|
||||
height, width, _channels = np_img.shape
|
||||
|
||||
if scale_factor is None:
|
||||
scale_factor = 256
|
||||
|
||||
x = make_noise_disk(height, width, 1, scale_factor) * float(width - 1)
|
||||
y = make_noise_disk(height, width, 1, scale_factor) * float(height - 1)
|
||||
|
||||
flow = np.concatenate([x, y], axis=2).astype(np.float32)
|
||||
|
||||
shuffled_img = cv2.remap(np_img, flow, None, cv2.INTER_LINEAR)
|
||||
|
||||
output_img = np_to_pil(shuffled_img)
|
||||
|
||||
return output_img
|
||||
@@ -1,7 +1,9 @@
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
@@ -29,3 +31,11 @@ class DepthAnythingPipeline(RawModel):
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._pipeline.model)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path: pathlib.Path):
|
||||
"""Load the model from the given path and return a DepthAnythingPipeline instance."""
|
||||
|
||||
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
||||
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
||||
return cls(depth_anything_pipeline)
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from controlnet_aux.util import resize_image
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.dw_openpose.onnxdet import inference_detector
|
||||
from invokeai.backend.image_util.dw_openpose.onnxpose import inference_pose
|
||||
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||
from invokeai.backend.image_util.util import np_to_pil
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
@@ -109,4 +115,142 @@ class DWOpenposeDetector:
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"]
|
||||
class DWOpenposeDetector2:
|
||||
"""
|
||||
Code from the original implementation of the DW Openpose Detector.
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
|
||||
This implementation is similar to DWOpenposeDetector, with some alterations to allow the onnx models to be loaded
|
||||
and managed by the model manager.
|
||||
"""
|
||||
|
||||
hf_repo_id = "yzd-v/DWPose"
|
||||
hf_filename_onnx_det = "yolox_l.onnx"
|
||||
hf_filename_onnx_pose = "dw-ll_ucoco_384.onnx"
|
||||
|
||||
@classmethod
|
||||
def get_model_url_det(cls) -> str:
|
||||
"""Returns the URL for the detection model."""
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename_onnx_det)
|
||||
|
||||
@classmethod
|
||||
def get_model_url_pose(cls) -> str:
|
||||
"""Returns the URL for the pose model."""
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename_onnx_pose)
|
||||
|
||||
@staticmethod
|
||||
def create_onnx_inference_session(model_path: Path) -> ort.InferenceSession:
|
||||
"""Creates an ONNX Inference Session for the given model path, using the appropriate execution provider based on
|
||||
the device type."""
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||
return ort.InferenceSession(path_or_bytes=model_path, providers=providers)
|
||||
|
||||
def __init__(self, session_det: ort.InferenceSession, session_pose: ort.InferenceSession):
|
||||
self.session_det = session_det
|
||||
self.session_pose = session_pose
|
||||
|
||||
def pose_estimation(self, np_image: np.ndarray):
|
||||
"""Does the pose estimation on the given image and returns the keypoints and scores."""
|
||||
|
||||
det_result = inference_detector(self.session_det, np_image)
|
||||
keypoints, scores = inference_pose(self.session_pose, det_result, np_image)
|
||||
|
||||
keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
|
||||
# compute neck joint
|
||||
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
||||
# neck score when visualizing pred
|
||||
neck[:, 2:4] = np.logical_and(keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
||||
new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
|
||||
mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
|
||||
openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
|
||||
new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
|
||||
keypoints_info = new_keypoints_info
|
||||
|
||||
keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
|
||||
|
||||
return keypoints, scores
|
||||
|
||||
def run(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw_face: bool = False,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = False,
|
||||
) -> Image.Image:
|
||||
"""Detects the pose in the given image and returns an solid black image with pose drawn on top, suitable for
|
||||
use with a ControlNet."""
|
||||
|
||||
np_image = np.array(image)
|
||||
H, W, C = np_image.shape
|
||||
|
||||
with torch.no_grad():
|
||||
candidate, subset = self.pose_estimation(np_image)
|
||||
nums, keys, locs = candidate.shape
|
||||
candidate[..., 0] /= float(W)
|
||||
candidate[..., 1] /= float(H)
|
||||
body = candidate[:, :18].copy()
|
||||
body = body.reshape(nums * 18, locs)
|
||||
score = subset[:, :18]
|
||||
for i in range(len(score)):
|
||||
for j in range(len(score[i])):
|
||||
if score[i][j] > 0.3:
|
||||
score[i][j] = int(18 * i + j)
|
||||
else:
|
||||
score[i][j] = -1
|
||||
|
||||
un_visible = subset < 0.3
|
||||
candidate[un_visible] = -1
|
||||
|
||||
# foot = candidate[:, 18:24]
|
||||
|
||||
faces = candidate[:, 24:92]
|
||||
|
||||
hands = candidate[:, 92:113]
|
||||
hands = np.vstack([hands, candidate[:, 113:]])
|
||||
|
||||
bodies = {"candidate": body, "subset": score}
|
||||
pose = {"bodies": bodies, "hands": hands, "faces": faces}
|
||||
|
||||
return DWOpenposeDetector2.draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def draw_pose(
|
||||
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
|
||||
H: int,
|
||||
W: int,
|
||||
draw_face: bool = True,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = True,
|
||||
) -> Image.Image:
|
||||
"""Draws the pose on a black image and returns it as a PIL Image."""
|
||||
|
||||
bodies = pose["bodies"]
|
||||
faces = pose["faces"]
|
||||
hands = pose["hands"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
candidate = bodies["candidate"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
subset = bodies["subset"]
|
||||
|
||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||
|
||||
if draw_body:
|
||||
canvas = draw_bodypose(canvas, candidate, subset)
|
||||
|
||||
if draw_hands:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_handpose(canvas, hands)
|
||||
|
||||
if draw_face:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_facepose(canvas, faces) # type: ignore
|
||||
|
||||
dwpose_image = np_to_pil(canvas)
|
||||
|
||||
return dwpose_image
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
|
||||
# Adapted from https://github.com/huggingface/controlnet_aux
|
||||
|
||||
import pathlib
|
||||
|
||||
import cv2
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
@@ -140,3 +143,74 @@ class HEDProcessor:
|
||||
detected_map[detected_map < 255] = 0
|
||||
|
||||
return np_to_pil(detected_map)
|
||||
|
||||
|
||||
class HEDEdgeDetector:
|
||||
"""Simple wrapper around the HED model for detecting edges in an image."""
|
||||
|
||||
hf_repo_id = "lllyasviel/Annotators"
|
||||
hf_filename = "ControlNetHED.pth"
|
||||
|
||||
def __init__(self, model: ControlNetHED_Apache2):
|
||||
self.model = model
|
||||
|
||||
@classmethod
|
||||
def get_model_url(cls) -> str:
|
||||
"""Get the URL to download the model from the Hugging Face Hub."""
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path: pathlib.Path) -> ControlNetHED_Apache2:
|
||||
"""Load the model from a file."""
|
||||
model = ControlNetHED_Apache2()
|
||||
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||
model.float().eval()
|
||||
return model
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def run(self, image: Image.Image, safe: bool = False, scribble: bool = False) -> Image.Image:
|
||||
"""Processes an image and returns the detected edges.
|
||||
|
||||
Args:
|
||||
image: The input image.
|
||||
safe: Whether to apply safe step to the detected edges.
|
||||
scribble: Whether to apply non-maximum suppression and Gaussian blur to the detected edges.
|
||||
|
||||
Returns:
|
||||
The detected edges.
|
||||
"""
|
||||
|
||||
device = next(iter(self.model.parameters())).device
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
height, width, _channels = np_image.shape
|
||||
|
||||
with torch.no_grad():
|
||||
image_hed = torch.from_numpy(np_image.copy()).float().to(device)
|
||||
image_hed = rearrange(image_hed, "h w c -> 1 c h w")
|
||||
edges = self.model(image_hed)
|
||||
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
|
||||
edges = [cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) for e in edges]
|
||||
edges = np.stack(edges, axis=2)
|
||||
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
||||
if safe:
|
||||
edge = safe_step(edge)
|
||||
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
detected_map = edge
|
||||
|
||||
detected_map = cv2.resize(detected_map, (width, height), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if scribble:
|
||||
detected_map = nms(detected_map, 127, 3.0)
|
||||
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
|
||||
detected_map[detected_map > 4] = 255
|
||||
detected_map[detected_map < 255] = 0
|
||||
|
||||
output = np_to_pil(detected_map)
|
||||
|
||||
return output
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
|
||||
|
||||
import pathlib
|
||||
|
||||
import cv2
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -156,3 +159,63 @@ class LineartProcessor:
|
||||
detected_map = 255 - detected_map
|
||||
|
||||
return np_to_pil(detected_map)
|
||||
|
||||
|
||||
class LineartEdgeDetector:
|
||||
"""Simple wrapper around the fine and coarse lineart models for detecting edges in an image."""
|
||||
|
||||
hf_repo_id = "lllyasviel/Annotators"
|
||||
hf_filename_fine = "sk_model.pth"
|
||||
hf_filename_coarse = "sk_model2.pth"
|
||||
|
||||
@classmethod
|
||||
def get_model_url(cls, coarse: bool = False) -> str:
|
||||
"""Get the URL to download the model from the Hugging Face Hub."""
|
||||
if coarse:
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename_coarse)
|
||||
else:
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename_fine)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path: pathlib.Path) -> Generator:
|
||||
"""Load the model from a file."""
|
||||
model = Generator(3, 1, 3)
|
||||
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||
model.float().eval()
|
||||
return model
|
||||
|
||||
def __init__(self, model: Generator) -> None:
|
||||
self.model = model
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def run(self, image: Image.Image) -> Image.Image:
|
||||
"""Detects edges in the input image with the selected lineart model.
|
||||
|
||||
Args:
|
||||
input: The input image.
|
||||
coarse: Whether to use the coarse model.
|
||||
|
||||
Returns:
|
||||
The detected edges.
|
||||
"""
|
||||
device = next(iter(self.model.parameters())).device
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
with torch.no_grad():
|
||||
np_image = torch.from_numpy(np_image).float().to(device)
|
||||
np_image = np_image / 255.0
|
||||
np_image = rearrange(np_image, "h w c -> 1 c h w")
|
||||
line = self.model(np_image)[0][0]
|
||||
|
||||
line = line.cpu().numpy()
|
||||
line = (line * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
detected_map = line
|
||||
|
||||
detected_map = 255 - detected_map
|
||||
|
||||
return np_to_pil(detected_map)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
|
||||
|
||||
import functools
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -201,3 +203,65 @@ class LineartAnimeProcessor:
|
||||
detected_map = 255 - detected_map
|
||||
|
||||
return np_to_pil(detected_map)
|
||||
|
||||
|
||||
class LineartAnimeEdgeDetector:
|
||||
"""Simple wrapper around the Lineart Anime model for detecting edges in an image."""
|
||||
|
||||
hf_repo_id = "lllyasviel/Annotators"
|
||||
hf_filename = "netG.pth"
|
||||
|
||||
@classmethod
|
||||
def get_model_url(cls) -> str:
|
||||
"""Get the URL to download the model from the Hugging Face Hub."""
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path: pathlib.Path) -> UnetGenerator:
|
||||
"""Load the model from a file."""
|
||||
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||
model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
|
||||
ckpt = torch.load(model_path)
|
||||
for key in list(ckpt.keys()):
|
||||
if "module." in key:
|
||||
ckpt[key.replace("module.", "")] = ckpt[key]
|
||||
del ckpt[key]
|
||||
model.load_state_dict(ckpt)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def __init__(self, model: UnetGenerator) -> None:
|
||||
self.model = model
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def run(self, image: Image.Image) -> Image.Image:
|
||||
"""Processes an image and returns the detected edges."""
|
||||
device = next(iter(self.model.parameters())).device
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
height, width, _channels = np_image.shape
|
||||
new_height = 256 * int(np.ceil(float(height) / 256.0))
|
||||
new_width = 256 * int(np.ceil(float(width) / 256.0))
|
||||
|
||||
resized_img = cv2.resize(np_image, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
with torch.no_grad():
|
||||
image_feed = torch.from_numpy(resized_img).float().to(device)
|
||||
image_feed = image_feed / 127.5 - 1.0
|
||||
image_feed = rearrange(image_feed, "h w c -> 1 c h w")
|
||||
|
||||
line = self.model(image_feed)[0, 0] * 127.5 + 127.5
|
||||
line = line.cpu().numpy()
|
||||
|
||||
line = cv2.resize(line, (width, height), interpolation=cv2.INTER_CUBIC)
|
||||
line = line.clip(0, 255).astype(np.uint8)
|
||||
|
||||
detected_map = line
|
||||
detected_map = 255 - detected_map
|
||||
output = np_to_pil(detected_map)
|
||||
|
||||
return output
|
||||
|
||||
15
invokeai/backend/image_util/mediapipe_face/__init__.py
Normal file
15
invokeai/backend/image_util/mediapipe_face/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Adapted from https://github.com/huggingface/controlnet_aux
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.mediapipe_face.mediapipe_face_common import generate_annotation
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
def detect_faces(image: Image.Image, max_faces: int = 1, min_confidence: float = 0.5) -> Image.Image:
|
||||
"""Detects faces in an image using MediaPipe."""
|
||||
|
||||
np_img = pil_to_np(image)
|
||||
detected_map = generate_annotation(np_img, max_faces, min_confidence)
|
||||
detected_map_pil = np_to_pil(detected_map)
|
||||
return detected_map_pil
|
||||
@@ -0,0 +1,149 @@
|
||||
from typing import Mapping
|
||||
|
||||
import mediapipe as mp
|
||||
import numpy
|
||||
|
||||
mp_drawing = mp.solutions.drawing_utils
|
||||
mp_drawing_styles = mp.solutions.drawing_styles
|
||||
mp_face_detection = mp.solutions.face_detection # Only for counting faces.
|
||||
mp_face_mesh = mp.solutions.face_mesh
|
||||
mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION
|
||||
mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS
|
||||
mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS
|
||||
|
||||
DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
|
||||
PoseLandmark = mp.solutions.drawing_styles.PoseLandmark
|
||||
|
||||
min_face_size_pixels: int = 64
|
||||
f_thick = 2
|
||||
f_rad = 1
|
||||
right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
|
||||
right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
|
||||
right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
|
||||
left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
||||
left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
||||
left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
|
||||
mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad)
|
||||
head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
||||
|
||||
# mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
|
||||
face_connection_spec = {}
|
||||
for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
|
||||
face_connection_spec[edge] = head_draw
|
||||
for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
|
||||
face_connection_spec[edge] = left_eye_draw
|
||||
for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
|
||||
face_connection_spec[edge] = left_eyebrow_draw
|
||||
# for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
|
||||
# face_connection_spec[edge] = left_iris_draw
|
||||
for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
|
||||
face_connection_spec[edge] = right_eye_draw
|
||||
for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
|
||||
face_connection_spec[edge] = right_eyebrow_draw
|
||||
# for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
|
||||
# face_connection_spec[edge] = right_iris_draw
|
||||
for edge in mp_face_mesh.FACEMESH_LIPS:
|
||||
face_connection_spec[edge] = mouth_draw
|
||||
iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
|
||||
|
||||
|
||||
def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2):
|
||||
"""We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
|
||||
landmarks. Until our PR is merged into mediapipe, we need this separate method."""
|
||||
if len(image.shape) != 3:
|
||||
raise ValueError("Input image must be H,W,C.")
|
||||
image_rows, image_cols, image_channels = image.shape
|
||||
if image_channels != 3: # BGR channels
|
||||
raise ValueError("Input image must contain three channel bgr data.")
|
||||
for idx, landmark in enumerate(landmark_list.landmark):
|
||||
if (landmark.HasField("visibility") and landmark.visibility < 0.9) or (
|
||||
landmark.HasField("presence") and landmark.presence < 0.5
|
||||
):
|
||||
continue
|
||||
if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
|
||||
continue
|
||||
image_x = int(image_cols * landmark.x)
|
||||
image_y = int(image_rows * landmark.y)
|
||||
draw_color = None
|
||||
if isinstance(drawing_spec, Mapping):
|
||||
if drawing_spec.get(idx) is None:
|
||||
continue
|
||||
else:
|
||||
draw_color = drawing_spec[idx].color
|
||||
elif isinstance(drawing_spec, DrawingSpec):
|
||||
draw_color = drawing_spec.color
|
||||
image[image_y - halfwidth : image_y + halfwidth, image_x - halfwidth : image_x + halfwidth, :] = draw_color
|
||||
|
||||
|
||||
def reverse_channels(image):
|
||||
"""Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB."""
|
||||
# im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order.
|
||||
# im[:,:,::[2,1,0]] would also work but makes a copy of the data.
|
||||
return image[:, :, ::-1]
|
||||
|
||||
|
||||
def generate_annotation(img_rgb, max_faces: int, min_confidence: float):
|
||||
"""
|
||||
Find up to 'max_faces' inside the provided input image.
|
||||
If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many
|
||||
pixels in the image.
|
||||
"""
|
||||
with mp_face_mesh.FaceMesh(
|
||||
static_image_mode=True,
|
||||
max_num_faces=max_faces,
|
||||
refine_landmarks=True,
|
||||
min_detection_confidence=min_confidence,
|
||||
) as facemesh:
|
||||
img_height, img_width, img_channels = img_rgb.shape
|
||||
assert img_channels == 3
|
||||
|
||||
results = facemesh.process(img_rgb).multi_face_landmarks
|
||||
|
||||
if results is None:
|
||||
print("No faces detected in controlnet image for Mediapipe face annotator.")
|
||||
return numpy.zeros_like(img_rgb)
|
||||
|
||||
# Filter faces that are too small
|
||||
filtered_landmarks = []
|
||||
for lm in results:
|
||||
landmarks = lm.landmark
|
||||
face_rect = [
|
||||
landmarks[0].x,
|
||||
landmarks[0].y,
|
||||
landmarks[0].x,
|
||||
landmarks[0].y,
|
||||
] # Left, up, right, down.
|
||||
for i in range(len(landmarks)):
|
||||
face_rect[0] = min(face_rect[0], landmarks[i].x)
|
||||
face_rect[1] = min(face_rect[1], landmarks[i].y)
|
||||
face_rect[2] = max(face_rect[2], landmarks[i].x)
|
||||
face_rect[3] = max(face_rect[3], landmarks[i].y)
|
||||
if min_face_size_pixels > 0:
|
||||
face_width = abs(face_rect[2] - face_rect[0])
|
||||
face_height = abs(face_rect[3] - face_rect[1])
|
||||
face_width_pixels = face_width * img_width
|
||||
face_height_pixels = face_height * img_height
|
||||
face_size = min(face_width_pixels, face_height_pixels)
|
||||
if face_size >= min_face_size_pixels:
|
||||
filtered_landmarks.append(lm)
|
||||
else:
|
||||
filtered_landmarks.append(lm)
|
||||
|
||||
# Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start.
|
||||
empty = numpy.zeros_like(img_rgb)
|
||||
|
||||
# Draw detected faces:
|
||||
for face_landmarks in filtered_landmarks:
|
||||
mp_drawing.draw_landmarks(
|
||||
empty,
|
||||
face_landmarks,
|
||||
connections=face_connection_spec.keys(),
|
||||
landmark_drawing_spec=None,
|
||||
connection_drawing_spec=face_connection_spec,
|
||||
)
|
||||
draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)
|
||||
|
||||
# Flip BGR back to RGB.
|
||||
empty = reverse_channels(empty).copy()
|
||||
|
||||
return empty
|
||||
66
invokeai/backend/image_util/mlsd/__init__.py
Normal file
66
invokeai/backend/image_util/mlsd/__init__.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Adapted from https://github.com/huggingface/controlnet_aux
|
||||
|
||||
import pathlib
|
||||
|
||||
import cv2
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large
|
||||
from invokeai.backend.image_util.mlsd.utils import pred_lines
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
|
||||
|
||||
|
||||
class MLSDDetector:
|
||||
"""Simple wrapper around a MLSD model for detecting edges as line segments in an image."""
|
||||
|
||||
hf_repo_id = "lllyasviel/ControlNet"
|
||||
hf_filename = "annotator/ckpts/mlsd_large_512_fp32.pth"
|
||||
|
||||
@classmethod
|
||||
def get_model_url(cls) -> str:
|
||||
"""Get the URL to download the model from the Hugging Face Hub."""
|
||||
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path: pathlib.Path) -> MobileV2_MLSD_Large:
|
||||
"""Load the model from a file."""
|
||||
|
||||
model = MobileV2_MLSD_Large()
|
||||
model.load_state_dict(torch.load(model_path), strict=True)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def __init__(self, model: MobileV2_MLSD_Large) -> None:
|
||||
self.model = model
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def run(self, image: Image.Image, score_threshold: float = 0.1, distance_threshold: float = 20.0) -> Image.Image:
|
||||
"""Processes an image and returns the detected edges."""
|
||||
|
||||
np_img = pil_to_np(image)
|
||||
|
||||
height, width, _channels = np_img.shape
|
||||
|
||||
# This model requires the input image to have a resolution that is a multiple of 64
|
||||
np_img = resize_to_multiple(np_img, 64)
|
||||
img_output = np.zeros_like(np_img)
|
||||
|
||||
with torch.no_grad():
|
||||
lines = pred_lines(np_img, self.model, [np_img.shape[0], np_img.shape[1]], score_threshold, distance_threshold)
|
||||
for line in lines:
|
||||
x_start, y_start, x_end, y_end = [int(val) for val in line]
|
||||
cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
|
||||
|
||||
detected_map = img_output[:, :, 0]
|
||||
|
||||
# Back to the original size
|
||||
output_image = cv2.resize(detected_map, (width, height), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
return np_to_pil(output_image)
|
||||
0
invokeai/backend/image_util/mlsd/models/__init__.py
Normal file
0
invokeai/backend/image_util/mlsd/models/__init__.py
Normal file
290
invokeai/backend/image_util/mlsd/models/mbv2_mlsd_large.py
Normal file
290
invokeai/backend/image_util/mlsd/models/mbv2_mlsd_large.py
Normal file
@@ -0,0 +1,290 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class BlockTypeA(nn.Module):
|
||||
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
|
||||
super(BlockTypeA, self).__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_c2, out_c2, kernel_size=1),
|
||||
nn.BatchNorm2d(out_c2),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(in_c1, out_c1, kernel_size=1),
|
||||
nn.BatchNorm2d(out_c1),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.upscale = upscale
|
||||
|
||||
def forward(self, a, b):
|
||||
b = self.conv1(b)
|
||||
a = self.conv2(a)
|
||||
if self.upscale:
|
||||
b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
|
||||
return torch.cat((a, b), dim=1)
|
||||
|
||||
|
||||
class BlockTypeB(nn.Module):
|
||||
def __init__(self, in_c, out_c):
|
||||
super(BlockTypeB, self).__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(in_c),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_c),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x) + x
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
class BlockTypeC(nn.Module):
|
||||
def __init__(self, in_c, out_c):
|
||||
super(BlockTypeC, self).__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
|
||||
nn.BatchNorm2d(in_c),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(in_c),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
return x
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
self.channel_pad = out_planes - in_planes
|
||||
self.stride = stride
|
||||
#padding = (kernel_size - 1) // 2
|
||||
|
||||
# TFLite uses slightly different padding than PyTorch
|
||||
if stride == 2:
|
||||
padding = 0
|
||||
else:
|
||||
padding = (kernel_size - 1) // 2
|
||||
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# TFLite uses different padding
|
||||
if self.stride == 2:
|
||||
x = F.pad(x, (0, 1, 0, 1), "constant", 0)
|
||||
#print(x.shape)
|
||||
|
||||
for module in self:
|
||||
if not isinstance(module, nn.MaxPool2d):
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self, pretrained=True):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
width_mult = 1.0
|
||||
round_nearest = 8
|
||||
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
#[6, 160, 3, 2],
|
||||
#[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
features = [ConvBNReLU(4, input_channel, stride=2)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
|
||||
self.features = nn.Sequential(*features)
|
||||
self.fpn_selected = [1, 3, 6, 10, 13]
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
if pretrained:
|
||||
self._load_pretrained_model()
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||
fpn_features = []
|
||||
for i, f in enumerate(self.features):
|
||||
if i > self.fpn_selected[-1]:
|
||||
break
|
||||
x = f(x)
|
||||
if i in self.fpn_selected:
|
||||
fpn_features.append(x)
|
||||
|
||||
c1, c2, c3, c4, c5 = fpn_features
|
||||
return c1, c2, c3, c4, c5
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def _load_pretrained_model(self):
|
||||
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
|
||||
model_dict = {}
|
||||
state_dict = self.state_dict()
|
||||
for k, v in pretrain_dict.items():
|
||||
if k in state_dict:
|
||||
model_dict[k] = v
|
||||
state_dict.update(model_dict)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
|
||||
class MobileV2_MLSD_Large(nn.Module):
|
||||
def __init__(self):
|
||||
super(MobileV2_MLSD_Large, self).__init__()
|
||||
|
||||
self.backbone = MobileNetV2(pretrained=False)
|
||||
## A, B
|
||||
self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
|
||||
out_c1= 64, out_c2=64,
|
||||
upscale=False)
|
||||
self.block16 = BlockTypeB(128, 64)
|
||||
|
||||
## A, B
|
||||
self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
|
||||
out_c1= 64, out_c2= 64)
|
||||
self.block18 = BlockTypeB(128, 64)
|
||||
|
||||
## A, B
|
||||
self.block19 = BlockTypeA(in_c1=24, in_c2=64,
|
||||
out_c1=64, out_c2=64)
|
||||
self.block20 = BlockTypeB(128, 64)
|
||||
|
||||
## A, B, C
|
||||
self.block21 = BlockTypeA(in_c1=16, in_c2=64,
|
||||
out_c1=64, out_c2=64)
|
||||
self.block22 = BlockTypeB(128, 64)
|
||||
|
||||
self.block23 = BlockTypeC(64, 16)
|
||||
|
||||
def forward(self, x):
|
||||
c1, c2, c3, c4, c5 = self.backbone(x)
|
||||
|
||||
x = self.block15(c4, c5)
|
||||
x = self.block16(x)
|
||||
|
||||
x = self.block17(c3, x)
|
||||
x = self.block18(x)
|
||||
|
||||
x = self.block19(c2, x)
|
||||
x = self.block20(x)
|
||||
|
||||
x = self.block21(c1, x)
|
||||
x = self.block22(x)
|
||||
x = self.block23(x)
|
||||
x = x[:, 7:, :, :]
|
||||
|
||||
return x
|
||||
273
invokeai/backend/image_util/mlsd/models/mbv2_mlsd_tiny.py
Normal file
273
invokeai/backend/image_util/mlsd/models/mbv2_mlsd_tiny.py
Normal file
@@ -0,0 +1,273 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class BlockTypeA(nn.Module):
|
||||
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
|
||||
super(BlockTypeA, self).__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_c2, out_c2, kernel_size=1),
|
||||
nn.BatchNorm2d(out_c2),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(in_c1, out_c1, kernel_size=1),
|
||||
nn.BatchNorm2d(out_c1),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.upscale = upscale
|
||||
|
||||
def forward(self, a, b):
|
||||
b = self.conv1(b)
|
||||
a = self.conv2(a)
|
||||
b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
|
||||
return torch.cat((a, b), dim=1)
|
||||
|
||||
|
||||
class BlockTypeB(nn.Module):
|
||||
def __init__(self, in_c, out_c):
|
||||
super(BlockTypeB, self).__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(in_c),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_c),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x) + x
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
class BlockTypeC(nn.Module):
|
||||
def __init__(self, in_c, out_c):
|
||||
super(BlockTypeC, self).__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
|
||||
nn.BatchNorm2d(in_c),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(in_c),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
return x
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
self.channel_pad = out_planes - in_planes
|
||||
self.stride = stride
|
||||
#padding = (kernel_size - 1) // 2
|
||||
|
||||
# TFLite uses slightly different padding than PyTorch
|
||||
if stride == 2:
|
||||
padding = 0
|
||||
else:
|
||||
padding = (kernel_size - 1) // 2
|
||||
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# TFLite uses different padding
|
||||
if self.stride == 2:
|
||||
x = F.pad(x, (0, 1, 0, 1), "constant", 0)
|
||||
#print(x.shape)
|
||||
|
||||
for module in self:
|
||||
if not isinstance(module, nn.MaxPool2d):
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self, pretrained=True):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
width_mult = 1.0
|
||||
round_nearest = 8
|
||||
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
#[6, 96, 3, 1],
|
||||
#[6, 160, 3, 2],
|
||||
#[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
features = [ConvBNReLU(4, input_channel, stride=2)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
self.fpn_selected = [3, 6, 10]
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
#if pretrained:
|
||||
# self._load_pretrained_model()
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||
fpn_features = []
|
||||
for i, f in enumerate(self.features):
|
||||
if i > self.fpn_selected[-1]:
|
||||
break
|
||||
x = f(x)
|
||||
if i in self.fpn_selected:
|
||||
fpn_features.append(x)
|
||||
|
||||
c2, c3, c4 = fpn_features
|
||||
return c2, c3, c4
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def _load_pretrained_model(self):
|
||||
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
|
||||
model_dict = {}
|
||||
state_dict = self.state_dict()
|
||||
for k, v in pretrain_dict.items():
|
||||
if k in state_dict:
|
||||
model_dict[k] = v
|
||||
state_dict.update(model_dict)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
|
||||
class MobileV2_MLSD_Tiny(nn.Module):
|
||||
def __init__(self):
|
||||
super(MobileV2_MLSD_Tiny, self).__init__()
|
||||
|
||||
self.backbone = MobileNetV2(pretrained=True)
|
||||
|
||||
self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
|
||||
out_c1= 64, out_c2=64)
|
||||
self.block13 = BlockTypeB(128, 64)
|
||||
|
||||
self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
|
||||
out_c1= 32, out_c2= 32)
|
||||
self.block15 = BlockTypeB(64, 64)
|
||||
|
||||
self.block16 = BlockTypeC(64, 16)
|
||||
|
||||
def forward(self, x):
|
||||
c2, c3, c4 = self.backbone(x)
|
||||
|
||||
x = self.block12(c3, c4)
|
||||
x = self.block13(x)
|
||||
x = self.block14(c2, x)
|
||||
x = self.block15(x)
|
||||
x = self.block16(x)
|
||||
x = x[:, 7:, :, :]
|
||||
#print(x.shape)
|
||||
x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
|
||||
|
||||
return x
|
||||
587
invokeai/backend/image_util/mlsd/utils.py
Normal file
587
invokeai/backend/image_util/mlsd/utils.py
Normal file
@@ -0,0 +1,587 @@
|
||||
'''
|
||||
modified by lihaoweicv
|
||||
pytorch version
|
||||
'''
|
||||
|
||||
'''
|
||||
M-LSD
|
||||
Copyright 2021-present NAVER Corp.
|
||||
Apache License v2.0
|
||||
'''
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
|
||||
'''
|
||||
tpMap:
|
||||
center: tpMap[1, 0, :, :]
|
||||
displacement: tpMap[1, 1:5, :, :]
|
||||
'''
|
||||
b, c, h, w = tpMap.shape
|
||||
assert b==1, 'only support bsize==1'
|
||||
displacement = tpMap[:, 1:5, :, :][0]
|
||||
center = tpMap[:, 0, :, :]
|
||||
heat = torch.sigmoid(center)
|
||||
hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
|
||||
keep = (hmax == heat).float()
|
||||
heat = heat * keep
|
||||
heat = heat.reshape(-1, )
|
||||
|
||||
scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
|
||||
yy = torch.floor_divide(indices, w).unsqueeze(-1)
|
||||
xx = torch.fmod(indices, w).unsqueeze(-1)
|
||||
ptss = torch.cat((yy, xx),dim=-1)
|
||||
|
||||
ptss = ptss.detach().cpu().numpy()
|
||||
scores = scores.detach().cpu().numpy()
|
||||
displacement = displacement.detach().cpu().numpy()
|
||||
displacement = displacement.transpose((1,2,0))
|
||||
return ptss, scores, displacement
|
||||
|
||||
|
||||
def pred_lines(image, model,
|
||||
input_shape=[512, 512],
|
||||
score_thr=0.10,
|
||||
dist_thr=20.0):
|
||||
h, w, _ = image.shape
|
||||
|
||||
device = next(iter(model.parameters())).device
|
||||
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
|
||||
|
||||
resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
|
||||
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
|
||||
|
||||
resized_image = resized_image.transpose((2,0,1))
|
||||
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
|
||||
batch_image = (batch_image / 127.5) - 1.0
|
||||
|
||||
batch_image = torch.from_numpy(batch_image).float()
|
||||
batch_image = batch_image.to(device)
|
||||
outputs = model(batch_image)
|
||||
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
|
||||
start = vmap[:, :, :2]
|
||||
end = vmap[:, :, 2:]
|
||||
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
|
||||
|
||||
segments_list = []
|
||||
for center, score in zip(pts, pts_score, strict=False):
|
||||
y, x = center
|
||||
distance = dist_map[y, x]
|
||||
if score > score_thr and distance > dist_thr:
|
||||
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
|
||||
x_start = x + disp_x_start
|
||||
y_start = y + disp_y_start
|
||||
x_end = x + disp_x_end
|
||||
y_end = y + disp_y_end
|
||||
segments_list.append([x_start, y_start, x_end, y_end])
|
||||
|
||||
if segments_list:
|
||||
lines = 2 * np.array(segments_list) # 256 > 512
|
||||
lines[:, 0] = lines[:, 0] * w_ratio
|
||||
lines[:, 1] = lines[:, 1] * h_ratio
|
||||
lines[:, 2] = lines[:, 2] * w_ratio
|
||||
lines[:, 3] = lines[:, 3] * h_ratio
|
||||
else:
|
||||
# No segments detected - return empty array
|
||||
lines = np.array([])
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def pred_squares(image,
|
||||
model,
|
||||
input_shape=[512, 512],
|
||||
params={'score': 0.06,
|
||||
'outside_ratio': 0.28,
|
||||
'inside_ratio': 0.45,
|
||||
'w_overlap': 0.0,
|
||||
'w_degree': 1.95,
|
||||
'w_length': 0.0,
|
||||
'w_area': 1.86,
|
||||
'w_center': 0.14}):
|
||||
'''
|
||||
shape = [height, width]
|
||||
'''
|
||||
h, w, _ = image.shape
|
||||
original_shape = [h, w]
|
||||
device = next(iter(model.parameters())).device
|
||||
|
||||
resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
|
||||
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
|
||||
resized_image = resized_image.transpose((2, 0, 1))
|
||||
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
|
||||
batch_image = (batch_image / 127.5) - 1.0
|
||||
|
||||
batch_image = torch.from_numpy(batch_image).float().to(device)
|
||||
outputs = model(batch_image)
|
||||
|
||||
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
|
||||
start = vmap[:, :, :2] # (x, y)
|
||||
end = vmap[:, :, 2:] # (x, y)
|
||||
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
|
||||
|
||||
junc_list = []
|
||||
segments_list = []
|
||||
for junc, score in zip(pts, pts_score, strict=False):
|
||||
y, x = junc
|
||||
distance = dist_map[y, x]
|
||||
if score > params['score'] and distance > 20.0:
|
||||
junc_list.append([x, y])
|
||||
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
|
||||
d_arrow = 1.0
|
||||
x_start = x + d_arrow * disp_x_start
|
||||
y_start = y + d_arrow * disp_y_start
|
||||
x_end = x + d_arrow * disp_x_end
|
||||
y_end = y + d_arrow * disp_y_end
|
||||
segments_list.append([x_start, y_start, x_end, y_end])
|
||||
|
||||
segments = np.array(segments_list)
|
||||
|
||||
####### post processing for squares
|
||||
# 1. get unique lines
|
||||
point = np.array([[0, 0]])
|
||||
point = point[0]
|
||||
start = segments[:, :2]
|
||||
end = segments[:, 2:]
|
||||
diff = start - end
|
||||
a = diff[:, 1]
|
||||
b = -diff[:, 0]
|
||||
c = a * start[:, 0] + b * start[:, 1]
|
||||
|
||||
d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
|
||||
theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
|
||||
theta[theta < 0.0] += 180
|
||||
hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
|
||||
|
||||
d_quant = 1
|
||||
theta_quant = 2
|
||||
hough[:, 0] //= d_quant
|
||||
hough[:, 1] //= theta_quant
|
||||
_, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
|
||||
|
||||
acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
|
||||
idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
|
||||
yx_indices = hough[indices, :].astype('int32')
|
||||
acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
|
||||
idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
|
||||
|
||||
acc_map_np = acc_map
|
||||
# acc_map = acc_map[None, :, :, None]
|
||||
#
|
||||
# ### fast suppression using tensorflow op
|
||||
# acc_map = tf.constant(acc_map, dtype=tf.float32)
|
||||
# max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
|
||||
# acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
|
||||
# flatten_acc_map = tf.reshape(acc_map, [1, -1])
|
||||
# topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
|
||||
# _, h, w, _ = acc_map.shape
|
||||
# y = tf.expand_dims(topk_indices // w, axis=-1)
|
||||
# x = tf.expand_dims(topk_indices % w, axis=-1)
|
||||
# yx = tf.concat([y, x], axis=-1)
|
||||
|
||||
### fast suppression using pytorch op
|
||||
acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
|
||||
_,_, h, w = acc_map.shape
|
||||
max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
|
||||
acc_map = acc_map * ( (acc_map == max_acc_map).float() )
|
||||
flatten_acc_map = acc_map.reshape([-1, ])
|
||||
|
||||
scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
|
||||
yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
|
||||
xx = torch.fmod(indices, w).unsqueeze(-1)
|
||||
yx = torch.cat((yy, xx), dim=-1)
|
||||
|
||||
yx = yx.detach().cpu().numpy()
|
||||
|
||||
topk_values = scores.detach().cpu().numpy()
|
||||
indices = idx_map[yx[:, 0], yx[:, 1]]
|
||||
basis = 5 // 2
|
||||
|
||||
merged_segments = []
|
||||
for yx_pt, max_indice, value in zip(yx, indices, topk_values, strict=False):
|
||||
y, x = yx_pt
|
||||
if max_indice == -1 or value == 0:
|
||||
continue
|
||||
segment_list = []
|
||||
for y_offset in range(-basis, basis + 1):
|
||||
for x_offset in range(-basis, basis + 1):
|
||||
indice = idx_map[y + y_offset, x + x_offset]
|
||||
cnt = int(acc_map_np[y + y_offset, x + x_offset])
|
||||
if indice != -1:
|
||||
segment_list.append(segments[indice])
|
||||
if cnt > 1:
|
||||
check_cnt = 1
|
||||
current_hough = hough[indice]
|
||||
for new_indice, new_hough in enumerate(hough):
|
||||
if (current_hough == new_hough).all() and indice != new_indice:
|
||||
segment_list.append(segments[new_indice])
|
||||
check_cnt += 1
|
||||
if check_cnt == cnt:
|
||||
break
|
||||
group_segments = np.array(segment_list).reshape([-1, 2])
|
||||
sorted_group_segments = np.sort(group_segments, axis=0)
|
||||
x_min, y_min = sorted_group_segments[0, :]
|
||||
x_max, y_max = sorted_group_segments[-1, :]
|
||||
|
||||
deg = theta[max_indice]
|
||||
if deg >= 90:
|
||||
merged_segments.append([x_min, y_max, x_max, y_min])
|
||||
else:
|
||||
merged_segments.append([x_min, y_min, x_max, y_max])
|
||||
|
||||
# 2. get intersections
|
||||
new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
|
||||
start = new_segments[:, :2] # (x1, y1)
|
||||
end = new_segments[:, 2:] # (x2, y2)
|
||||
new_centers = (start + end) / 2.0
|
||||
diff = start - end
|
||||
dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
|
||||
|
||||
# ax + by = c
|
||||
a = diff[:, 1]
|
||||
b = -diff[:, 0]
|
||||
c = a * start[:, 0] + b * start[:, 1]
|
||||
pre_det = a[:, None] * b[None, :]
|
||||
det = pre_det - np.transpose(pre_det)
|
||||
|
||||
pre_inter_y = a[:, None] * c[None, :]
|
||||
inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
|
||||
pre_inter_x = c[:, None] * b[None, :]
|
||||
inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
|
||||
inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
|
||||
|
||||
# 3. get corner information
|
||||
# 3.1 get distance
|
||||
'''
|
||||
dist_segments:
|
||||
| dist(0), dist(1), dist(2), ...|
|
||||
dist_inter_to_segment1:
|
||||
| dist(inter,0), dist(inter,0), dist(inter,0), ... |
|
||||
| dist(inter,1), dist(inter,1), dist(inter,1), ... |
|
||||
...
|
||||
dist_inter_to_semgnet2:
|
||||
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
|
||||
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
|
||||
...
|
||||
'''
|
||||
|
||||
dist_inter_to_segment1_start = np.sqrt(
|
||||
np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
||||
dist_inter_to_segment1_end = np.sqrt(
|
||||
np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
||||
dist_inter_to_segment2_start = np.sqrt(
|
||||
np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
||||
dist_inter_to_segment2_end = np.sqrt(
|
||||
np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
||||
|
||||
# sort ascending
|
||||
dist_inter_to_segment1 = np.sort(
|
||||
np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
|
||||
axis=-1) # [n_batch, n_batch, 2]
|
||||
dist_inter_to_segment2 = np.sort(
|
||||
np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
|
||||
axis=-1) # [n_batch, n_batch, 2]
|
||||
|
||||
# 3.2 get degree
|
||||
inter_to_start = new_centers[:, None, :] - inter_pts
|
||||
deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
|
||||
deg_inter_to_start[deg_inter_to_start < 0.0] += 360
|
||||
inter_to_end = new_centers[None, :, :] - inter_pts
|
||||
deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
|
||||
deg_inter_to_end[deg_inter_to_end < 0.0] += 360
|
||||
|
||||
'''
|
||||
B -- G
|
||||
| |
|
||||
C -- R
|
||||
B : blue / G: green / C: cyan / R: red
|
||||
|
||||
0 -- 1
|
||||
| |
|
||||
3 -- 2
|
||||
'''
|
||||
# rename variables
|
||||
deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
|
||||
# sort deg ascending
|
||||
deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
|
||||
|
||||
deg_diff_map = np.abs(deg1_map - deg2_map)
|
||||
# we only consider the smallest degree of intersect
|
||||
deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
|
||||
|
||||
# define available degree range
|
||||
deg_range = [60, 120]
|
||||
|
||||
corner_dict = {corner_info: [] for corner_info in range(4)}
|
||||
inter_points = []
|
||||
for i in range(inter_pts.shape[0]):
|
||||
for j in range(i + 1, inter_pts.shape[1]):
|
||||
# i, j > line index, always i < j
|
||||
x, y = inter_pts[i, j, :]
|
||||
deg1, deg2 = deg_sort[i, j, :]
|
||||
deg_diff = deg_diff_map[i, j]
|
||||
|
||||
check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
|
||||
|
||||
outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
|
||||
inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
|
||||
check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
|
||||
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
|
||||
(dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
|
||||
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
|
||||
((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
|
||||
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
|
||||
(dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
|
||||
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
|
||||
|
||||
if check_degree and check_distance:
|
||||
corner_info = None
|
||||
|
||||
if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
|
||||
(deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
|
||||
corner_info, color_info = 0, 'blue'
|
||||
elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
|
||||
corner_info, color_info = 1, 'green'
|
||||
elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
|
||||
corner_info, color_info = 2, 'black'
|
||||
elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
|
||||
(deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
|
||||
corner_info, color_info = 3, 'cyan'
|
||||
else:
|
||||
corner_info, color_info = 4, 'red' # we don't use it
|
||||
continue
|
||||
|
||||
corner_dict[corner_info].append([x, y, i, j])
|
||||
inter_points.append([x, y])
|
||||
|
||||
square_list = []
|
||||
connect_list = []
|
||||
segments_list = []
|
||||
for corner0 in corner_dict[0]:
|
||||
for corner1 in corner_dict[1]:
|
||||
connect01 = False
|
||||
for corner0_line in corner0[2:]:
|
||||
if corner0_line in corner1[2:]:
|
||||
connect01 = True
|
||||
break
|
||||
if connect01:
|
||||
for corner2 in corner_dict[2]:
|
||||
connect12 = False
|
||||
for corner1_line in corner1[2:]:
|
||||
if corner1_line in corner2[2:]:
|
||||
connect12 = True
|
||||
break
|
||||
if connect12:
|
||||
for corner3 in corner_dict[3]:
|
||||
connect23 = False
|
||||
for corner2_line in corner2[2:]:
|
||||
if corner2_line in corner3[2:]:
|
||||
connect23 = True
|
||||
break
|
||||
if connect23:
|
||||
for corner3_line in corner3[2:]:
|
||||
if corner3_line in corner0[2:]:
|
||||
# SQUARE!!!
|
||||
'''
|
||||
0 -- 1
|
||||
| |
|
||||
3 -- 2
|
||||
square_list:
|
||||
order: 0 > 1 > 2 > 3
|
||||
| x0, y0, x1, y1, x2, y2, x3, y3 |
|
||||
| x0, y0, x1, y1, x2, y2, x3, y3 |
|
||||
...
|
||||
connect_list:
|
||||
order: 01 > 12 > 23 > 30
|
||||
| line_idx01, line_idx12, line_idx23, line_idx30 |
|
||||
| line_idx01, line_idx12, line_idx23, line_idx30 |
|
||||
...
|
||||
segments_list:
|
||||
order: 0 > 1 > 2 > 3
|
||||
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
|
||||
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
|
||||
...
|
||||
'''
|
||||
square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
|
||||
connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
|
||||
segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
|
||||
|
||||
def check_outside_inside(segments_info, connect_idx):
|
||||
# return 'outside or inside', min distance, cover_param, peri_param
|
||||
if connect_idx == segments_info[0]:
|
||||
check_dist_mat = dist_inter_to_segment1
|
||||
else:
|
||||
check_dist_mat = dist_inter_to_segment2
|
||||
|
||||
i, j = segments_info
|
||||
min_dist, max_dist = check_dist_mat[i, j, :]
|
||||
connect_dist = dist_segments[connect_idx]
|
||||
if max_dist > connect_dist:
|
||||
return 'outside', min_dist, 0, 1
|
||||
else:
|
||||
return 'inside', min_dist, -1, -1
|
||||
|
||||
top_square = None
|
||||
|
||||
try:
|
||||
map_size = input_shape[0] / 2
|
||||
squares = np.array(square_list).reshape([-1, 4, 2])
|
||||
score_array = []
|
||||
connect_array = np.array(connect_list)
|
||||
segments_array = np.array(segments_list).reshape([-1, 4, 2])
|
||||
|
||||
# get degree of corners:
|
||||
squares_rollup = np.roll(squares, 1, axis=1)
|
||||
squares_rolldown = np.roll(squares, -1, axis=1)
|
||||
vec1 = squares_rollup - squares
|
||||
normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
|
||||
vec2 = squares_rolldown - squares
|
||||
normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
|
||||
inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
|
||||
squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
|
||||
|
||||
# get square score
|
||||
overlap_scores = []
|
||||
degree_scores = []
|
||||
length_scores = []
|
||||
|
||||
for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree, strict=False):
|
||||
'''
|
||||
0 -- 1
|
||||
| |
|
||||
3 -- 2
|
||||
|
||||
# segments: [4, 2]
|
||||
# connects: [4]
|
||||
'''
|
||||
|
||||
###################################### OVERLAP SCORES
|
||||
cover = 0
|
||||
perimeter = 0
|
||||
# check 0 > 1 > 2 > 3
|
||||
square_length = []
|
||||
|
||||
for start_idx in range(4):
|
||||
end_idx = (start_idx + 1) % 4
|
||||
|
||||
connect_idx = connects[start_idx] # segment idx of segment01
|
||||
start_segments = segments[start_idx]
|
||||
end_segments = segments[end_idx]
|
||||
|
||||
start_point = square[start_idx]
|
||||
end_point = square[end_idx]
|
||||
|
||||
# check whether outside or inside
|
||||
start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
|
||||
connect_idx)
|
||||
end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
|
||||
|
||||
cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
|
||||
perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
|
||||
|
||||
square_length.append(
|
||||
dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
|
||||
|
||||
overlap_scores.append(cover / perimeter)
|
||||
######################################
|
||||
###################################### DEGREE SCORES
|
||||
'''
|
||||
deg0 vs deg2
|
||||
deg1 vs deg3
|
||||
'''
|
||||
deg0, deg1, deg2, deg3 = degree
|
||||
deg_ratio1 = deg0 / deg2
|
||||
if deg_ratio1 > 1.0:
|
||||
deg_ratio1 = 1 / deg_ratio1
|
||||
deg_ratio2 = deg1 / deg3
|
||||
if deg_ratio2 > 1.0:
|
||||
deg_ratio2 = 1 / deg_ratio2
|
||||
degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
|
||||
######################################
|
||||
###################################### LENGTH SCORES
|
||||
'''
|
||||
len0 vs len2
|
||||
len1 vs len3
|
||||
'''
|
||||
len0, len1, len2, len3 = square_length
|
||||
len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
|
||||
len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
|
||||
length_scores.append((len_ratio1 + len_ratio2) / 2)
|
||||
|
||||
######################################
|
||||
|
||||
overlap_scores = np.array(overlap_scores)
|
||||
overlap_scores /= np.max(overlap_scores)
|
||||
|
||||
degree_scores = np.array(degree_scores)
|
||||
# degree_scores /= np.max(degree_scores)
|
||||
|
||||
length_scores = np.array(length_scores)
|
||||
|
||||
###################################### AREA SCORES
|
||||
area_scores = np.reshape(squares, [-1, 4, 2])
|
||||
area_x = area_scores[:, :, 0]
|
||||
area_y = area_scores[:, :, 1]
|
||||
correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
|
||||
area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
|
||||
area_scores = 0.5 * np.abs(area_scores + correction)
|
||||
area_scores /= (map_size * map_size) # np.max(area_scores)
|
||||
######################################
|
||||
|
||||
###################################### CENTER SCORES
|
||||
centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
|
||||
# squares: [n, 4, 2]
|
||||
square_centers = np.mean(squares, axis=1) # [n, 2]
|
||||
center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
|
||||
center_scores = center2center / (map_size / np.sqrt(2.0))
|
||||
|
||||
'''
|
||||
score_w = [overlap, degree, area, center, length]
|
||||
'''
|
||||
score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
|
||||
score_array = params['w_overlap'] * overlap_scores \
|
||||
+ params['w_degree'] * degree_scores \
|
||||
+ params['w_area'] * area_scores \
|
||||
- params['w_center'] * center_scores \
|
||||
+ params['w_length'] * length_scores
|
||||
|
||||
best_square = []
|
||||
|
||||
sorted_idx = np.argsort(score_array)[::-1]
|
||||
score_array = score_array[sorted_idx]
|
||||
squares = squares[sorted_idx]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
'''return list
|
||||
merged_lines, squares, scores
|
||||
'''
|
||||
|
||||
try:
|
||||
new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
|
||||
new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
|
||||
new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
|
||||
new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
|
||||
except:
|
||||
new_segments = []
|
||||
|
||||
try:
|
||||
squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
|
||||
squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
|
||||
except:
|
||||
squares = []
|
||||
score_array = []
|
||||
|
||||
try:
|
||||
inter_points = np.array(inter_points)
|
||||
inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
|
||||
inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
|
||||
except:
|
||||
inter_points = []
|
||||
|
||||
return new_segments, squares, score_array, inter_points
|
||||
21
invokeai/backend/image_util/normal_bae/LICENSE
Normal file
21
invokeai/backend/image_util/normal_bae/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Caroline Chan
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
93
invokeai/backend/image_util/normal_bae/__init__.py
Normal file
93
invokeai/backend/image_util/normal_bae/__init__.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Adapted from https://github.com/huggingface/controlnet_aux
|
||||
|
||||
import pathlib
|
||||
import types
|
||||
|
||||
import cv2
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
|
||||
|
||||
|
||||
class NormalMapDetector:
|
||||
"""Simple wrapper around the Normal BAE model for normal map generation."""
|
||||
|
||||
hf_repo_id = "lllyasviel/Annotators"
|
||||
hf_filename = "scannet.pt"
|
||||
|
||||
@classmethod
|
||||
def get_model_url(cls) -> str:
|
||||
"""Get the URL to download the model from the Hugging Face Hub."""
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path: pathlib.Path) -> NNET:
|
||||
"""Load the model from a file."""
|
||||
|
||||
args = types.SimpleNamespace()
|
||||
args.mode = "client"
|
||||
args.architecture = "BN"
|
||||
args.pretrained = "scannet"
|
||||
args.sampling_ratio = 0.4
|
||||
args.importance_ratio = 0.7
|
||||
|
||||
model = NNET(args)
|
||||
|
||||
ckpt = torch.load(model_path, map_location="cpu")["model"]
|
||||
load_dict = {}
|
||||
for k, v in ckpt.items():
|
||||
if k.startswith("module."):
|
||||
k_ = k.replace("module.", "")
|
||||
load_dict[k_] = v
|
||||
else:
|
||||
load_dict[k] = v
|
||||
|
||||
model.load_state_dict(load_dict)
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
def __init__(self, model: NNET) -> None:
|
||||
self.model = model
|
||||
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def run(self, image: Image.Image):
|
||||
"""Processes an image and returns the detected normal map."""
|
||||
|
||||
device = next(iter(self.model.parameters())).device
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
height, width, _channels = np_image.shape
|
||||
|
||||
# The model requires the image to be a multiple of 8
|
||||
np_image = resize_to_multiple(np_image, 8)
|
||||
|
||||
image_normal = np_image
|
||||
|
||||
with torch.no_grad():
|
||||
image_normal = torch.from_numpy(image_normal).float().to(device)
|
||||
image_normal = image_normal / 255.0
|
||||
image_normal = rearrange(image_normal, "h w c -> 1 c h w")
|
||||
image_normal = self.norm(image_normal)
|
||||
|
||||
normal = self.model(image_normal)
|
||||
normal = normal[0][-1][:, :3]
|
||||
normal = ((normal + 1) * 0.5).clip(0, 1)
|
||||
|
||||
normal = rearrange(normal[0], "c h w -> h w c").cpu().numpy()
|
||||
normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
# Back to the original size
|
||||
output_image = cv2.resize(normal_image, (width, height), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
return np_to_pil(output_image)
|
||||
22
invokeai/backend/image_util/normal_bae/nets/NNET.py
Normal file
22
invokeai/backend/image_util/normal_bae/nets/NNET.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .submodules.encoder import Encoder
|
||||
from .submodules.decoder import Decoder
|
||||
|
||||
|
||||
class NNET(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(NNET, self).__init__()
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder(args)
|
||||
|
||||
def get_1x_lr_params(self): # lr/10 learning rate
|
||||
return self.encoder.parameters()
|
||||
|
||||
def get_10x_lr_params(self): # lr learning rate
|
||||
return self.decoder.parameters()
|
||||
|
||||
def forward(self, img, **kwargs):
|
||||
return self.decoder(self.encoder(img), **kwargs)
|
||||
85
invokeai/backend/image_util/normal_bae/nets/baseline.py
Normal file
85
invokeai/backend/image_util/normal_bae/nets/baseline.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .submodules.submodules import UpSampleBN, norm_normalize
|
||||
|
||||
|
||||
# This is the baseline encoder-decoder we used in the ablation study
|
||||
class NNET(nn.Module):
|
||||
def __init__(self, args=None):
|
||||
super(NNET, self).__init__()
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder(num_classes=4)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
out = self.decoder(self.encoder(x), **kwargs)
|
||||
|
||||
# Bilinearly upsample the output to match the input resolution
|
||||
up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False)
|
||||
|
||||
# L2-normalize the first three channels / ensure positive value for concentration parameters (kappa)
|
||||
up_out = norm_normalize(up_out)
|
||||
return up_out
|
||||
|
||||
def get_1x_lr_params(self): # lr/10 learning rate
|
||||
return self.encoder.parameters()
|
||||
|
||||
def get_10x_lr_params(self): # lr learning rate
|
||||
modules = [self.decoder]
|
||||
for m in modules:
|
||||
yield from m.parameters()
|
||||
|
||||
|
||||
# Encoder
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
basemodel_name = 'tf_efficientnet_b5_ap'
|
||||
basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True)
|
||||
|
||||
# Remove last layer
|
||||
basemodel.global_pool = nn.Identity()
|
||||
basemodel.classifier = nn.Identity()
|
||||
|
||||
self.original_model = basemodel
|
||||
|
||||
def forward(self, x):
|
||||
features = [x]
|
||||
for k, v in self.original_model._modules.items():
|
||||
if (k == 'blocks'):
|
||||
for ki, vi in v._modules.items():
|
||||
features.append(vi(features[-1]))
|
||||
else:
|
||||
features.append(v(features[-1]))
|
||||
return features
|
||||
|
||||
|
||||
# Decoder (no pixel-wise MLP, no uncertainty-guided sampling)
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, num_classes=4):
|
||||
super(Decoder, self).__init__()
|
||||
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
|
||||
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
|
||||
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
|
||||
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
|
||||
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
|
||||
self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, features):
|
||||
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
|
||||
x_d0 = self.conv2(x_block4)
|
||||
x_d1 = self.up1(x_d0, x_block3)
|
||||
x_d2 = self.up2(x_d1, x_block2)
|
||||
x_d3 = self.up3(x_d2, x_block1)
|
||||
x_d4 = self.up4(x_d3, x_block0)
|
||||
out = self.conv3(x_d4)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = Baseline()
|
||||
x = torch.rand(2, 3, 480, 640)
|
||||
out = model(x)
|
||||
print(out.shape)
|
||||
@@ -0,0 +1,202 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .submodules import UpSampleBN, UpSampleGN, norm_normalize, sample_points
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
# hyper-parameter for sampling
|
||||
self.sampling_ratio = args.sampling_ratio
|
||||
self.importance_ratio = args.importance_ratio
|
||||
|
||||
# feature-map
|
||||
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
|
||||
if args.architecture == 'BN':
|
||||
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
|
||||
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
|
||||
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
|
||||
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
|
||||
|
||||
elif args.architecture == 'GN':
|
||||
self.up1 = UpSampleGN(skip_input=2048 + 176, output_features=1024)
|
||||
self.up2 = UpSampleGN(skip_input=1024 + 64, output_features=512)
|
||||
self.up3 = UpSampleGN(skip_input=512 + 40, output_features=256)
|
||||
self.up4 = UpSampleGN(skip_input=256 + 24, output_features=128)
|
||||
|
||||
else:
|
||||
raise Exception('invalid architecture')
|
||||
|
||||
# produces 1/8 res output
|
||||
self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# produces 1/4 res output
|
||||
self.out_conv_res4 = nn.Sequential(
|
||||
nn.Conv1d(512 + 4, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 4, kernel_size=1),
|
||||
)
|
||||
|
||||
# produces 1/2 res output
|
||||
self.out_conv_res2 = nn.Sequential(
|
||||
nn.Conv1d(256 + 4, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 4, kernel_size=1),
|
||||
)
|
||||
|
||||
# produces 1/1 res output
|
||||
self.out_conv_res1 = nn.Sequential(
|
||||
nn.Conv1d(128 + 4, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 4, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, features, gt_norm_mask=None, mode='test'):
|
||||
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
|
||||
|
||||
# generate feature-map
|
||||
|
||||
x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res
|
||||
x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res
|
||||
x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res
|
||||
x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res
|
||||
x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res
|
||||
|
||||
# 1/8 res output
|
||||
out_res8 = self.out_conv_res8(x_d2) # out_res8: [2, 4, 60, 80] 1/8 res output
|
||||
out_res8 = norm_normalize(out_res8) # out_res8: [2, 4, 60, 80] 1/8 res output
|
||||
|
||||
################################################################################################################
|
||||
# out_res4
|
||||
################################################################################################################
|
||||
|
||||
if mode == 'train':
|
||||
# upsampling ... out_res8: [2, 4, 60, 80] -> out_res8_res4: [2, 4, 120, 160]
|
||||
out_res8_res4 = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
B, _, H, W = out_res8_res4.shape
|
||||
|
||||
# samples: [B, 1, N, 2]
|
||||
point_coords_res4, rows_int, cols_int = sample_points(out_res8_res4.detach(), gt_norm_mask,
|
||||
sampling_ratio=self.sampling_ratio,
|
||||
beta=self.importance_ratio)
|
||||
|
||||
# output (needed for evaluation / visualization)
|
||||
out_res4 = out_res8_res4
|
||||
|
||||
# grid_sample feature-map
|
||||
feat_res4 = F.grid_sample(x_d2, point_coords_res4, mode='bilinear', align_corners=True) # (B, 512, 1, N)
|
||||
init_pred = F.grid_sample(out_res8, point_coords_res4, mode='bilinear', align_corners=True) # (B, 4, 1, N)
|
||||
feat_res4 = torch.cat([feat_res4, init_pred], dim=1) # (B, 512+4, 1, N)
|
||||
|
||||
# prediction (needed to compute loss)
|
||||
samples_pred_res4 = self.out_conv_res4(feat_res4[:, :, 0, :]) # (B, 4, N)
|
||||
samples_pred_res4 = norm_normalize(samples_pred_res4) # (B, 4, N) - normalized
|
||||
|
||||
for i in range(B):
|
||||
out_res4[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res4[i, :, :]
|
||||
|
||||
else:
|
||||
# grid_sample feature-map
|
||||
feat_map = F.interpolate(x_d2, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
init_pred = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
|
||||
B, _, H, W = feat_map.shape
|
||||
|
||||
# try all pixels
|
||||
out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N)
|
||||
out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized
|
||||
out_res4 = out_res4.view(B, 4, H, W)
|
||||
samples_pred_res4 = point_coords_res4 = None
|
||||
|
||||
################################################################################################################
|
||||
# out_res2
|
||||
################################################################################################################
|
||||
|
||||
if mode == 'train':
|
||||
|
||||
# upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
|
||||
out_res4_res2 = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
B, _, H, W = out_res4_res2.shape
|
||||
|
||||
# samples: [B, 1, N, 2]
|
||||
point_coords_res2, rows_int, cols_int = sample_points(out_res4_res2.detach(), gt_norm_mask,
|
||||
sampling_ratio=self.sampling_ratio,
|
||||
beta=self.importance_ratio)
|
||||
|
||||
# output (needed for evaluation / visualization)
|
||||
out_res2 = out_res4_res2
|
||||
|
||||
# grid_sample feature-map
|
||||
feat_res2 = F.grid_sample(x_d3, point_coords_res2, mode='bilinear', align_corners=True) # (B, 256, 1, N)
|
||||
init_pred = F.grid_sample(out_res4, point_coords_res2, mode='bilinear', align_corners=True) # (B, 4, 1, N)
|
||||
feat_res2 = torch.cat([feat_res2, init_pred], dim=1) # (B, 256+4, 1, N)
|
||||
|
||||
# prediction (needed to compute loss)
|
||||
samples_pred_res2 = self.out_conv_res2(feat_res2[:, :, 0, :]) # (B, 4, N)
|
||||
samples_pred_res2 = norm_normalize(samples_pred_res2) # (B, 4, N) - normalized
|
||||
|
||||
for i in range(B):
|
||||
out_res2[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res2[i, :, :]
|
||||
|
||||
else:
|
||||
# grid_sample feature-map
|
||||
feat_map = F.interpolate(x_d3, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
init_pred = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
|
||||
B, _, H, W = feat_map.shape
|
||||
|
||||
out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N)
|
||||
out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized
|
||||
out_res2 = out_res2.view(B, 4, H, W)
|
||||
samples_pred_res2 = point_coords_res2 = None
|
||||
|
||||
################################################################################################################
|
||||
# out_res1
|
||||
################################################################################################################
|
||||
|
||||
if mode == 'train':
|
||||
# upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
|
||||
out_res2_res1 = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
B, _, H, W = out_res2_res1.shape
|
||||
|
||||
# samples: [B, 1, N, 2]
|
||||
point_coords_res1, rows_int, cols_int = sample_points(out_res2_res1.detach(), gt_norm_mask,
|
||||
sampling_ratio=self.sampling_ratio,
|
||||
beta=self.importance_ratio)
|
||||
|
||||
# output (needed for evaluation / visualization)
|
||||
out_res1 = out_res2_res1
|
||||
|
||||
# grid_sample feature-map
|
||||
feat_res1 = F.grid_sample(x_d4, point_coords_res1, mode='bilinear', align_corners=True) # (B, 128, 1, N)
|
||||
init_pred = F.grid_sample(out_res2, point_coords_res1, mode='bilinear', align_corners=True) # (B, 4, 1, N)
|
||||
feat_res1 = torch.cat([feat_res1, init_pred], dim=1) # (B, 128+4, 1, N)
|
||||
|
||||
# prediction (needed to compute loss)
|
||||
samples_pred_res1 = self.out_conv_res1(feat_res1[:, :, 0, :]) # (B, 4, N)
|
||||
samples_pred_res1 = norm_normalize(samples_pred_res1) # (B, 4, N) - normalized
|
||||
|
||||
for i in range(B):
|
||||
out_res1[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res1[i, :, :]
|
||||
|
||||
else:
|
||||
# grid_sample feature-map
|
||||
feat_map = F.interpolate(x_d4, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
init_pred = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
|
||||
B, _, H, W = feat_map.shape
|
||||
|
||||
out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N)
|
||||
out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized
|
||||
out_res1 = out_res1.view(B, 4, H, W)
|
||||
samples_pred_res1 = point_coords_res1 = None
|
||||
|
||||
return [out_res8, out_res4, out_res2, out_res1], \
|
||||
[out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], \
|
||||
[None, point_coords_res4, point_coords_res2, point_coords_res1]
|
||||
|
||||
109
invokeai/backend/image_util/normal_bae/nets/submodules/efficientnet_repo/.gitignore
vendored
Normal file
109
invokeai/backend/image_util/normal_bae/nets/submodules/efficientnet_repo/.gitignore
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# pytorch stuff
|
||||
*.pth
|
||||
*.onnx
|
||||
*.pb
|
||||
|
||||
trained_models/
|
||||
.fuse_hidden*
|
||||
@@ -0,0 +1,555 @@
|
||||
# Model Performance Benchmarks
|
||||
|
||||
All benchmarks run as per:
|
||||
|
||||
```
|
||||
python onnx_export.py --model mobilenetv3_100 ./mobilenetv3_100.onnx
|
||||
python onnx_optimize.py ./mobilenetv3_100.onnx --output mobilenetv3_100-opt.onnx
|
||||
python onnx_to_caffe.py ./mobilenetv3_100.onnx --c2-prefix mobilenetv3
|
||||
python onnx_to_caffe.py ./mobilenetv3_100-opt.onnx --c2-prefix mobilenetv3-opt
|
||||
python caffe2_benchmark.py --c2-init ./mobilenetv3.init.pb --c2-predict ./mobilenetv3.predict.pb
|
||||
python caffe2_benchmark.py --c2-init ./mobilenetv3-opt.init.pb --c2-predict ./mobilenetv3-opt.predict.pb
|
||||
```
|
||||
|
||||
## EfficientNet-B0
|
||||
|
||||
### Unoptimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 49.2862. Iters per second: 20.2897
|
||||
Time per operator type:
|
||||
29.7378 ms. 60.5145%. Conv
|
||||
12.1785 ms. 24.7824%. Sigmoid
|
||||
3.62811 ms. 7.38297%. SpatialBN
|
||||
2.98444 ms. 6.07314%. Mul
|
||||
0.326902 ms. 0.665225%. AveragePool
|
||||
0.197317 ms. 0.401528%. FC
|
||||
0.0852877 ms. 0.173555%. Add
|
||||
0.0032607 ms. 0.00663532%. Squeeze
|
||||
49.1416 ms in Total
|
||||
FLOP per operator type:
|
||||
0.76907 GFLOP. 95.2696%. Conv
|
||||
0.0269508 GFLOP. 3.33857%. SpatialBN
|
||||
0.00846444 GFLOP. 1.04855%. Mul
|
||||
0.002561 GFLOP. 0.317248%. FC
|
||||
0.000210112 GFLOP. 0.0260279%. Add
|
||||
0.807256 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
58.5253 MB. 43.0891%. Mul
|
||||
43.2015 MB. 31.807%. Conv
|
||||
27.2869 MB. 20.0899%. SpatialBN
|
||||
5.12912 MB. 3.77631%. FC
|
||||
1.6809 MB. 1.23756%. Add
|
||||
135.824 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
33.8578 MB. 38.1965%. Mul
|
||||
26.9881 MB. 30.4465%. Conv
|
||||
26.9508 MB. 30.4044%. SpatialBN
|
||||
0.840448 MB. 0.948147%. Add
|
||||
0.004 MB. 0.00451258%. FC
|
||||
88.6412 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
15.8248 MB. 74.9391%. Conv
|
||||
5.124 MB. 24.265%. FC
|
||||
0.168064 MB. 0.795877%. SpatialBN
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Mul
|
||||
21.1168 MB in Total
|
||||
```
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 46.0838. Iters per second: 21.6996
|
||||
Time per operator type:
|
||||
29.776 ms. 65.002%. Conv
|
||||
12.2803 ms. 26.8084%. Sigmoid
|
||||
3.15073 ms. 6.87815%. Mul
|
||||
0.328651 ms. 0.717456%. AveragePool
|
||||
0.186237 ms. 0.406563%. FC
|
||||
0.0832429 ms. 0.181722%. Add
|
||||
0.0026184 ms. 0.00571606%. Squeeze
|
||||
45.8078 ms in Total
|
||||
FLOP per operator type:
|
||||
0.76907 GFLOP. 98.5601%. Conv
|
||||
0.00846444 GFLOP. 1.08476%. Mul
|
||||
0.002561 GFLOP. 0.328205%. FC
|
||||
0.000210112 GFLOP. 0.0269269%. Add
|
||||
0.780305 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
58.5253 MB. 53.8803%. Mul
|
||||
43.2855 MB. 39.8501%. Conv
|
||||
5.12912 MB. 4.72204%. FC
|
||||
1.6809 MB. 1.54749%. Add
|
||||
108.621 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
33.8578 MB. 54.8834%. Mul
|
||||
26.9881 MB. 43.7477%. Conv
|
||||
0.840448 MB. 1.36237%. Add
|
||||
0.004 MB. 0.00648399%. FC
|
||||
61.6904 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
15.8248 MB. 75.5403%. Conv
|
||||
5.124 MB. 24.4597%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Mul
|
||||
20.9488 MB in Total
|
||||
```
|
||||
|
||||
## EfficientNet-B1
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 71.8102. Iters per second: 13.9256
|
||||
Time per operator type:
|
||||
45.7915 ms. 66.3206%. Conv
|
||||
17.8718 ms. 25.8841%. Sigmoid
|
||||
4.44132 ms. 6.43244%. Mul
|
||||
0.51001 ms. 0.738658%. AveragePool
|
||||
0.233283 ms. 0.337868%. Add
|
||||
0.194986 ms. 0.282402%. FC
|
||||
0.00268255 ms. 0.00388519%. Squeeze
|
||||
69.0456 ms in Total
|
||||
FLOP per operator type:
|
||||
1.37105 GFLOP. 98.7673%. Conv
|
||||
0.0138759 GFLOP. 0.99959%. Mul
|
||||
0.002561 GFLOP. 0.184489%. FC
|
||||
0.000674432 GFLOP. 0.0485847%. Add
|
||||
1.38816 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
94.624 MB. 54.0789%. Mul
|
||||
69.8255 MB. 39.9062%. Conv
|
||||
5.39546 MB. 3.08357%. Add
|
||||
5.12912 MB. 2.93136%. FC
|
||||
174.974 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
55.5035 MB. 54.555%. Mul
|
||||
43.5333 MB. 42.7894%. Conv
|
||||
2.69773 MB. 2.65163%. Add
|
||||
0.004 MB. 0.00393165%. FC
|
||||
101.739 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
25.7479 MB. 83.4024%. Conv
|
||||
5.124 MB. 16.5976%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Mul
|
||||
30.8719 MB in Total
|
||||
```
|
||||
|
||||
## EfficientNet-B2
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 92.28. Iters per second: 10.8366
|
||||
Time per operator type:
|
||||
61.4627 ms. 67.5845%. Conv
|
||||
22.7458 ms. 25.0113%. Sigmoid
|
||||
5.59931 ms. 6.15701%. Mul
|
||||
0.642567 ms. 0.706568%. AveragePool
|
||||
0.272795 ms. 0.299965%. Add
|
||||
0.216178 ms. 0.237709%. FC
|
||||
0.00268895 ms. 0.00295677%. Squeeze
|
||||
90.942 ms in Total
|
||||
FLOP per operator type:
|
||||
1.98431 GFLOP. 98.9343%. Conv
|
||||
0.0177039 GFLOP. 0.882686%. Mul
|
||||
0.002817 GFLOP. 0.140451%. FC
|
||||
0.000853984 GFLOP. 0.0425782%. Add
|
||||
2.00568 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
120.609 MB. 54.9637%. Mul
|
||||
86.3512 MB. 39.3519%. Conv
|
||||
6.83187 MB. 3.11341%. Add
|
||||
5.64163 MB. 2.571%. FC
|
||||
219.433 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
70.8155 MB. 54.6573%. Mul
|
||||
55.3273 MB. 42.7031%. Conv
|
||||
3.41594 MB. 2.63651%. Add
|
||||
0.004 MB. 0.00308731%. FC
|
||||
129.563 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
30.4721 MB. 84.3913%. Conv
|
||||
5.636 MB. 15.6087%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Mul
|
||||
36.1081 MB in Total
|
||||
```
|
||||
|
||||
## MixNet-M
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 63.1122. Iters per second: 15.8448
|
||||
Time per operator type:
|
||||
48.1139 ms. 75.2052%. Conv
|
||||
7.1341 ms. 11.1511%. Sigmoid
|
||||
2.63706 ms. 4.12189%. SpatialBN
|
||||
1.73186 ms. 2.70701%. Mul
|
||||
1.38707 ms. 2.16809%. Split
|
||||
1.29322 ms. 2.02139%. Concat
|
||||
1.00093 ms. 1.56452%. Relu
|
||||
0.235309 ms. 0.367803%. Add
|
||||
0.221579 ms. 0.346343%. FC
|
||||
0.219315 ms. 0.342803%. AveragePool
|
||||
0.00250145 ms. 0.00390993%. Squeeze
|
||||
63.9768 ms in Total
|
||||
FLOP per operator type:
|
||||
0.675273 GFLOP. 95.5827%. Conv
|
||||
0.0221072 GFLOP. 3.12921%. SpatialBN
|
||||
0.00538445 GFLOP. 0.762152%. Mul
|
||||
0.003073 GFLOP. 0.434973%. FC
|
||||
0.000642488 GFLOP. 0.0909421%. Add
|
||||
0 GFLOP. 0%. Concat
|
||||
0 GFLOP. 0%. Relu
|
||||
0.70648 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
46.8424 MB. 30.502%. Conv
|
||||
36.8626 MB. 24.0036%. Mul
|
||||
22.3152 MB. 14.5309%. SpatialBN
|
||||
22.1074 MB. 14.3955%. Concat
|
||||
14.1496 MB. 9.21372%. Relu
|
||||
6.15414 MB. 4.00735%. FC
|
||||
5.1399 MB. 3.34692%. Add
|
||||
153.571 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
32.7672 MB. 28.4331%. Conv
|
||||
22.1072 MB. 19.1831%. Concat
|
||||
22.1072 MB. 19.1831%. SpatialBN
|
||||
21.5378 MB. 18.689%. Mul
|
||||
14.1496 MB. 12.2781%. Relu
|
||||
2.56995 MB. 2.23003%. Add
|
||||
0.004 MB. 0.00347092%. FC
|
||||
115.243 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
13.7059 MB. 68.674%. Conv
|
||||
6.148 MB. 30.8049%. FC
|
||||
0.104 MB. 0.521097%. SpatialBN
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Concat
|
||||
0 MB. 0%. Mul
|
||||
0 MB. 0%. Relu
|
||||
19.9579 MB in Total
|
||||
```
|
||||
|
||||
## TF MobileNet-V3 Large 1.0
|
||||
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 22.0495. Iters per second: 45.3525
|
||||
Time per operator type:
|
||||
17.437 ms. 80.0087%. Conv
|
||||
1.27662 ms. 5.8577%. Add
|
||||
1.12759 ms. 5.17387%. Div
|
||||
0.701155 ms. 3.21721%. Mul
|
||||
0.562654 ms. 2.58171%. Relu
|
||||
0.431144 ms. 1.97828%. Clip
|
||||
0.156902 ms. 0.719936%. FC
|
||||
0.0996858 ms. 0.457402%. AveragePool
|
||||
0.00112455 ms. 0.00515993%. Flatten
|
||||
21.7939 ms in Total
|
||||
FLOP per operator type:
|
||||
0.43062 GFLOP. 98.1484%. Conv
|
||||
0.002561 GFLOP. 0.583713%. FC
|
||||
0.00210867 GFLOP. 0.480616%. Mul
|
||||
0.00193868 GFLOP. 0.441871%. Add
|
||||
0.00151532 GFLOP. 0.345377%. Div
|
||||
0 GFLOP. 0%. Relu
|
||||
0.438743 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
34.7967 MB. 43.9391%. Conv
|
||||
14.496 MB. 18.3046%. Mul
|
||||
9.44828 MB. 11.9307%. Add
|
||||
9.26157 MB. 11.6949%. Relu
|
||||
6.0614 MB. 7.65395%. Div
|
||||
5.12912 MB. 6.47673%. FC
|
||||
79.193 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
17.6247 MB. 35.8656%. Conv
|
||||
9.26157 MB. 18.847%. Relu
|
||||
8.43469 MB. 17.1643%. Mul
|
||||
7.75472 MB. 15.7806%. Add
|
||||
6.06128 MB. 12.3345%. Div
|
||||
0.004 MB. 0.00813985%. FC
|
||||
49.1409 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
16.6851 MB. 76.5052%. Conv
|
||||
5.124 MB. 23.4948%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Div
|
||||
0 MB. 0%. Mul
|
||||
0 MB. 0%. Relu
|
||||
21.8091 MB in Total
|
||||
```
|
||||
|
||||
## MobileNet-V3 (RW)
|
||||
|
||||
### Unoptimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 24.8316. Iters per second: 40.2712
|
||||
Time per operator type:
|
||||
15.9266 ms. 69.2624%. Conv
|
||||
2.36551 ms. 10.2873%. SpatialBN
|
||||
1.39102 ms. 6.04936%. Add
|
||||
1.30327 ms. 5.66773%. Div
|
||||
0.737014 ms. 3.20517%. Mul
|
||||
0.639697 ms. 2.78195%. Relu
|
||||
0.375681 ms. 1.63378%. Clip
|
||||
0.153126 ms. 0.665921%. FC
|
||||
0.0993787 ms. 0.432184%. AveragePool
|
||||
0.0032632 ms. 0.0141912%. Squeeze
|
||||
22.9946 ms in Total
|
||||
FLOP per operator type:
|
||||
0.430616 GFLOP. 94.4041%. Conv
|
||||
0.0175992 GFLOP. 3.85829%. SpatialBN
|
||||
0.002561 GFLOP. 0.561449%. FC
|
||||
0.00210961 GFLOP. 0.46249%. Mul
|
||||
0.00173891 GFLOP. 0.381223%. Add
|
||||
0.00151626 GFLOP. 0.33241%. Div
|
||||
0 GFLOP. 0%. Relu
|
||||
0.456141 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
34.7354 MB. 36.4363%. Conv
|
||||
17.7944 MB. 18.6658%. SpatialBN
|
||||
14.5035 MB. 15.2137%. Mul
|
||||
9.25778 MB. 9.71113%. Relu
|
||||
7.84641 MB. 8.23064%. Add
|
||||
6.06516 MB. 6.36216%. Div
|
||||
5.12912 MB. 5.38029%. FC
|
||||
95.3317 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
17.6246 MB. 26.7264%. Conv
|
||||
17.5992 MB. 26.6878%. SpatialBN
|
||||
9.25778 MB. 14.0387%. Relu
|
||||
8.43843 MB. 12.7962%. Mul
|
||||
6.95565 MB. 10.5477%. Add
|
||||
6.06502 MB. 9.19713%. Div
|
||||
0.004 MB. 0.00606568%. FC
|
||||
65.9447 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
16.6778 MB. 76.1564%. Conv
|
||||
5.124 MB. 23.3979%. FC
|
||||
0.0976 MB. 0.445674%. SpatialBN
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Div
|
||||
0 MB. 0%. Mul
|
||||
0 MB. 0%. Relu
|
||||
21.8994 MB in Total
|
||||
|
||||
```
|
||||
### Optimized
|
||||
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 22.0981. Iters per second: 45.2527
|
||||
Time per operator type:
|
||||
17.146 ms. 78.8965%. Conv
|
||||
1.38453 ms. 6.37084%. Add
|
||||
1.30991 ms. 6.02749%. Div
|
||||
0.685417 ms. 3.15391%. Mul
|
||||
0.532589 ms. 2.45068%. Relu
|
||||
0.418263 ms. 1.92461%. Clip
|
||||
0.15128 ms. 0.696106%. FC
|
||||
0.102065 ms. 0.469648%. AveragePool
|
||||
0.0022143 ms. 0.010189%. Squeeze
|
||||
21.7323 ms in Total
|
||||
FLOP per operator type:
|
||||
0.430616 GFLOP. 98.1927%. Conv
|
||||
0.002561 GFLOP. 0.583981%. FC
|
||||
0.00210961 GFLOP. 0.481051%. Mul
|
||||
0.00173891 GFLOP. 0.396522%. Add
|
||||
0.00151626 GFLOP. 0.34575%. Div
|
||||
0 GFLOP. 0%. Relu
|
||||
0.438542 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
34.7842 MB. 44.833%. Conv
|
||||
14.5035 MB. 18.6934%. Mul
|
||||
9.25778 MB. 11.9323%. Relu
|
||||
7.84641 MB. 10.1132%. Add
|
||||
6.06516 MB. 7.81733%. Div
|
||||
5.12912 MB. 6.61087%. FC
|
||||
77.5861 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
17.6246 MB. 36.4556%. Conv
|
||||
9.25778 MB. 19.1492%. Relu
|
||||
8.43843 MB. 17.4544%. Mul
|
||||
6.95565 MB. 14.3874%. Add
|
||||
6.06502 MB. 12.5452%. Div
|
||||
0.004 MB. 0.00827378%. FC
|
||||
48.3455 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
16.6778 MB. 76.4973%. Conv
|
||||
5.124 MB. 23.5027%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Div
|
||||
0 MB. 0%. Mul
|
||||
0 MB. 0%. Relu
|
||||
21.8018 MB in Total
|
||||
|
||||
```
|
||||
|
||||
## MnasNet-A1
|
||||
|
||||
### Unoptimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 30.0892. Iters per second: 33.2345
|
||||
Time per operator type:
|
||||
24.4656 ms. 79.0905%. Conv
|
||||
4.14958 ms. 13.4144%. SpatialBN
|
||||
1.60598 ms. 5.19169%. Relu
|
||||
0.295219 ms. 0.95436%. Mul
|
||||
0.187609 ms. 0.606486%. FC
|
||||
0.120556 ms. 0.389724%. AveragePool
|
||||
0.09036 ms. 0.292109%. Add
|
||||
0.015727 ms. 0.050841%. Sigmoid
|
||||
0.00306205 ms. 0.00989875%. Squeeze
|
||||
30.9337 ms in Total
|
||||
FLOP per operator type:
|
||||
0.620598 GFLOP. 95.6434%. Conv
|
||||
0.0248873 GFLOP. 3.8355%. SpatialBN
|
||||
0.002561 GFLOP. 0.394688%. FC
|
||||
0.000597408 GFLOP. 0.0920695%. Mul
|
||||
0.000222656 GFLOP. 0.0343146%. Add
|
||||
0 GFLOP. 0%. Relu
|
||||
0.648867 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
35.5457 MB. 38.4109%. Conv
|
||||
25.1552 MB. 27.1829%. SpatialBN
|
||||
22.5235 MB. 24.339%. Relu
|
||||
5.12912 MB. 5.54256%. FC
|
||||
2.40586 MB. 2.59978%. Mul
|
||||
1.78125 MB. 1.92483%. Add
|
||||
92.5406 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
24.9042 MB. 32.9424%. Conv
|
||||
24.8873 MB. 32.92%. SpatialBN
|
||||
22.5235 MB. 29.7932%. Relu
|
||||
2.38963 MB. 3.16092%. Mul
|
||||
0.890624 MB. 1.17809%. Add
|
||||
0.004 MB. 0.00529106%. FC
|
||||
75.5993 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
10.2732 MB. 66.1459%. Conv
|
||||
5.124 MB. 32.9917%. FC
|
||||
0.133952 MB. 0.86247%. SpatialBN
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Mul
|
||||
0 MB. 0%. Relu
|
||||
15.5312 MB in Total
|
||||
```
|
||||
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 24.2367. Iters per second: 41.2597
|
||||
Time per operator type:
|
||||
22.0547 ms. 91.1375%. Conv
|
||||
1.49096 ms. 6.16116%. Relu
|
||||
0.253417 ms. 1.0472%. Mul
|
||||
0.18506 ms. 0.76473%. FC
|
||||
0.112942 ms. 0.466717%. AveragePool
|
||||
0.086769 ms. 0.358559%. Add
|
||||
0.0127889 ms. 0.0528479%. Sigmoid
|
||||
0.0027346 ms. 0.0113003%. Squeeze
|
||||
24.1994 ms in Total
|
||||
FLOP per operator type:
|
||||
0.620598 GFLOP. 99.4581%. Conv
|
||||
0.002561 GFLOP. 0.41043%. FC
|
||||
0.000597408 GFLOP. 0.0957417%. Mul
|
||||
0.000222656 GFLOP. 0.0356832%. Add
|
||||
0 GFLOP. 0%. Relu
|
||||
0.623979 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
35.6127 MB. 52.7968%. Conv
|
||||
22.5235 MB. 33.3917%. Relu
|
||||
5.12912 MB. 7.60406%. FC
|
||||
2.40586 MB. 3.56675%. Mul
|
||||
1.78125 MB. 2.64075%. Add
|
||||
67.4524 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
24.9042 MB. 49.1092%. Conv
|
||||
22.5235 MB. 44.4145%. Relu
|
||||
2.38963 MB. 4.71216%. Mul
|
||||
0.890624 MB. 1.75624%. Add
|
||||
0.004 MB. 0.00788768%. FC
|
||||
50.712 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
10.2732 MB. 66.7213%. Conv
|
||||
5.124 MB. 33.2787%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Mul
|
||||
0 MB. 0%. Relu
|
||||
15.3972 MB in Total
|
||||
```
|
||||
## MnasNet-B1
|
||||
|
||||
### Unoptimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 28.3109. Iters per second: 35.322
|
||||
Time per operator type:
|
||||
29.1121 ms. 83.3081%. Conv
|
||||
4.14959 ms. 11.8746%. SpatialBN
|
||||
1.35823 ms. 3.88675%. Relu
|
||||
0.186188 ms. 0.532802%. FC
|
||||
0.116244 ms. 0.332647%. Add
|
||||
0.018641 ms. 0.0533437%. AveragePool
|
||||
0.0040904 ms. 0.0117052%. Squeeze
|
||||
34.9451 ms in Total
|
||||
FLOP per operator type:
|
||||
0.626272 GFLOP. 96.2088%. Conv
|
||||
0.0218266 GFLOP. 3.35303%. SpatialBN
|
||||
0.002561 GFLOP. 0.393424%. FC
|
||||
0.000291648 GFLOP. 0.0448034%. Add
|
||||
0 GFLOP. 0%. Relu
|
||||
0.650951 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
34.4354 MB. 41.3788%. Conv
|
||||
22.1299 MB. 26.5921%. SpatialBN
|
||||
19.1923 MB. 23.0622%. Relu
|
||||
5.12912 MB. 6.16333%. FC
|
||||
2.33318 MB. 2.80364%. Add
|
||||
83.2199 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
21.8266 MB. 34.0955%. Conv
|
||||
21.8266 MB. 34.0955%. SpatialBN
|
||||
19.1923 MB. 29.9805%. Relu
|
||||
1.16659 MB. 1.82234%. Add
|
||||
0.004 MB. 0.00624844%. FC
|
||||
64.016 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
12.2576 MB. 69.9104%. Conv
|
||||
5.124 MB. 29.2245%. FC
|
||||
0.15168 MB. 0.865099%. SpatialBN
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Relu
|
||||
17.5332 MB in Total
|
||||
```
|
||||
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 26.6364. Iters per second: 37.5426
|
||||
Time per operator type:
|
||||
24.9888 ms. 94.0962%. Conv
|
||||
1.26147 ms. 4.75011%. Relu
|
||||
0.176234 ms. 0.663619%. FC
|
||||
0.113309 ms. 0.426672%. Add
|
||||
0.0138708 ms. 0.0522311%. AveragePool
|
||||
0.00295685 ms. 0.0111341%. Squeeze
|
||||
26.5566 ms in Total
|
||||
FLOP per operator type:
|
||||
0.626272 GFLOP. 99.5466%. Conv
|
||||
0.002561 GFLOP. 0.407074%. FC
|
||||
0.000291648 GFLOP. 0.0463578%. Add
|
||||
0 GFLOP. 0%. Relu
|
||||
0.629124 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
34.5112 MB. 56.4224%. Conv
|
||||
19.1923 MB. 31.3775%. Relu
|
||||
5.12912 MB. 8.3856%. FC
|
||||
2.33318 MB. 3.81452%. Add
|
||||
61.1658 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
21.8266 MB. 51.7346%. Conv
|
||||
19.1923 MB. 45.4908%. Relu
|
||||
1.16659 MB. 2.76513%. Add
|
||||
0.004 MB. 0.00948104%. FC
|
||||
42.1895 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
12.2576 MB. 70.5205%. Conv
|
||||
5.124 MB. 29.4795%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Relu
|
||||
17.3816 MB in Total
|
||||
```
|
||||
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -0,0 +1,323 @@
|
||||
# (Generic) EfficientNets for PyTorch
|
||||
|
||||
A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter efficient architectures derived from the MobileNet V1/V2 block sequence, including those found via automated neural architecture search.
|
||||
|
||||
All models are implemented by GenEfficientNet or MobileNetV3 classes, with string based architecture definitions to configure the block layouts (idea from [here](https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py))
|
||||
|
||||
## What's New
|
||||
|
||||
### Aug 19, 2020
|
||||
* Add updated PyTorch trained EfficientNet-B3 weights trained by myself with `timm` (82.1 top-1)
|
||||
* Add PyTorch trained EfficientNet-Lite0 contributed by [@hal-314](https://github.com/hal-314) (75.5 top-1)
|
||||
* Update ONNX and Caffe2 export / utility scripts to work with latest PyTorch / ONNX
|
||||
* ONNX runtime based validation script added
|
||||
* activations (mostly) brought in sync with `timm` equivalents
|
||||
|
||||
|
||||
### April 5, 2020
|
||||
* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite
|
||||
* 3.5M param MobileNet-V2 100 @ 73%
|
||||
* 4.5M param MobileNet-V2 110d @ 75%
|
||||
* 6.1M param MobileNet-V2 140 @ 76.5%
|
||||
* 5.8M param MobileNet-V2 120d @ 77.3%
|
||||
|
||||
### March 23, 2020
|
||||
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
|
||||
* Add PyTorch trained MobileNet-V3 Large weights with 75.77% top-1
|
||||
* IMPORTANT CHANGE (if training from scratch) - weight init changed to better match Tensorflow impl, set `fix_group_fanout=False` in `initialize_weight_goog` for old behavior
|
||||
|
||||
### Feb 12, 2020
|
||||
* Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
|
||||
* Port new EfficientNet-B8 (RandAugment) weights from TF TPU, these are different than the B8 AdvProp, different input normalization.
|
||||
* Add RandAugment PyTorch trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin)
|
||||
|
||||
### Jan 22, 2020
|
||||
* Update weights for EfficientNet B0, B2, B3 and MixNet-XL with latest RandAugment trained weights. Trained with (https://github.com/rwightman/pytorch-image-models)
|
||||
* Fix torchscript compatibility for PyTorch 1.4, add torchscript support for MixedConv2d using ModuleDict
|
||||
* Test models, torchscript, onnx export with PyTorch 1.4 -- no issues
|
||||
|
||||
### Nov 22, 2019
|
||||
* New top-1 high! Ported official TF EfficientNet AdvProp (https://arxiv.org/abs/1911.09665) weights and B8 model spec. Created a new set of `ap` models since they use a different
|
||||
preprocessing (Inception mean/std) from the original EfficientNet base/AA/RA weights.
|
||||
|
||||
### Nov 15, 2019
|
||||
* Ported official TF MobileNet-V3 float32 large/small/minimalistic weights
|
||||
* Modifications to MobileNet-V3 model and components to support some additional config needed for differences between TF MobileNet-V3 and mine
|
||||
|
||||
### Oct 30, 2019
|
||||
* Many of the models will now work with torch.jit.script, MixNet being the biggest exception
|
||||
* Improved interface for enabling torchscript or ONNX export compatible modes (via config)
|
||||
* Add JIT optimized mem-efficient Swish/Mish autograd.fn in addition to memory-efficient autgrad.fn
|
||||
* Activation factory to select best version of activation by name or override one globally
|
||||
* Add pretrained checkpoint load helper that handles input conv and classifier changes
|
||||
|
||||
### Oct 27, 2019
|
||||
* Add CondConv EfficientNet variants ported from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
|
||||
* Add RandAug weights for TF EfficientNet B5 and B7 from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
|
||||
* Bring over MixNet-XL model and depth scaling algo from my pytorch-image-models code base
|
||||
* Switch activations and global pooling to modules
|
||||
* Add memory-efficient Swish/Mish impl
|
||||
* Add as_sequential() method to all models and allow as an argument in entrypoint fns
|
||||
* Move MobileNetV3 into own file since it has a different head
|
||||
* Remove ChamNet, MobileNet V2/V1 since they will likely never be used here
|
||||
|
||||
## Models
|
||||
|
||||
Implemented models include:
|
||||
* EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252)
|
||||
* EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665)
|
||||
* EfficientNet (B0-B8) (https://arxiv.org/abs/1905.11946)
|
||||
* EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html)
|
||||
* EfficientNet-CondConv (https://arxiv.org/abs/1904.04971)
|
||||
* EfficientNet-Lite (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
|
||||
* MixNet (https://arxiv.org/abs/1907.09595)
|
||||
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
|
||||
* MobileNet-V3 (https://arxiv.org/abs/1905.02244)
|
||||
* FBNet-C (https://arxiv.org/abs/1812.03443)
|
||||
* Single-Path NAS (https://arxiv.org/abs/1904.02877)
|
||||
|
||||
I originally implemented and trained some these models with code [here](https://github.com/rwightman/pytorch-image-models), this repository contains just the GenEfficientNet models, validation, and associated ONNX/Caffe2 export code.
|
||||
|
||||
## Pretrained
|
||||
|
||||
I've managed to train several of the models to accuracies close to or above the originating papers and official impl. My training code is here: https://github.com/rwightman/pytorch-image-models
|
||||
|
||||
|
||||
|Model | Prec@1 (Err) | Prec@5 (Err) | Param#(M) | MAdds(M) | Image Scaling | Resolution | Crop |
|
||||
|---|---|---|---|---|---|---|---|
|
||||
| efficientnet_b3 | 82.240 (17.760) | 96.116 (3.884) | 12.23 | TBD | bicubic | 320 | 1.0 |
|
||||
| efficientnet_b3 | 82.076 (17.924) | 96.020 (3.980) | 12.23 | TBD | bicubic | 300 | 0.904 |
|
||||
| mixnet_xl | 81.074 (18.926) | 95.282 (4.718) | 11.90 | TBD | bicubic | 256 | 1.0 |
|
||||
| efficientnet_b2 | 80.612 (19.388) | 95.318 (4.682) | 9.1 | TBD | bicubic | 288 | 1.0 |
|
||||
| mixnet_xl | 80.476 (19.524) | 94.936 (5.064) | 11.90 | TBD | bicubic | 224 | 0.875 |
|
||||
| efficientnet_b2 | 80.288 (19.712) | 95.166 (4.834) | 9.1 | 1003 | bicubic | 260 | 0.890 |
|
||||
| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33 | TBD | bicubic | 224 | 0.875 |
|
||||
| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.8 | 694 | bicubic | 240 | 0.882 |
|
||||
| efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44 | TBD | bicubic | 224 | 0.875 |
|
||||
| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.3 | 390 | bicubic | 224 | 0.875 |
|
||||
| mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8 | TBD | bicubic | 224 | 0.875 |
|
||||
| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | 353 | bicubic | 224 | 0.875 |
|
||||
| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1 | TBD | bicubic | 224 | 0.875 |
|
||||
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | TBD | bicubic | 224 | 0.875 |
|
||||
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5 | TBD | bicubic | 224 | 0.875 |
|
||||
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | 219 | bicubic | 224 | 0.875 |
|
||||
| efficientnet_lite0 | 75.472 (24.528) | 92.520 (7.480) | 4.65 | TBD | bicubic | 224 | 0.875 |
|
||||
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.9 | 312 | bicubic | 224 | 0.875 |
|
||||
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | 385 | bilinear | 224 | 0.875 |
|
||||
| mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5 | TBD | bicubic | 224 | 0.875 |
|
||||
| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.4 | 315 | bicubic | 224 | 0.875 |
|
||||
| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.4 | TBD | bilinear | 224 | 0.875 |
|
||||
| mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5 | TBD | bicubic | 224 | 0.875 |
|
||||
|
||||
|
||||
More pretrained models to come...
|
||||
|
||||
|
||||
## Ported Weights
|
||||
|
||||
The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args.
|
||||
|
||||
**IMPORTANT:**
|
||||
* Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, EfficientNet-Lite, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std.
|
||||
* Enabling the Tensorflow preprocessing pipeline with `--tf-preprocessing` at validation time will improve scores by 0.1-0.5%, very close to original TF impl.
|
||||
|
||||
To run validation for tf_efficientnet_b5:
|
||||
`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --crop-pct 0.934 --interpolation bicubic`
|
||||
|
||||
To run validation w/ TF preprocessing for tf_efficientnet_b5:
|
||||
`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --tf-preprocessing`
|
||||
|
||||
To run validation for a model with Inception preprocessing, ie EfficientNet-B8 AdvProp:
|
||||
`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b8_ap -b 48 --num-gpu 2 --img-size 672 --crop-pct 0.954 --mean 0.5 --std 0.5`
|
||||
|
||||
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size | Crop |
|
||||
|---|---|---|---|---|---|---|
|
||||
| tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 | N/A |
|
||||
| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 | 0.961 |
|
||||
| tf_efficientnet_l2_ns_475 | 88.234 (11.766) | 98.546 (1.454) | 480 | bicubic | 475 | 0.936 |
|
||||
| tf_efficientnet_l2_ns_475 *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 | N/A |
|
||||
| tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 | N/A |
|
||||
| tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 | N/A |
|
||||
| tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 | N/A |
|
||||
| tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 | N/A |
|
||||
| tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 | N/A |
|
||||
| tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 | N/A |
|
||||
| tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 | N/A |
|
||||
| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 | N/A |
|
||||
| tf_efficientnet_b8 | 85.370 (14.630) | 97.390 (2.610) | 87.4 | bicubic | 672 | 0.954 |
|
||||
| tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 | 0.954 |
|
||||
| tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 | N/A |
|
||||
| tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 | 0.922 |
|
||||
| tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 | N/A |
|
||||
| tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 | 0.949 |
|
||||
| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 | N/A |
|
||||
| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 | 0.949 |
|
||||
| tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 | 0.942 |
|
||||
| tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 | N/A |
|
||||
| tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 | N/A |
|
||||
| tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 | 0.934 |
|
||||
| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 | N/A |
|
||||
| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 | 0.942 |
|
||||
| tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 | N/A |
|
||||
| tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 | .904 |
|
||||
| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 | N/A |
|
||||
| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 | 0.934 |
|
||||
| tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 | N/A |
|
||||
| tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 | 0.922 |
|
||||
| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 | 0.922 |
|
||||
| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 | N/A |
|
||||
| tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 | N/A |
|
||||
| tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 | 0.89 |
|
||||
| tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A |
|
||||
| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | 0.904 |
|
||||
| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | 0.904 |
|
||||
| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A |
|
||||
| tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | 0.92 |
|
||||
| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | N/A |
|
||||
| tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | N/A |
|
||||
| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | 0.88 |
|
||||
| tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | 0.904 |
|
||||
| tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | N/A |
|
||||
| tf_efficientnet_b2_ap *tfp | 80.420 (19.580) | 95.040 (4.960) | 9.11 | bicubic | 260 | N/A |
|
||||
| tf_efficientnet_b2_ap | 80.306 (19.694) | 95.028 (4.972) | 9.11 | bicubic | 260 | 0.890 |
|
||||
| tf_efficientnet_b2 *tfp | 80.188 (19.812) | 94.974 (5.026) | 9.11 | bicubic | 260 | N/A |
|
||||
| tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | 0.890 |
|
||||
| tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | 0.904 |
|
||||
| tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | N/A |
|
||||
| tf_efficientnet_b1_ap *tfp | 79.532 (20.468) | 94.378 (5.622) | 7.79 | bicubic | 240 | N/A |
|
||||
| tf_efficientnet_cc_b1_8e *tfp | 79.464 (20.536)| 94.492 (5.508) | 39.7 | bicubic | 240 | 0.88 |
|
||||
| tf_efficientnet_cc_b1_8e | 79.298 (20.702) | 94.364 (5.636) | 39.7 | bicubic | 240 | 0.88 |
|
||||
| tf_efficientnet_b1_ap | 79.278 (20.722) | 94.308 (5.692) | 7.79 | bicubic | 240 | 0.88 |
|
||||
| tf_efficientnet_b1 *tfp | 79.172 (20.828) | 94.450 (5.550) | 7.79 | bicubic | 240 | N/A |
|
||||
| tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 | N/A |
|
||||
| tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 | N/A |
|
||||
| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 | N/A |
|
||||
| tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 | 0.88 |
|
||||
| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 | 0.875 |
|
||||
| tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_cc_b0_4e *tfp | 77.746 (22.254) | 93.552 (6.448) | 13.3 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | N/A |
|
||||
| tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | N/A |
|
||||
| tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | 0.89 |
|
||||
| tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | N/A |
|
||||
| tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | N/A |
|
||||
| tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | N/A |
|
||||
| tf_efficientnet_b0_ap | 77.084 (22.916) | 93.254 (6.746) | 5.29 | bicubic | 224 | 0.875 |
|
||||
| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | N/A |
|
||||
| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | N/A |
|
||||
| tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | 0.882 |
|
||||
| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | N/A |
|
||||
| tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | N/A |
|
||||
| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | 0.875 |
|
||||
| tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | 0.875 |
|
||||
| tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | N/A |
|
||||
| tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | 0.875 |
|
||||
| tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |N/A |
|
||||
| tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | 0.875 |
|
||||
| tf_mobilenetv3_large_minimal_100 *tfp | 72.678 (27.322) | 90.860 (9.140) | 3.92 | bilinear | 224 | N/A |
|
||||
| tf_mobilenetv3_large_minimal_100 | 72.244 (27.756) | 90.636 (9.364) | 3.92 | bilinear | 224 | 0.875 |
|
||||
| tf_mobilenetv3_small_100 *tfp | 67.918 (32.082) | 87.958 (12.042 | 2.54 | bilinear | 224 | N/A |
|
||||
| tf_mobilenetv3_small_100 | 67.918 (32.082) | 87.662 (12.338) | 2.54 | bilinear | 224 | 0.875 |
|
||||
| tf_mobilenetv3_small_075 *tfp | 66.142 (33.858) | 86.498 (13.502) | 2.04 | bilinear | 224 | N/A |
|
||||
| tf_mobilenetv3_small_075 | 65.718 (34.282) | 86.136 (13.864) | 2.04 | bilinear | 224 | 0.875 |
|
||||
| tf_mobilenetv3_small_minimal_100 *tfp | 63.378 (36.622) | 84.802 (15.198) | 2.04 | bilinear | 224 | N/A |
|
||||
| tf_mobilenetv3_small_minimal_100 | 62.898 (37.102) | 84.230 (15.770) | 2.04 | bilinear | 224 | 0.875 |
|
||||
|
||||
|
||||
*tfp models validated with `tf-preprocessing` pipeline
|
||||
|
||||
Google tf and tflite weights ported from official Tensorflow repositories
|
||||
* https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
||||
* https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
|
||||
* https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment
|
||||
|
||||
All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x, 3.8.x.
|
||||
|
||||
Users have reported that a Python 3 Anaconda install in Windows works. I have not verified this myself.
|
||||
|
||||
PyTorch versions 1.4, 1.5, 1.6 have been tested with this code.
|
||||
|
||||
I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda:
|
||||
```
|
||||
conda create -n torch-env
|
||||
conda activate torch-env
|
||||
conda install -c pytorch pytorch torchvision cudatoolkit=10.2
|
||||
```
|
||||
|
||||
### PyTorch Hub
|
||||
|
||||
Models can be accessed via the PyTorch Hub API
|
||||
|
||||
```
|
||||
>>> torch.hub.list('rwightman/gen-efficientnet-pytorch')
|
||||
['efficientnet_b0', ...]
|
||||
>>> model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True)
|
||||
>>> model.eval()
|
||||
>>> output = model(torch.randn(1,3,224,224))
|
||||
```
|
||||
|
||||
### Pip
|
||||
This package can be installed via pip.
|
||||
|
||||
Install (after conda env/install):
|
||||
```
|
||||
pip install geffnet
|
||||
```
|
||||
|
||||
Eval use:
|
||||
```
|
||||
>>> import geffnet
|
||||
>>> m = geffnet.create_model('mobilenetv3_large_100', pretrained=True)
|
||||
>>> m.eval()
|
||||
```
|
||||
|
||||
Train use:
|
||||
```
|
||||
>>> import geffnet
|
||||
>>> # models can also be created by using the entrypoint directly
|
||||
>>> m = geffnet.efficientnet_b2(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2)
|
||||
>>> m.train()
|
||||
```
|
||||
|
||||
Create in a nn.Sequential container, for fast.ai, etc:
|
||||
```
|
||||
>>> import geffnet
|
||||
>>> m = geffnet.mixnet_l(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2, as_sequential=True)
|
||||
```
|
||||
|
||||
### Exporting
|
||||
|
||||
Scripts are included to
|
||||
* export models to ONNX (`onnx_export.py`)
|
||||
* optimized ONNX graph (`onnx_optimize.py` or `onnx_validate.py` w/ `--onnx-output-opt` arg)
|
||||
* validate with ONNX runtime (`onnx_validate.py`)
|
||||
* convert ONNX model to Caffe2 (`onnx_to_caffe.py`)
|
||||
* validate in Caffe2 (`caffe2_validate.py`)
|
||||
* benchmark in Caffe2 w/ FLOPs, parameters output (`caffe2_benchmark.py`)
|
||||
|
||||
As an example, to export the MobileNet-V3 pretrained model and then run an Imagenet validation:
|
||||
```
|
||||
python onnx_export.py --model mobilenetv3_large_100 ./mobilenetv3_100.onnx
|
||||
python onnx_validate.py /imagenet/validation/ --onnx-input ./mobilenetv3_100.onnx
|
||||
```
|
||||
|
||||
These scripts were tested to be working as of PyTorch 1.6 and ONNX 1.7 w/ ONNX runtime 1.4. Caffe2 compatible
|
||||
export now requires additional args mentioned in the export script (not needed in earlier versions).
|
||||
|
||||
#### Export Notes
|
||||
1. The TF ported weights with the 'SAME' conv padding activated cannot be exported to ONNX unless `_EXPORTABLE` flag in `config.py` is set to True. Use `config.set_exportable(True)` as in the `onnx_export.py` script.
|
||||
2. TF ported models with 'SAME' padding will have the padding fixed at export time to the resolution used for export. Even though dynamic padding is supported in opset >= 11, I can't get it working.
|
||||
3. ONNX optimize facility doesn't work reliably in PyTorch 1.6 / ONNX 1.7. Fortunately, the onnxruntime based inference is working very well now and includes on the fly optimization.
|
||||
3. ONNX / Caffe2 export/import frequently breaks with different PyTorch and ONNX version releases. Please check their respective issue trackers before filing issues here.
|
||||
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
""" Caffe2 validation script
|
||||
|
||||
This script runs Caffe2 benchmark on exported ONNX model.
|
||||
It is a useful tool for reporting model FLOPS.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import argparse
|
||||
from caffe2.python import core, workspace, model_helper
|
||||
from caffe2.proto import caffe2_pb2
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark')
|
||||
parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME',
|
||||
help='caffe2 model pb name prefix')
|
||||
parser.add_argument('--c2-init', default='', type=str, metavar='PATH',
|
||||
help='caffe2 model init .pb')
|
||||
parser.add_argument('--c2-predict', default='', type=str, metavar='PATH',
|
||||
help='caffe2 model predict .pb')
|
||||
parser.add_argument('-b', '--batch-size', default=1, type=int,
|
||||
metavar='N', help='mini-batch size (default: 1)')
|
||||
parser.add_argument('--img-size', default=224, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
args.gpu_id = 0
|
||||
if args.c2_prefix:
|
||||
args.c2_init = args.c2_prefix + '.init.pb'
|
||||
args.c2_predict = args.c2_prefix + '.predict.pb'
|
||||
|
||||
model = model_helper.ModelHelper(name="le_net", init_params=False)
|
||||
|
||||
# Bring in the init net from init_net.pb
|
||||
init_net_proto = caffe2_pb2.NetDef()
|
||||
with open(args.c2_init, "rb") as f:
|
||||
init_net_proto.ParseFromString(f.read())
|
||||
model.param_init_net = core.Net(init_net_proto)
|
||||
|
||||
# bring in the predict net from predict_net.pb
|
||||
predict_net_proto = caffe2_pb2.NetDef()
|
||||
with open(args.c2_predict, "rb") as f:
|
||||
predict_net_proto.ParseFromString(f.read())
|
||||
model.net = core.Net(predict_net_proto)
|
||||
|
||||
# CUDA performance not impressive
|
||||
#device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id)
|
||||
#model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
|
||||
#model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
|
||||
|
||||
input_blob = model.net.external_inputs[0]
|
||||
model.param_init_net.GaussianFill(
|
||||
[],
|
||||
input_blob.GetUnscopedName(),
|
||||
shape=(args.batch_size, 3, args.img_size, args.img_size),
|
||||
mean=0.0,
|
||||
std=1.0)
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.CreateNet(model.net, overwrite=True)
|
||||
workspace.BenchmarkNet(model.net.Proto().name, 5, 20, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,138 @@
|
||||
""" Caffe2 validation script
|
||||
|
||||
This script is created to verify exported ONNX models running in Caffe2
|
||||
It utilizes the same PyTorch dataloader/processing pipeline for a
|
||||
fair comparison against the originals.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from caffe2.python import core, workspace, model_helper
|
||||
from caffe2.proto import caffe2_pb2
|
||||
from data import create_loader, resolve_data_config, Dataset
|
||||
from utils import AverageMeter
|
||||
import time
|
||||
|
||||
parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME',
|
||||
help='caffe2 model pb name prefix')
|
||||
parser.add_argument('--c2-init', default='', type=str, metavar='PATH',
|
||||
help='caffe2 model init .pb')
|
||||
parser.add_argument('--c2-predict', default='', type=str, metavar='PATH',
|
||||
help='caffe2 model predict .pb')
|
||||
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N', help='mini-batch size (default: 256)')
|
||||
parser.add_argument('--img-size', default=None, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override mean pixel value of dataset')
|
||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||
help='Override std deviation of of dataset')
|
||||
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
|
||||
help='Override default crop pct of 0.875')
|
||||
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
|
||||
help='use tensorflow mnasnet preporcessing')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
args.gpu_id = 0
|
||||
if args.c2_prefix:
|
||||
args.c2_init = args.c2_prefix + '.init.pb'
|
||||
args.c2_predict = args.c2_prefix + '.predict.pb'
|
||||
|
||||
model = model_helper.ModelHelper(name="validation_net", init_params=False)
|
||||
|
||||
# Bring in the init net from init_net.pb
|
||||
init_net_proto = caffe2_pb2.NetDef()
|
||||
with open(args.c2_init, "rb") as f:
|
||||
init_net_proto.ParseFromString(f.read())
|
||||
model.param_init_net = core.Net(init_net_proto)
|
||||
|
||||
# bring in the predict net from predict_net.pb
|
||||
predict_net_proto = caffe2_pb2.NetDef()
|
||||
with open(args.c2_predict, "rb") as f:
|
||||
predict_net_proto.ParseFromString(f.read())
|
||||
model.net = core.Net(predict_net_proto)
|
||||
|
||||
data_config = resolve_data_config(None, args)
|
||||
loader = create_loader(
|
||||
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.batch_size,
|
||||
use_prefetcher=False,
|
||||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
tensorflow_preprocessing=args.tf_preprocessing)
|
||||
|
||||
# this is so obvious, wonderful interface </sarcasm>
|
||||
input_blob = model.net.external_inputs[0]
|
||||
output_blob = model.net.external_outputs[0]
|
||||
|
||||
if True:
|
||||
device_opts = None
|
||||
else:
|
||||
# CUDA is crashing, no idea why, awesome error message, give it a try for kicks
|
||||
device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id)
|
||||
model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
|
||||
model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
|
||||
|
||||
model.param_init_net.GaussianFill(
|
||||
[], input_blob.GetUnscopedName(),
|
||||
shape=(1,) + data_config['input_size'], mean=0.0, std=1.0)
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.CreateNet(model.net, overwrite=True)
|
||||
|
||||
batch_time = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
end = time.time()
|
||||
for i, (input, target) in enumerate(loader):
|
||||
# run the net and return prediction
|
||||
caffe2_in = input.data.numpy()
|
||||
workspace.FeedBlob(input_blob, caffe2_in, device_opts)
|
||||
workspace.RunNet(model.net, num_iter=1)
|
||||
output = workspace.FetchBlob(output_blob)
|
||||
|
||||
# measure accuracy and record loss
|
||||
prec1, prec5 = accuracy_np(output.data, target.numpy())
|
||||
top1.update(prec1.item(), input.size(0))
|
||||
top5.update(prec5.item(), input.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print('Test: [{0}/{1}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t'
|
||||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
||||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
|
||||
i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg,
|
||||
ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5))
|
||||
|
||||
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
|
||||
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
|
||||
|
||||
|
||||
def accuracy_np(output, target):
|
||||
max_indices = np.argsort(output, axis=1)[:, ::-1]
|
||||
top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
|
||||
top1 = 100 * np.equal(max_indices[:, 0], target).mean()
|
||||
return top1, top5
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,5 @@
|
||||
from .gen_efficientnet import *
|
||||
from .mobilenetv3 import *
|
||||
from .model_factory import create_model
|
||||
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable
|
||||
from .activations import *
|
||||
@@ -0,0 +1,137 @@
|
||||
from geffnet import config
|
||||
from geffnet.activations.activations_me import *
|
||||
from geffnet.activations.activations_jit import *
|
||||
from geffnet.activations.activations import *
|
||||
import torch
|
||||
|
||||
_has_silu = 'silu' in dir(torch.nn.functional)
|
||||
|
||||
_ACT_FN_DEFAULT = dict(
|
||||
silu=F.silu if _has_silu else swish,
|
||||
swish=F.silu if _has_silu else swish,
|
||||
mish=mish,
|
||||
relu=F.relu,
|
||||
relu6=F.relu6,
|
||||
sigmoid=sigmoid,
|
||||
tanh=tanh,
|
||||
hard_sigmoid=hard_sigmoid,
|
||||
hard_swish=hard_swish,
|
||||
)
|
||||
|
||||
_ACT_FN_JIT = dict(
|
||||
silu=F.silu if _has_silu else swish_jit,
|
||||
swish=F.silu if _has_silu else swish_jit,
|
||||
mish=mish_jit,
|
||||
)
|
||||
|
||||
_ACT_FN_ME = dict(
|
||||
silu=F.silu if _has_silu else swish_me,
|
||||
swish=F.silu if _has_silu else swish_me,
|
||||
mish=mish_me,
|
||||
hard_swish=hard_swish_me,
|
||||
hard_sigmoid_jit=hard_sigmoid_me,
|
||||
)
|
||||
|
||||
_ACT_LAYER_DEFAULT = dict(
|
||||
silu=nn.SiLU if _has_silu else Swish,
|
||||
swish=nn.SiLU if _has_silu else Swish,
|
||||
mish=Mish,
|
||||
relu=nn.ReLU,
|
||||
relu6=nn.ReLU6,
|
||||
sigmoid=Sigmoid,
|
||||
tanh=Tanh,
|
||||
hard_sigmoid=HardSigmoid,
|
||||
hard_swish=HardSwish,
|
||||
)
|
||||
|
||||
_ACT_LAYER_JIT = dict(
|
||||
silu=nn.SiLU if _has_silu else SwishJit,
|
||||
swish=nn.SiLU if _has_silu else SwishJit,
|
||||
mish=MishJit,
|
||||
)
|
||||
|
||||
_ACT_LAYER_ME = dict(
|
||||
silu=nn.SiLU if _has_silu else SwishMe,
|
||||
swish=nn.SiLU if _has_silu else SwishMe,
|
||||
mish=MishMe,
|
||||
hard_swish=HardSwishMe,
|
||||
hard_sigmoid=HardSigmoidMe
|
||||
)
|
||||
|
||||
_OVERRIDE_FN = dict()
|
||||
_OVERRIDE_LAYER = dict()
|
||||
|
||||
|
||||
def add_override_act_fn(name, fn):
|
||||
global _OVERRIDE_FN
|
||||
_OVERRIDE_FN[name] = fn
|
||||
|
||||
|
||||
def update_override_act_fn(overrides):
|
||||
assert isinstance(overrides, dict)
|
||||
global _OVERRIDE_FN
|
||||
_OVERRIDE_FN.update(overrides)
|
||||
|
||||
|
||||
def clear_override_act_fn():
|
||||
global _OVERRIDE_FN
|
||||
_OVERRIDE_FN = dict()
|
||||
|
||||
|
||||
def add_override_act_layer(name, fn):
|
||||
_OVERRIDE_LAYER[name] = fn
|
||||
|
||||
|
||||
def update_override_act_layer(overrides):
|
||||
assert isinstance(overrides, dict)
|
||||
global _OVERRIDE_LAYER
|
||||
_OVERRIDE_LAYER.update(overrides)
|
||||
|
||||
|
||||
def clear_override_act_layer():
|
||||
global _OVERRIDE_LAYER
|
||||
_OVERRIDE_LAYER = dict()
|
||||
|
||||
|
||||
def get_act_fn(name='relu'):
|
||||
""" Activation Function Factory
|
||||
Fetching activation fns by name with this function allows export or torch script friendly
|
||||
functions to be returned dynamically based on current config.
|
||||
"""
|
||||
if name in _OVERRIDE_FN:
|
||||
return _OVERRIDE_FN[name]
|
||||
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
|
||||
if use_me and name in _ACT_FN_ME:
|
||||
# If not exporting or scripting the model, first look for a memory optimized version
|
||||
# activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin
|
||||
return _ACT_FN_ME[name]
|
||||
if config.is_exportable() and name in ('silu', 'swish'):
|
||||
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
|
||||
return swish
|
||||
use_jit = not (config.is_exportable() or config.is_no_jit())
|
||||
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
|
||||
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
|
||||
return _ACT_FN_JIT[name]
|
||||
return _ACT_FN_DEFAULT[name]
|
||||
|
||||
|
||||
def get_act_layer(name='relu'):
|
||||
""" Activation Layer Factory
|
||||
Fetching activation layers by name with this function allows export or torch script friendly
|
||||
functions to be returned dynamically based on current config.
|
||||
"""
|
||||
if name in _OVERRIDE_LAYER:
|
||||
return _OVERRIDE_LAYER[name]
|
||||
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
|
||||
if use_me and name in _ACT_LAYER_ME:
|
||||
return _ACT_LAYER_ME[name]
|
||||
if config.is_exportable() and name in ('silu', 'swish'):
|
||||
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
|
||||
return Swish
|
||||
use_jit = not (config.is_exportable() or config.is_no_jit())
|
||||
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
|
||||
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
|
||||
return _ACT_LAYER_JIT[name]
|
||||
return _ACT_LAYER_DEFAULT[name]
|
||||
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
""" Activations
|
||||
|
||||
A collection of activations fn and modules with a common interface so that they can
|
||||
easily be swapped. All have an `inplace` arg even if not used.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def swish(x, inplace: bool = False):
|
||||
"""Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
|
||||
and also as Swish (https://arxiv.org/abs/1710.05941).
|
||||
|
||||
TODO Rename to SiLU with addition to PyTorch
|
||||
"""
|
||||
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(Swish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return swish(x, self.inplace)
|
||||
|
||||
|
||||
def mish(x, inplace: bool = False):
|
||||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||
"""
|
||||
return x.mul(F.softplus(x).tanh())
|
||||
|
||||
|
||||
class Mish(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(Mish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return mish(x, self.inplace)
|
||||
|
||||
|
||||
def sigmoid(x, inplace: bool = False):
|
||||
return x.sigmoid_() if inplace else x.sigmoid()
|
||||
|
||||
|
||||
# PyTorch has this, but not with a consistent inplace argmument interface
|
||||
class Sigmoid(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(Sigmoid, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return x.sigmoid_() if self.inplace else x.sigmoid()
|
||||
|
||||
|
||||
def tanh(x, inplace: bool = False):
|
||||
return x.tanh_() if inplace else x.tanh()
|
||||
|
||||
|
||||
# PyTorch has this, but not with a consistent inplace argmument interface
|
||||
class Tanh(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(Tanh, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return x.tanh_() if self.inplace else x.tanh()
|
||||
|
||||
|
||||
def hard_swish(x, inplace: bool = False):
|
||||
inner = F.relu6(x + 3.).div_(6.)
|
||||
return x.mul_(inner) if inplace else x.mul(inner)
|
||||
|
||||
|
||||
class HardSwish(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSwish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return hard_swish(x, self.inplace)
|
||||
|
||||
|
||||
def hard_sigmoid(x, inplace: bool = False):
|
||||
if inplace:
|
||||
return x.add_(3.).clamp_(0., 6.).div_(6.)
|
||||
else:
|
||||
return F.relu6(x + 3.) / 6.
|
||||
|
||||
|
||||
class HardSigmoid(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSigmoid, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return hard_sigmoid(x, self.inplace)
|
||||
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
""" Activations (jit)
|
||||
|
||||
A collection of jit-scripted activations fn and modules with a common interface so that they can
|
||||
easily be swapped. All have an `inplace` arg even if not used.
|
||||
|
||||
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
|
||||
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
|
||||
versions if they contain in-place ops.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit',
|
||||
'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit']
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit(x, inplace: bool = False):
|
||||
"""Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
|
||||
and also as Swish (https://arxiv.org/abs/1710.05941).
|
||||
|
||||
TODO Rename to SiLU with addition to PyTorch
|
||||
"""
|
||||
return x.mul(x.sigmoid())
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit(x, _inplace: bool = False):
|
||||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||
"""
|
||||
return x.mul(F.softplus(x).tanh())
|
||||
|
||||
|
||||
class SwishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(SwishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return swish_jit(x)
|
||||
|
||||
|
||||
class MishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(MishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return mish_jit(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_sigmoid_jit(x, inplace: bool = False):
|
||||
# return F.relu6(x + 3.) / 6.
|
||||
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
||||
|
||||
|
||||
class HardSigmoidJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSigmoidJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return hard_sigmoid_jit(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_swish_jit(x, inplace: bool = False):
|
||||
# return x * (F.relu6(x + 3.) / 6)
|
||||
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
||||
|
||||
|
||||
class HardSwishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSwishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return hard_swish_jit(x)
|
||||
@@ -0,0 +1,174 @@
|
||||
""" Activations (memory-efficient w/ custom autograd)
|
||||
|
||||
A collection of activations fn and modules with a common interface so that they can
|
||||
easily be swapped. All have an `inplace` arg even if not used.
|
||||
|
||||
These activations are not compatible with jit scripting or ONNX export of the model, please use either
|
||||
the JIT or basic versions of the activations.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
__all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe',
|
||||
'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe']
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit_fwd(x):
|
||||
return x.mul(torch.sigmoid(x))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit_bwd(x, grad_output):
|
||||
x_sigmoid = torch.sigmoid(x)
|
||||
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
||||
|
||||
|
||||
class SwishJitAutoFn(torch.autograd.Function):
|
||||
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
|
||||
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
||||
https://twitter.com/jeremyphoward/status/1188251041835315200
|
||||
|
||||
Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
|
||||
and also as Swish (https://arxiv.org/abs/1710.05941).
|
||||
|
||||
TODO Rename to SiLU with addition to PyTorch
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return swish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return swish_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def swish_me(x, inplace=False):
|
||||
return SwishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class SwishMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(SwishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return SwishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit_fwd(x):
|
||||
return x.mul(torch.tanh(F.softplus(x)))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit_bwd(x, grad_output):
|
||||
x_sigmoid = torch.sigmoid(x)
|
||||
x_tanh_sp = F.softplus(x).tanh()
|
||||
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
||||
|
||||
|
||||
class MishJitAutoFn(torch.autograd.Function):
|
||||
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||
A memory efficient, jit scripted variant of Mish
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return mish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return mish_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def mish_me(x, inplace=False):
|
||||
return MishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class MishMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(MishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return MishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
|
||||
return (x + 3).clamp(min=0, max=6).div(6.)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_sigmoid_jit_bwd(x, grad_output):
|
||||
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
|
||||
return grad_output * m
|
||||
|
||||
|
||||
class HardSigmoidJitAutoFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return hard_sigmoid_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return hard_sigmoid_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def hard_sigmoid_me(x, inplace: bool = False):
|
||||
return HardSigmoidJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class HardSigmoidMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSigmoidMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return HardSigmoidJitAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_swish_jit_fwd(x):
|
||||
return x * (x + 3).clamp(min=0, max=6).div(6.)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_swish_jit_bwd(x, grad_output):
|
||||
m = torch.ones_like(x) * (x >= 3.)
|
||||
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
|
||||
return grad_output * m
|
||||
|
||||
|
||||
class HardSwishJitAutoFn(torch.autograd.Function):
|
||||
"""A memory efficient, jit-scripted HardSwish activation"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return hard_swish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return hard_swish_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def hard_swish_me(x, inplace=False):
|
||||
return HardSwishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class HardSwishMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSwishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return HardSwishJitAutoFn.apply(x)
|
||||
@@ -0,0 +1,123 @@
|
||||
""" Global layer config state
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
|
||||
__all__ = [
|
||||
'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs',
|
||||
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
|
||||
]
|
||||
|
||||
# Set to True if prefer to have layers with no jit optimization (includes activations)
|
||||
_NO_JIT = False
|
||||
|
||||
# Set to True if prefer to have activation layers with no jit optimization
|
||||
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
|
||||
# the jit flags so far are activations. This will change as more layers are updated and/or added.
|
||||
_NO_ACTIVATION_JIT = False
|
||||
|
||||
# Set to True if exporting a model with Same padding via ONNX
|
||||
_EXPORTABLE = False
|
||||
|
||||
# Set to True if wanting to use torch.jit.script on a model
|
||||
_SCRIPTABLE = False
|
||||
|
||||
|
||||
def is_no_jit():
|
||||
return _NO_JIT
|
||||
|
||||
|
||||
class set_no_jit:
|
||||
def __init__(self, mode: bool) -> None:
|
||||
global _NO_JIT
|
||||
self.prev = _NO_JIT
|
||||
_NO_JIT = mode
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _NO_JIT
|
||||
_NO_JIT = self.prev
|
||||
return False
|
||||
|
||||
|
||||
def is_exportable():
|
||||
return _EXPORTABLE
|
||||
|
||||
|
||||
class set_exportable:
|
||||
def __init__(self, mode: bool) -> None:
|
||||
global _EXPORTABLE
|
||||
self.prev = _EXPORTABLE
|
||||
_EXPORTABLE = mode
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _EXPORTABLE
|
||||
_EXPORTABLE = self.prev
|
||||
return False
|
||||
|
||||
|
||||
def is_scriptable():
|
||||
return _SCRIPTABLE
|
||||
|
||||
|
||||
class set_scriptable:
|
||||
def __init__(self, mode: bool) -> None:
|
||||
global _SCRIPTABLE
|
||||
self.prev = _SCRIPTABLE
|
||||
_SCRIPTABLE = mode
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _SCRIPTABLE
|
||||
_SCRIPTABLE = self.prev
|
||||
return False
|
||||
|
||||
|
||||
class set_layer_config:
|
||||
""" Layer config context manager that allows setting all layer config flags at once.
|
||||
If a flag arg is None, it will not change the current value.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
scriptable: Optional[bool] = None,
|
||||
exportable: Optional[bool] = None,
|
||||
no_jit: Optional[bool] = None,
|
||||
no_activation_jit: Optional[bool] = None):
|
||||
global _SCRIPTABLE
|
||||
global _EXPORTABLE
|
||||
global _NO_JIT
|
||||
global _NO_ACTIVATION_JIT
|
||||
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
|
||||
if scriptable is not None:
|
||||
_SCRIPTABLE = scriptable
|
||||
if exportable is not None:
|
||||
_EXPORTABLE = exportable
|
||||
if no_jit is not None:
|
||||
_NO_JIT = no_jit
|
||||
if no_activation_jit is not None:
|
||||
_NO_ACTIVATION_JIT = no_activation_jit
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _SCRIPTABLE
|
||||
global _EXPORTABLE
|
||||
global _NO_JIT
|
||||
global _NO_ACTIVATION_JIT
|
||||
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
|
||||
return False
|
||||
|
||||
|
||||
def layer_config_kwargs(kwargs):
|
||||
""" Consume config kwargs and return contextmgr obj """
|
||||
return set_layer_config(
|
||||
scriptable=kwargs.pop('scriptable', None),
|
||||
exportable=kwargs.pop('exportable', None),
|
||||
no_jit=kwargs.pop('no_jit', None))
|
||||
@@ -0,0 +1,304 @@
|
||||
""" Conv2D w/ SAME padding, CondConv, MixedConv
|
||||
|
||||
A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and
|
||||
MobileNetV3 models that maintain weight compatibility with original Tensorflow models.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import collections.abc
|
||||
import math
|
||||
from functools import partial
|
||||
from itertools import repeat
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .config import *
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
return parse
|
||||
|
||||
|
||||
_single = _ntuple(1)
|
||||
_pair = _ntuple(2)
|
||||
_triple = _ntuple(3)
|
||||
_quadruple = _ntuple(4)
|
||||
|
||||
|
||||
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
|
||||
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
||||
|
||||
|
||||
def _get_padding(kernel_size, stride=1, dilation=1, **_):
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
return padding
|
||||
|
||||
|
||||
def _calc_same_pad(i: int, k: int, s: int, d: int):
|
||||
return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
||||
|
||||
|
||||
def _same_pad_arg(input_size, kernel_size, stride, dilation):
|
||||
ih, iw = input_size
|
||||
kh, kw = kernel_size
|
||||
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
|
||||
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
|
||||
return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
|
||||
|
||||
|
||||
def _split_channels(num_chan, num_groups):
|
||||
split = [num_chan // num_groups for _ in range(num_groups)]
|
||||
split[0] += num_chan - sum(split)
|
||||
return split
|
||||
|
||||
|
||||
def conv2d_same(
|
||||
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
|
||||
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
|
||||
ih, iw = x.size()[-2:]
|
||||
kh, kw = weight.size()[-2:]
|
||||
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
|
||||
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
|
||||
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
||||
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
||||
|
||||
|
||||
class Conv2dSame(nn.Conv2d):
|
||||
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True):
|
||||
super(Conv2dSame, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
||||
|
||||
def forward(self, x):
|
||||
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class Conv2dSameExport(nn.Conv2d):
|
||||
""" ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
||||
|
||||
NOTE: This does not currently work with torch.jit.script
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True):
|
||||
super(Conv2dSameExport, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
||||
self.pad = None
|
||||
self.pad_input_size = (0, 0)
|
||||
|
||||
def forward(self, x):
|
||||
input_size = x.size()[-2:]
|
||||
if self.pad is None:
|
||||
pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
|
||||
self.pad = nn.ZeroPad2d(pad_arg)
|
||||
self.pad_input_size = input_size
|
||||
|
||||
if self.pad is not None:
|
||||
x = self.pad(x)
|
||||
return F.conv2d(
|
||||
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
def get_padding_value(padding, kernel_size, **kwargs):
|
||||
dynamic = False
|
||||
if isinstance(padding, str):
|
||||
# for any string padding, the padding will be calculated for you, one of three ways
|
||||
padding = padding.lower()
|
||||
if padding == 'same':
|
||||
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
||||
if _is_static_pad(kernel_size, **kwargs):
|
||||
# static case, no extra overhead
|
||||
padding = _get_padding(kernel_size, **kwargs)
|
||||
else:
|
||||
# dynamic padding
|
||||
padding = 0
|
||||
dynamic = True
|
||||
elif padding == 'valid':
|
||||
# 'VALID' padding, same as padding=0
|
||||
padding = 0
|
||||
else:
|
||||
# Default to PyTorch style 'same'-ish symmetric padding
|
||||
padding = _get_padding(kernel_size, **kwargs)
|
||||
return padding, dynamic
|
||||
|
||||
|
||||
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
||||
padding = kwargs.pop('padding', '')
|
||||
kwargs.setdefault('bias', False)
|
||||
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
|
||||
if is_dynamic:
|
||||
if is_exportable():
|
||||
assert not is_scriptable()
|
||||
return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||
|
||||
|
||||
class MixedConv2d(nn.ModuleDict):
|
||||
""" Mixed Grouped Convolution
|
||||
Based on MDConv and GroupedConv in MixNet impl:
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
|
||||
super(MixedConv2d, self).__init__()
|
||||
|
||||
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
|
||||
num_groups = len(kernel_size)
|
||||
in_splits = _split_channels(in_channels, num_groups)
|
||||
out_splits = _split_channels(out_channels, num_groups)
|
||||
self.in_channels = sum(in_splits)
|
||||
self.out_channels = sum(out_splits)
|
||||
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
||||
conv_groups = out_ch if depthwise else 1
|
||||
self.add_module(
|
||||
str(idx),
|
||||
create_conv2d_pad(
|
||||
in_ch, out_ch, k, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
|
||||
)
|
||||
self.splits = in_splits
|
||||
|
||||
def forward(self, x):
|
||||
x_split = torch.split(x, self.splits, 1)
|
||||
x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())]
|
||||
x = torch.cat(x_out, 1)
|
||||
return x
|
||||
|
||||
|
||||
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
||||
def condconv_initializer(weight):
|
||||
"""CondConv initializer function."""
|
||||
num_params = np.prod(expert_shape)
|
||||
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
|
||||
weight.shape[1] != num_params):
|
||||
raise (ValueError(
|
||||
'CondConv variables must have shape [num_experts, num_params]'))
|
||||
for i in range(num_experts):
|
||||
initializer(weight[i].view(expert_shape))
|
||||
return condconv_initializer
|
||||
|
||||
|
||||
class CondConv2d(nn.Module):
|
||||
""" Conditional Convolution
|
||||
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
||||
|
||||
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
|
||||
https://github.com/pytorch/pytorch/issues/17983
|
||||
"""
|
||||
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
|
||||
super(CondConv2d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.stride = _pair(stride)
|
||||
padding_val, is_padding_dynamic = get_padding_value(
|
||||
padding, kernel_size, stride=stride, dilation=dilation)
|
||||
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
|
||||
self.padding = _pair(padding_val)
|
||||
self.dilation = _pair(dilation)
|
||||
self.groups = groups
|
||||
self.num_experts = num_experts
|
||||
|
||||
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||
weight_num_param = 1
|
||||
for wd in self.weight_shape:
|
||||
weight_num_param *= wd
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
|
||||
|
||||
if bias:
|
||||
self.bias_shape = (self.out_channels,)
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init_weight = get_condconv_initializer(
|
||||
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
|
||||
init_weight(self.weight)
|
||||
if self.bias is not None:
|
||||
fan_in = np.prod(self.weight_shape[1:])
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
init_bias = get_condconv_initializer(
|
||||
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
|
||||
init_bias(self.bias)
|
||||
|
||||
def forward(self, x, routing_weights):
|
||||
B, C, H, W = x.shape
|
||||
weight = torch.matmul(routing_weights, self.weight)
|
||||
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||
weight = weight.view(new_weight_shape)
|
||||
bias = None
|
||||
if self.bias is not None:
|
||||
bias = torch.matmul(routing_weights, self.bias)
|
||||
bias = bias.view(B * self.out_channels)
|
||||
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
|
||||
x = x.view(1, B * C, H, W)
|
||||
if self.dynamic_padding:
|
||||
out = conv2d_same(
|
||||
x, weight, bias, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups * B)
|
||||
else:
|
||||
out = F.conv2d(
|
||||
x, weight, bias, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups * B)
|
||||
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
|
||||
|
||||
# Literal port (from TF definition)
|
||||
# x = torch.split(x, 1, 0)
|
||||
# weight = torch.split(weight, 1, 0)
|
||||
# if self.bias is not None:
|
||||
# bias = torch.matmul(routing_weights, self.bias)
|
||||
# bias = torch.split(bias, 1, 0)
|
||||
# else:
|
||||
# bias = [None] * B
|
||||
# out = []
|
||||
# for xi, wi, bi in zip(x, weight, bias):
|
||||
# wi = wi.view(*self.weight_shape)
|
||||
# if bi is not None:
|
||||
# bi = bi.view(*self.bias_shape)
|
||||
# out.append(self.conv_fn(
|
||||
# xi, wi, bi, stride=self.stride, padding=self.padding,
|
||||
# dilation=self.dilation, groups=self.groups))
|
||||
# out = torch.cat(out, 0)
|
||||
return out
|
||||
|
||||
|
||||
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
||||
if isinstance(kernel_size, list):
|
||||
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
|
||||
# We're going to use only lists for defining the MixedConv2d kernel groups,
|
||||
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
|
||||
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
depthwise = kwargs.pop('depthwise', False)
|
||||
groups = out_chs if depthwise else 1
|
||||
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
|
||||
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||
else:
|
||||
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||
return m
|
||||
@@ -0,0 +1,683 @@
|
||||
""" EfficientNet / MobileNetV3 Blocks and Builder
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
from .conv2d_layers import *
|
||||
from geffnet.activations import *
|
||||
|
||||
__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible',
|
||||
'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv',
|
||||
'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def',
|
||||
'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'
|
||||
]
|
||||
|
||||
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
||||
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
|
||||
# NOTE: momentum varies btw .99 and .9997 depending on source
|
||||
# .99 in official TF TPU impl
|
||||
# .9997 (/w .999 in search space) for paper
|
||||
#
|
||||
# PyTorch defaults are momentum = .1, eps = 1e-5
|
||||
#
|
||||
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
||||
BN_EPS_TF_DEFAULT = 1e-3
|
||||
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
|
||||
|
||||
|
||||
def get_bn_args_tf():
|
||||
return _BN_ARGS_TF.copy()
|
||||
|
||||
|
||||
def resolve_bn_args(kwargs):
|
||||
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
|
||||
bn_momentum = kwargs.pop('bn_momentum', None)
|
||||
if bn_momentum is not None:
|
||||
bn_args['momentum'] = bn_momentum
|
||||
bn_eps = kwargs.pop('bn_eps', None)
|
||||
if bn_eps is not None:
|
||||
bn_args['eps'] = bn_eps
|
||||
return bn_args
|
||||
|
||||
|
||||
_SE_ARGS_DEFAULT = dict(
|
||||
gate_fn=sigmoid,
|
||||
act_layer=None, # None == use containing block's activation layer
|
||||
reduce_mid=False,
|
||||
divisor=1)
|
||||
|
||||
|
||||
def resolve_se_args(kwargs, in_chs, act_layer=None):
|
||||
se_kwargs = kwargs.copy() if kwargs is not None else {}
|
||||
# fill in args that aren't specified with the defaults
|
||||
for k, v in _SE_ARGS_DEFAULT.items():
|
||||
se_kwargs.setdefault(k, v)
|
||||
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
|
||||
if not se_kwargs.pop('reduce_mid'):
|
||||
se_kwargs['reduced_base_chs'] = in_chs
|
||||
# act_layer override, if it remains None, the containing block's act_layer will be used
|
||||
if se_kwargs['act_layer'] is None:
|
||||
assert act_layer is not None
|
||||
se_kwargs['act_layer'] = act_layer
|
||||
return se_kwargs
|
||||
|
||||
|
||||
def resolve_act_layer(kwargs, default='relu'):
|
||||
act_layer = kwargs.pop('act_layer', default)
|
||||
if isinstance(act_layer, str):
|
||||
act_layer = get_act_layer(act_layer)
|
||||
return act_layer
|
||||
|
||||
|
||||
def make_divisible(v: int, divisor: int = 8, min_value: int = None):
|
||||
min_value = min_value or divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v: # ensure round down does not go down by more than 10%.
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
if not multiplier:
|
||||
return channels
|
||||
channels *= multiplier
|
||||
return make_divisible(channels, divisor, channel_min)
|
||||
|
||||
|
||||
def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
|
||||
"""Apply drop connect."""
|
||||
if not training:
|
||||
return inputs
|
||||
|
||||
keep_prob = 1 - drop_connect_rate
|
||||
random_tensor = keep_prob + torch.rand(
|
||||
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = inputs.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class SqueezeExcite(nn.Module):
|
||||
|
||||
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
|
||||
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
||||
self.gate_fn = gate_fn
|
||||
|
||||
def forward(self, x):
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
x_se = self.conv_reduce(x_se)
|
||||
x_se = self.act1(x_se)
|
||||
x_se = self.conv_expand(x_se)
|
||||
x = x * self.gate_fn(x_se)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBnAct(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
super(ConvBnAct, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type)
|
||||
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
return x
|
||||
|
||||
|
||||
class DepthwiseSeparableConv(nn.Module):
|
||||
""" DepthwiseSeparable block
|
||||
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
|
||||
factor of 1.0. This is an alternative to having a IR with optional first pw conv.
|
||||
"""
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
self.conv_dw = select_conv2d(
|
||||
in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True)
|
||||
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if se_ratio is not None and se_ratio > 0.:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
|
||||
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||
self.act2 = act_layer(inplace=True) if pw_act else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
x = self.se(x)
|
||||
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
""" Inverted residual block w/ optional SE"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
conv_kwargs=None, drop_connect_rate=0.):
|
||||
super(InvertedResidual, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
mid_chs: int = make_divisible(in_chs * exp_ratio)
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
# Point-wise expansion
|
||||
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
||||
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Depth-wise convolution
|
||||
self.conv_dw = select_conv2d(
|
||||
mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs)
|
||||
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if se_ratio is not None and se_ratio > 0.:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
else:
|
||||
self.se = nn.Identity() # for jit.script compat
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
||||
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
# Point-wise expansion
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
# Depth-wise convolution
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
x = self.se(x)
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class CondConvResidual(InvertedResidual):
|
||||
""" Inverted residual block w/ CondConv routing"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
num_experts=0, drop_connect_rate=0.):
|
||||
|
||||
self.num_experts = num_experts
|
||||
conv_kwargs = dict(num_experts=self.num_experts)
|
||||
|
||||
super(CondConvResidual, self).__init__(
|
||||
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type,
|
||||
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
|
||||
drop_connect_rate=drop_connect_rate)
|
||||
|
||||
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
# CondConv routing
|
||||
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
|
||||
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
|
||||
|
||||
# Point-wise expansion
|
||||
x = self.conv_pw(x, routing_weights)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
# Depth-wise convolution
|
||||
x = self.conv_dw(x, routing_weights)
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
x = self.se(x)
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x, routing_weights)
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class EdgeResidual(nn.Module):
|
||||
""" EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
|
||||
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
|
||||
super(EdgeResidual, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio)
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
# Expansion convolution
|
||||
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
|
||||
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if se_ratio is not None and se_ratio > 0.:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type)
|
||||
self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
# Expansion convolution
|
||||
x = self.conv_exp(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
x = self.se(x)
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn2(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x += residual
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EfficientNetBuilder:
|
||||
""" Build Trunk Blocks for Efficient/Mobile Networks
|
||||
|
||||
This ended up being somewhat of a cross between
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
|
||||
and
|
||||
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
pad_type='', act_layer=None, se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.channel_divisor = channel_divisor
|
||||
self.channel_min = channel_min
|
||||
self.pad_type = pad_type
|
||||
self.act_layer = act_layer
|
||||
self.se_kwargs = se_kwargs
|
||||
self.norm_layer = norm_layer
|
||||
self.norm_kwargs = norm_kwargs
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
# updated during build
|
||||
self.in_chs = None
|
||||
self.block_idx = 0
|
||||
self.block_count = 0
|
||||
|
||||
def _round_channels(self, chs):
|
||||
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
||||
|
||||
def _make_block(self, ba):
|
||||
bt = ba.pop('block_type')
|
||||
ba['in_chs'] = self.in_chs
|
||||
ba['out_chs'] = self._round_channels(ba['out_chs'])
|
||||
if 'fake_in_chs' in ba and ba['fake_in_chs']:
|
||||
# FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU
|
||||
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
|
||||
ba['norm_layer'] = self.norm_layer
|
||||
ba['norm_kwargs'] = self.norm_kwargs
|
||||
ba['pad_type'] = self.pad_type
|
||||
# block act fn overrides the model default
|
||||
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
|
||||
assert ba['act_layer'] is not None
|
||||
if bt == 'ir':
|
||||
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if ba.get('num_experts', 0) > 0:
|
||||
block = CondConvResidual(**ba)
|
||||
else:
|
||||
block = InvertedResidual(**ba)
|
||||
elif bt == 'ds' or bt == 'dsa':
|
||||
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
block = DepthwiseSeparableConv(**ba)
|
||||
elif bt == 'er':
|
||||
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
block = EdgeResidual(**ba)
|
||||
elif bt == 'cn':
|
||||
block = ConvBnAct(**ba)
|
||||
else:
|
||||
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
||||
return block
|
||||
|
||||
def _make_stack(self, stack_args):
|
||||
blocks = []
|
||||
# each stack (stage) contains a list of block arguments
|
||||
for i, ba in enumerate(stack_args):
|
||||
if i >= 1:
|
||||
# only the first block in any stack can have a stride > 1
|
||||
ba['stride'] = 1
|
||||
block = self._make_block(ba)
|
||||
blocks.append(block)
|
||||
self.block_idx += 1 # incr global idx (across all stacks)
|
||||
return nn.Sequential(*blocks)
|
||||
|
||||
def __call__(self, in_chs, block_args):
|
||||
""" Build the blocks
|
||||
Args:
|
||||
in_chs: Number of input-channels passed to first block
|
||||
block_args: A list of lists, outer list defines stages, inner
|
||||
list contains strings defining block configuration(s)
|
||||
Return:
|
||||
List of block stacks (each stack wrapped in nn.Sequential)
|
||||
"""
|
||||
self.in_chs = in_chs
|
||||
self.block_count = sum([len(x) for x in block_args])
|
||||
self.block_idx = 0
|
||||
blocks = []
|
||||
# outer list of block_args defines the stacks ('stages' by some conventions)
|
||||
for stack_idx, stack in enumerate(block_args):
|
||||
assert isinstance(stack, list)
|
||||
stack = self._make_stack(stack)
|
||||
blocks.append(stack)
|
||||
return blocks
|
||||
|
||||
|
||||
def _parse_ksize(ss):
|
||||
if ss.isdigit():
|
||||
return int(ss)
|
||||
else:
|
||||
return [int(k) for k in ss.split('.')]
|
||||
|
||||
|
||||
def _decode_block_str(block_str):
|
||||
""" Decode block definition string
|
||||
|
||||
Gets a list of block arg (dicts) through a string notation of arguments.
|
||||
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
|
||||
|
||||
All args can exist in any order with the exception of the leading string which
|
||||
is assumed to indicate the block type.
|
||||
|
||||
leading string - block type (
|
||||
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
|
||||
r - number of repeat blocks,
|
||||
k - kernel size,
|
||||
s - strides (1-9),
|
||||
e - expansion ratio,
|
||||
c - output channels,
|
||||
se - squeeze/excitation ratio
|
||||
n - activation fn ('re', 'r6', 'hs', or 'sw')
|
||||
Args:
|
||||
block_str: a string representation of block arguments.
|
||||
Returns:
|
||||
A list of block args (dicts)
|
||||
Raises:
|
||||
ValueError: if the string def not properly specified (TODO)
|
||||
"""
|
||||
assert isinstance(block_str, str)
|
||||
ops = block_str.split('_')
|
||||
block_type = ops[0] # take the block type off the front
|
||||
ops = ops[1:]
|
||||
options = {}
|
||||
noskip = False
|
||||
for op in ops:
|
||||
# string options being checked on individual basis, combine if they grow
|
||||
if op == 'noskip':
|
||||
noskip = True
|
||||
elif op.startswith('n'):
|
||||
# activation fn
|
||||
key = op[0]
|
||||
v = op[1:]
|
||||
if v == 're':
|
||||
value = get_act_layer('relu')
|
||||
elif v == 'r6':
|
||||
value = get_act_layer('relu6')
|
||||
elif v == 'hs':
|
||||
value = get_act_layer('hard_swish')
|
||||
elif v == 'sw':
|
||||
value = get_act_layer('swish')
|
||||
else:
|
||||
continue
|
||||
options[key] = value
|
||||
else:
|
||||
# all numeric options
|
||||
splits = re.split(r'(\d.*)', op)
|
||||
if len(splits) >= 2:
|
||||
key, value = splits[:2]
|
||||
options[key] = value
|
||||
|
||||
# if act_layer is None, the model default (passed to model init) will be used
|
||||
act_layer = options['n'] if 'n' in options else None
|
||||
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
||||
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
||||
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
|
||||
|
||||
num_repeat = int(options['r'])
|
||||
# each type of block has different valid arguments, fill accordingly
|
||||
if block_type == 'ir':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
noskip=noskip,
|
||||
)
|
||||
if 'cc' in options:
|
||||
block_args['num_experts'] = int(options['cc'])
|
||||
elif block_type == 'ds' or block_type == 'dsa':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
pw_act=block_type == 'dsa',
|
||||
noskip=block_type == 'dsa' or noskip,
|
||||
)
|
||||
elif block_type == 'er':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
exp_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
fake_in_chs=fake_in_chs,
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
noskip=noskip,
|
||||
)
|
||||
elif block_type == 'cn':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
kernel_size=int(options['k']),
|
||||
out_chs=int(options['c']),
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
)
|
||||
else:
|
||||
assert False, 'Unknown block type (%s)' % block_type
|
||||
|
||||
return block_args, num_repeat
|
||||
|
||||
|
||||
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
""" Per-stage depth scaling
|
||||
Scales the block repeats in each stage. This depth scaling impl maintains
|
||||
compatibility with the EfficientNet scaling method, while allowing sensible
|
||||
scaling for other models that may have multiple block arg definitions in each stage.
|
||||
"""
|
||||
|
||||
# We scale the total repeat count for each stage, there may be multiple
|
||||
# block arg defs per stage so we need to sum.
|
||||
num_repeat = sum(repeats)
|
||||
if depth_trunc == 'round':
|
||||
# Truncating to int by rounding allows stages with few repeats to remain
|
||||
# proportionally smaller for longer. This is a good choice when stage definitions
|
||||
# include single repeat stages that we'd prefer to keep that way as long as possible
|
||||
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
|
||||
else:
|
||||
# The default for EfficientNet truncates repeats to int via 'ceil'.
|
||||
# Any multiplier > 1.0 will result in an increased depth for every stage.
|
||||
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
|
||||
|
||||
# Proportionally distribute repeat count scaling to each block definition in the stage.
|
||||
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
|
||||
# The first block makes less sense to repeat in most of the arch definitions.
|
||||
repeats_scaled = []
|
||||
for r in repeats[::-1]:
|
||||
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
|
||||
repeats_scaled.append(rs)
|
||||
num_repeat -= r
|
||||
num_repeat_scaled -= rs
|
||||
repeats_scaled = repeats_scaled[::-1]
|
||||
|
||||
# Apply the calculated scaling to each block arg in the stage
|
||||
sa_scaled = []
|
||||
for ba, rep in zip(stack_args, repeats_scaled):
|
||||
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
|
||||
return sa_scaled
|
||||
|
||||
|
||||
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
|
||||
arch_args = []
|
||||
for stack_idx, block_strings in enumerate(arch_def):
|
||||
assert isinstance(block_strings, list)
|
||||
stack_args = []
|
||||
repeats = []
|
||||
for block_str in block_strings:
|
||||
assert isinstance(block_str, str)
|
||||
ba, rep = _decode_block_str(block_str)
|
||||
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
|
||||
ba['num_experts'] *= experts_multiplier
|
||||
stack_args.append(ba)
|
||||
repeats.append(rep)
|
||||
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
|
||||
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
|
||||
else:
|
||||
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
|
||||
return arch_args
|
||||
|
||||
|
||||
def initialize_weight_goog(m, n='', fix_group_fanout=True):
|
||||
# weight init as per Tensorflow Official impl
|
||||
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
|
||||
if isinstance(m, CondConv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
if fix_group_fanout:
|
||||
fan_out //= m.groups
|
||||
init_weight_fn = get_condconv_initializer(
|
||||
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
|
||||
init_weight_fn(m.weight)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
if fix_group_fanout:
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1.0)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
fan_out = m.weight.size(0) # fan-out
|
||||
fan_in = 0
|
||||
if 'routing_fn' in n:
|
||||
fan_in = m.weight.size(1)
|
||||
init_range = 1.0 / math.sqrt(fan_in + fan_out)
|
||||
m.weight.data.uniform_(-init_range, init_range)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
def initialize_weight_default(m, n=''):
|
||||
if isinstance(m, CondConv2d):
|
||||
init_fn = get_condconv_initializer(partial(
|
||||
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
|
||||
init_fn(m.weight)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1.0)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,71 @@
|
||||
""" Checkpoint loading / state_dict helpers
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
print("=> Loading checkpoint '{}'".format(checkpoint_path))
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint['state_dict'].items():
|
||||
if k.startswith('module'):
|
||||
name = k[7:] # remove `module.`
|
||||
else:
|
||||
name = k
|
||||
new_state_dict[name] = v
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
else:
|
||||
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_pretrained(model, url, filter_fn=None, strict=True):
|
||||
if not url:
|
||||
print("=> Warning: Pretrained model URL is empty, using random initialization.")
|
||||
return
|
||||
|
||||
state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu')
|
||||
|
||||
input_conv = 'conv_stem'
|
||||
classifier = 'classifier'
|
||||
in_chans = getattr(model, input_conv).weight.shape[1]
|
||||
num_classes = getattr(model, classifier).weight.shape[0]
|
||||
|
||||
input_conv_weight = input_conv + '.weight'
|
||||
pretrained_in_chans = state_dict[input_conv_weight].shape[1]
|
||||
if in_chans != pretrained_in_chans:
|
||||
if in_chans == 1:
|
||||
print('=> Converting pretrained input conv {} from {} to 1 channel'.format(
|
||||
input_conv_weight, pretrained_in_chans))
|
||||
conv1_weight = state_dict[input_conv_weight]
|
||||
state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True)
|
||||
else:
|
||||
print('=> Discarding pretrained input conv {} since input channel count != {}'.format(
|
||||
input_conv_weight, pretrained_in_chans))
|
||||
del state_dict[input_conv_weight]
|
||||
strict = False
|
||||
|
||||
classifier_weight = classifier + '.weight'
|
||||
pretrained_num_classes = state_dict[classifier_weight].shape[0]
|
||||
if num_classes != pretrained_num_classes:
|
||||
print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes))
|
||||
del state_dict[classifier_weight]
|
||||
del state_dict[classifier + '.bias']
|
||||
strict = False
|
||||
|
||||
if filter_fn is not None:
|
||||
state_dict = filter_fn(state_dict)
|
||||
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
@@ -0,0 +1,364 @@
|
||||
""" MobileNet-V3
|
||||
|
||||
A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
|
||||
|
||||
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .activations import get_act_fn, get_act_layer, HardSwish
|
||||
from .config import layer_config_kwargs
|
||||
from .conv2d_layers import select_conv2d
|
||||
from .helpers import load_pretrained
|
||||
from .efficientnet_builder import *
|
||||
|
||||
__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100',
|
||||
'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100',
|
||||
'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100',
|
||||
'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100']
|
||||
|
||||
model_urls = {
|
||||
'mobilenetv3_rw':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
|
||||
'mobilenetv3_large_075': None,
|
||||
'mobilenetv3_large_100':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
|
||||
'mobilenetv3_large_minimal_100': None,
|
||||
'mobilenetv3_small_075': None,
|
||||
'mobilenetv3_small_100': None,
|
||||
'mobilenetv3_small_minimal_100': None,
|
||||
'tf_mobilenetv3_large_075':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
|
||||
'tf_mobilenetv3_large_100':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
|
||||
'tf_mobilenetv3_large_minimal_100':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
|
||||
'tf_mobilenetv3_small_075':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
|
||||
'tf_mobilenetv3_small_100':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
|
||||
'tf_mobilenetv3_small_minimal_100':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
|
||||
}
|
||||
|
||||
|
||||
class MobileNetV3(nn.Module):
|
||||
""" MobileNet-V3
|
||||
|
||||
A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the
|
||||
head convolution without a final batch-norm layer before the classifier.
|
||||
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
|
||||
channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'):
|
||||
super(MobileNetV3, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
stem_size = round_channels(stem_size, channel_multiplier)
|
||||
self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
in_chs = stem_size
|
||||
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate)
|
||||
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
||||
in_chs = builder.in_chs
|
||||
|
||||
self.global_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.classifier = nn.Linear(num_features, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if weight_init == 'goog':
|
||||
initialize_weight_goog(m)
|
||||
else:
|
||||
initialize_weight_default(m)
|
||||
|
||||
def as_sequential(self):
|
||||
layers = [self.conv_stem, self.bn1, self.act1]
|
||||
layers.extend(self.blocks)
|
||||
layers.extend([
|
||||
self.global_pool, self.conv_head, self.act2,
|
||||
nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def features(self, x):
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
x = self.blocks(x)
|
||||
x = self.global_pool(x)
|
||||
x = self.conv_head(x)
|
||||
x = self.act2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.flatten(1)
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
def _create_model(model_kwargs, variant, pretrained=False):
|
||||
as_sequential = model_kwargs.pop('as_sequential', False)
|
||||
model = MobileNetV3(**model_kwargs)
|
||||
if pretrained and model_urls[variant]:
|
||||
load_pretrained(model, model_urls[variant])
|
||||
if as_sequential:
|
||||
model = model.as_sequential()
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
"""Creates a MobileNet-V3 model (RW variant).
|
||||
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
|
||||
This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the
|
||||
eventual Tensorflow reference impl but has a few differences:
|
||||
1. This model has no bias on the head convolution
|
||||
2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet
|
||||
3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer
|
||||
from their parent block
|
||||
4. This model does not enforce divisible by 8 limitation on the SE reduction channel count
|
||||
|
||||
Overall the changes are fairly minor and result in a very small parameter count difference and no
|
||||
top-1/5
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
|
||||
# stage 1, 112x112 in
|
||||
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
||||
# stage 2, 56x56 in
|
||||
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
|
||||
# stage 3, 28x28 in
|
||||
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
||||
# stage 4, 14x14in
|
||||
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
|
||||
# stage 5, 14x14in
|
||||
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # hard-swish
|
||||
]
|
||||
with layer_config_kwargs(kwargs):
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
head_bias=False, # one of my mistakes
|
||||
channel_multiplier=channel_multiplier,
|
||||
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
|
||||
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True),
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_model(model_kwargs, variant, pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
"""Creates a MobileNet-V3 large/small/minimal models.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
if 'small' in variant:
|
||||
num_features = 1024
|
||||
if 'minimal' in variant:
|
||||
act_layer = 'relu'
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s2_e1_c16'],
|
||||
# stage 1, 56x56 in
|
||||
['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
|
||||
# stage 2, 28x28 in
|
||||
['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
|
||||
# stage 3, 14x14 in
|
||||
['ir_r2_k3_s1_e3_c48'],
|
||||
# stage 4, 14x14in
|
||||
['ir_r3_k3_s2_e6_c96'],
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c576'],
|
||||
]
|
||||
else:
|
||||
act_layer = 'hard_swish'
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
|
||||
# stage 1, 56x56 in
|
||||
['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
|
||||
# stage 2, 28x28 in
|
||||
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
|
||||
# stage 3, 14x14 in
|
||||
['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
|
||||
# stage 4, 14x14in
|
||||
['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c576'], # hard-swish
|
||||
]
|
||||
else:
|
||||
num_features = 1280
|
||||
if 'minimal' in variant:
|
||||
act_layer = 'relu'
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s1_e1_c16'],
|
||||
# stage 1, 112x112 in
|
||||
['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
|
||||
# stage 2, 56x56 in
|
||||
['ir_r3_k3_s2_e3_c40'],
|
||||
# stage 3, 28x28 in
|
||||
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
|
||||
# stage 4, 14x14in
|
||||
['ir_r2_k3_s1_e6_c112'],
|
||||
# stage 5, 14x14in
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'],
|
||||
]
|
||||
else:
|
||||
act_layer = 'hard_swish'
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s1_e1_c16_nre'], # relu
|
||||
# stage 1, 112x112 in
|
||||
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
||||
# stage 2, 56x56 in
|
||||
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
|
||||
# stage 3, 28x28 in
|
||||
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
||||
# stage 4, 14x14in
|
||||
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
|
||||
# stage 5, 14x14in
|
||||
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # hard-swish
|
||||
]
|
||||
with layer_config_kwargs(kwargs):
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=num_features,
|
||||
stem_size=16,
|
||||
channel_multiplier=channel_multiplier,
|
||||
act_layer=resolve_act_layer(kwargs, act_layer),
|
||||
se_kwargs=dict(
|
||||
act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8),
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_model(model_kwargs, variant, pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_rw(pretrained=False, **kwargs):
|
||||
""" MobileNet-V3 RW
|
||||
Attn: See note in gen function for this variant.
|
||||
"""
|
||||
# NOTE for train set drop_rate=0.2
|
||||
if pretrained:
|
||||
# pretrained model trained with non-default BN epsilon
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_large_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Large 0.75"""
|
||||
# NOTE for train set drop_rate=0.2
|
||||
model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_large_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Large 1.0 """
|
||||
# NOTE for train set drop_rate=0.2
|
||||
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Large (Minimalistic) 1.0 """
|
||||
# NOTE for train set drop_rate=0.2
|
||||
model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_small_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Small 0.75 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_small_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Small 1.0 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Small (Minimalistic) 1.0 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tf_mobilenetv3_large_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Large 0.75. Tensorflow compat variant. """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tf_mobilenetv3_large_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Large 1.0. Tensorflow compat variant. """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tf_mobilenetv3_small_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Small 0.75. Tensorflow compat variant. """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tf_mobilenetv3_small_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Small 1.0. Tensorflow compat variant."""
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@@ -0,0 +1,27 @@
|
||||
from .config import set_layer_config
|
||||
from .helpers import load_checkpoint
|
||||
|
||||
from .gen_efficientnet import *
|
||||
from .mobilenetv3 import *
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name='mnasnet_100',
|
||||
pretrained=None,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
checkpoint_path='',
|
||||
**kwargs):
|
||||
|
||||
model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs)
|
||||
|
||||
if model_name in globals():
|
||||
create_fn = globals()[model_name]
|
||||
model = create_fn(**model_kwargs)
|
||||
else:
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
if checkpoint_path and not pretrained:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
||||
return model
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '1.0.2'
|
||||
@@ -0,0 +1,84 @@
|
||||
dependencies = ['torch', 'math']
|
||||
|
||||
from geffnet import efficientnet_b0
|
||||
from geffnet import efficientnet_b1
|
||||
from geffnet import efficientnet_b2
|
||||
from geffnet import efficientnet_b3
|
||||
|
||||
from geffnet import efficientnet_es
|
||||
|
||||
from geffnet import efficientnet_lite0
|
||||
|
||||
from geffnet import mixnet_s
|
||||
from geffnet import mixnet_m
|
||||
from geffnet import mixnet_l
|
||||
from geffnet import mixnet_xl
|
||||
|
||||
from geffnet import mobilenetv2_100
|
||||
from geffnet import mobilenetv2_110d
|
||||
from geffnet import mobilenetv2_120d
|
||||
from geffnet import mobilenetv2_140
|
||||
|
||||
from geffnet import mobilenetv3_large_100
|
||||
from geffnet import mobilenetv3_rw
|
||||
from geffnet import mnasnet_a1
|
||||
from geffnet import mnasnet_b1
|
||||
from geffnet import fbnetc_100
|
||||
from geffnet import spnasnet_100
|
||||
|
||||
from geffnet import tf_efficientnet_b0
|
||||
from geffnet import tf_efficientnet_b1
|
||||
from geffnet import tf_efficientnet_b2
|
||||
from geffnet import tf_efficientnet_b3
|
||||
from geffnet import tf_efficientnet_b4
|
||||
from geffnet import tf_efficientnet_b5
|
||||
from geffnet import tf_efficientnet_b6
|
||||
from geffnet import tf_efficientnet_b7
|
||||
from geffnet import tf_efficientnet_b8
|
||||
|
||||
from geffnet import tf_efficientnet_b0_ap
|
||||
from geffnet import tf_efficientnet_b1_ap
|
||||
from geffnet import tf_efficientnet_b2_ap
|
||||
from geffnet import tf_efficientnet_b3_ap
|
||||
from geffnet import tf_efficientnet_b4_ap
|
||||
from geffnet import tf_efficientnet_b5_ap
|
||||
from geffnet import tf_efficientnet_b6_ap
|
||||
from geffnet import tf_efficientnet_b7_ap
|
||||
from geffnet import tf_efficientnet_b8_ap
|
||||
|
||||
from geffnet import tf_efficientnet_b0_ns
|
||||
from geffnet import tf_efficientnet_b1_ns
|
||||
from geffnet import tf_efficientnet_b2_ns
|
||||
from geffnet import tf_efficientnet_b3_ns
|
||||
from geffnet import tf_efficientnet_b4_ns
|
||||
from geffnet import tf_efficientnet_b5_ns
|
||||
from geffnet import tf_efficientnet_b6_ns
|
||||
from geffnet import tf_efficientnet_b7_ns
|
||||
from geffnet import tf_efficientnet_l2_ns_475
|
||||
from geffnet import tf_efficientnet_l2_ns
|
||||
|
||||
from geffnet import tf_efficientnet_es
|
||||
from geffnet import tf_efficientnet_em
|
||||
from geffnet import tf_efficientnet_el
|
||||
|
||||
from geffnet import tf_efficientnet_cc_b0_4e
|
||||
from geffnet import tf_efficientnet_cc_b0_8e
|
||||
from geffnet import tf_efficientnet_cc_b1_8e
|
||||
|
||||
from geffnet import tf_efficientnet_lite0
|
||||
from geffnet import tf_efficientnet_lite1
|
||||
from geffnet import tf_efficientnet_lite2
|
||||
from geffnet import tf_efficientnet_lite3
|
||||
from geffnet import tf_efficientnet_lite4
|
||||
|
||||
from geffnet import tf_mixnet_s
|
||||
from geffnet import tf_mixnet_m
|
||||
from geffnet import tf_mixnet_l
|
||||
|
||||
from geffnet import tf_mobilenetv3_large_075
|
||||
from geffnet import tf_mobilenetv3_large_100
|
||||
from geffnet import tf_mobilenetv3_large_minimal_100
|
||||
from geffnet import tf_mobilenetv3_small_075
|
||||
from geffnet import tf_mobilenetv3_small_100
|
||||
from geffnet import tf_mobilenetv3_small_minimal_100
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
""" ONNX export script
|
||||
|
||||
Export PyTorch models as ONNX graphs.
|
||||
|
||||
This export script originally started as an adaptation of code snippets found at
|
||||
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
|
||||
|
||||
The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph
|
||||
for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible
|
||||
with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback
|
||||
flags are currently required.
|
||||
|
||||
Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for
|
||||
caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime.
|
||||
|
||||
Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models.
|
||||
Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
import onnx
|
||||
import geffnet
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||
parser.add_argument('output', metavar='ONNX_FILE',
|
||||
help='output model filename')
|
||||
parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100',
|
||||
help='model architecture (default: mobilenetv3_large_100)')
|
||||
parser.add_argument('--opset', type=int, default=10,
|
||||
help='ONNX opset to use (default: 10)')
|
||||
parser.add_argument('--keep-init', action='store_true', default=False,
|
||||
help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.')
|
||||
parser.add_argument('--aten-fallback', action='store_true', default=False,
|
||||
help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.')
|
||||
parser.add_argument('--dynamic-size', action='store_true', default=False,
|
||||
help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.')
|
||||
parser.add_argument('-b', '--batch-size', default=1, type=int,
|
||||
metavar='N', help='mini-batch size (default: 1)')
|
||||
parser.add_argument('--img-size', default=None, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override mean pixel value of dataset')
|
||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||
help='Override std deviation of of dataset')
|
||||
parser.add_argument('--num-classes', type=int, default=1000,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to checkpoint (default: none)')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
args.pretrained = True
|
||||
if args.checkpoint:
|
||||
args.pretrained = False
|
||||
|
||||
print("==> Creating PyTorch {} model".format(args.model))
|
||||
# NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers
|
||||
# for models using SAME padding
|
||||
model = geffnet.create_model(
|
||||
args.model,
|
||||
num_classes=args.num_classes,
|
||||
in_chans=3,
|
||||
pretrained=args.pretrained,
|
||||
checkpoint_path=args.checkpoint,
|
||||
exportable=True)
|
||||
|
||||
model.eval()
|
||||
|
||||
example_input = torch.randn((args.batch_size, 3, args.img_size or 224, args.img_size or 224), requires_grad=True)
|
||||
|
||||
# Run model once before export trace, sets padding for models with Conv2dSameExport. This means
|
||||
# that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for
|
||||
# the input img_size specified in this script.
|
||||
# Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
|
||||
# issues in the tracing of the dynamic padding or errors attempting to export the model after jit
|
||||
# scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
|
||||
model(example_input)
|
||||
|
||||
print("==> Exporting model to ONNX format at '{}'".format(args.output))
|
||||
input_names = ["input0"]
|
||||
output_names = ["output0"]
|
||||
dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}}
|
||||
if args.dynamic_size:
|
||||
dynamic_axes['input0'][2] = 'height'
|
||||
dynamic_axes['input0'][3] = 'width'
|
||||
if args.aten_fallback:
|
||||
export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
else:
|
||||
export_type = torch.onnx.OperatorExportTypes.ONNX
|
||||
|
||||
torch_out = torch.onnx._export(
|
||||
model, example_input, args.output, export_params=True, verbose=True, input_names=input_names,
|
||||
output_names=output_names, keep_initializers_as_inputs=args.keep_init, dynamic_axes=dynamic_axes,
|
||||
opset_version=args.opset, operator_export_type=export_type)
|
||||
|
||||
print("==> Loading and checking exported model from '{}'".format(args.output))
|
||||
onnx_model = onnx.load(args.output)
|
||||
onnx.checker.check_model(onnx_model) # assuming throw on error
|
||||
print("==> Passed")
|
||||
|
||||
if args.keep_init and args.aten_fallback:
|
||||
import caffe2.python.onnx.backend as onnx_caffe2
|
||||
# Caffe2 loading only works properly in newer PyTorch/ONNX combos when
|
||||
# keep_initializers_as_inputs and aten_fallback are set to True.
|
||||
print("==> Loading model into Caffe2 backend and comparing forward pass.".format(args.output))
|
||||
caffe2_backend = onnx_caffe2.prepare(onnx_model)
|
||||
B = {onnx_model.graph.input[0].name: x.data.numpy()}
|
||||
c2_out = caffe2_backend.run(B)[0]
|
||||
np.testing.assert_almost_equal(torch_out.data.numpy(), c2_out, decimal=5)
|
||||
print("==> Passed")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,84 @@
|
||||
""" ONNX optimization script
|
||||
|
||||
Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc.
|
||||
|
||||
NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 1.6 and ONNX 1.7),
|
||||
it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline).
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
import onnx
|
||||
from onnx import optimizer
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Optimize ONNX model")
|
||||
|
||||
parser.add_argument("model", help="The ONNX model")
|
||||
parser.add_argument("--output", required=True, help="The optimized model output filename")
|
||||
|
||||
|
||||
def traverse_graph(graph, prefix=''):
|
||||
content = []
|
||||
indent = prefix + ' '
|
||||
graphs = []
|
||||
num_nodes = 0
|
||||
for node in graph.node:
|
||||
pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True)
|
||||
assert isinstance(gs, list)
|
||||
content.append(pn)
|
||||
graphs.extend(gs)
|
||||
num_nodes += 1
|
||||
for g in graphs:
|
||||
g_count, g_str = traverse_graph(g)
|
||||
content.append('\n' + g_str)
|
||||
num_nodes += g_count
|
||||
return num_nodes, '\n'.join(content)
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
onnx_model = onnx.load(args.model)
|
||||
num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph)
|
||||
|
||||
# Optimizer passes to perform
|
||||
passes = [
|
||||
#'eliminate_deadend',
|
||||
'eliminate_identity',
|
||||
'eliminate_nop_dropout',
|
||||
'eliminate_nop_pad',
|
||||
'eliminate_nop_transpose',
|
||||
'eliminate_unused_initializer',
|
||||
'extract_constant_to_initializer',
|
||||
'fuse_add_bias_into_conv',
|
||||
'fuse_bn_into_conv',
|
||||
'fuse_consecutive_concats',
|
||||
'fuse_consecutive_reduce_unsqueeze',
|
||||
'fuse_consecutive_squeezes',
|
||||
'fuse_consecutive_transposes',
|
||||
#'fuse_matmul_add_bias_into_gemm',
|
||||
'fuse_pad_into_conv',
|
||||
#'fuse_transpose_into_gemm',
|
||||
#'lift_lexical_references',
|
||||
]
|
||||
|
||||
# Apply the optimization on the original serialized model
|
||||
# WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing
|
||||
# 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401
|
||||
# It may be better to rely on onnxruntime optimizations, see onnx_validate.py script.
|
||||
warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX."
|
||||
"Try onnxruntime optimization if this doesn't work.")
|
||||
optimized_model = optimizer.optimize(onnx_model, passes)
|
||||
|
||||
num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph)
|
||||
print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str))
|
||||
print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes))
|
||||
|
||||
# Save the ONNX model
|
||||
onnx.save(optimized_model, args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,27 @@
|
||||
import argparse
|
||||
|
||||
import onnx
|
||||
from caffe2.python.onnx.backend import Caffe2Backend
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert ONNX to Caffe2")
|
||||
|
||||
parser.add_argument("model", help="The ONNX model")
|
||||
parser.add_argument("--c2-prefix", required=True,
|
||||
help="The output file prefix for the caffe2 model init and predict file. ")
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
onnx_model = onnx.load(args.model)
|
||||
caffe2_init, caffe2_predict = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
|
||||
caffe2_init_str = caffe2_init.SerializeToString()
|
||||
with open(args.c2_prefix + '.init.pb', "wb") as f:
|
||||
f.write(caffe2_init_str)
|
||||
caffe2_predict_str = caffe2_predict.SerializeToString()
|
||||
with open(args.c2_prefix + '.predict.pb', "wb") as f:
|
||||
f.write(caffe2_predict_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,112 @@
|
||||
""" ONNX-runtime validation script
|
||||
|
||||
This script was created to verify accuracy and performance of exported ONNX
|
||||
models running with the onnxruntime. It utilizes the PyTorch dataloader/processing
|
||||
pipeline for a fair comparison against the originals.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
from data import create_loader, resolve_data_config, Dataset
|
||||
from utils import AverageMeter
|
||||
import time
|
||||
|
||||
parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--onnx-input', default='', type=str, metavar='PATH',
|
||||
help='path to onnx model/weights file')
|
||||
parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH',
|
||||
help='path to output optimized onnx graph')
|
||||
parser.add_argument('--profile', action='store_true', default=False,
|
||||
help='Enable profiler output.')
|
||||
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N', help='mini-batch size (default: 256)')
|
||||
parser.add_argument('--img-size', default=None, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override mean pixel value of dataset')
|
||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||
help='Override std deviation of of dataset')
|
||||
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
|
||||
help='Override default crop pct of 0.875')
|
||||
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
|
||||
help='use tensorflow mnasnet preporcessing')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
args.gpu_id = 0
|
||||
|
||||
# Set graph optimization level
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
if args.profile:
|
||||
sess_options.enable_profiling = True
|
||||
if args.onnx_output_opt:
|
||||
sess_options.optimized_model_filepath = args.onnx_output_opt
|
||||
|
||||
session = onnxruntime.InferenceSession(args.onnx_input, sess_options)
|
||||
|
||||
data_config = resolve_data_config(None, args)
|
||||
loader = create_loader(
|
||||
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.batch_size,
|
||||
use_prefetcher=False,
|
||||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
tensorflow_preprocessing=args.tf_preprocessing)
|
||||
|
||||
input_name = session.get_inputs()[0].name
|
||||
|
||||
batch_time = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
end = time.time()
|
||||
for i, (input, target) in enumerate(loader):
|
||||
# run the net and return prediction
|
||||
output = session.run([], {input_name: input.data.numpy()})
|
||||
output = output[0]
|
||||
|
||||
# measure accuracy and record loss
|
||||
prec1, prec5 = accuracy_np(output, target.numpy())
|
||||
top1.update(prec1.item(), input.size(0))
|
||||
top5.update(prec5.item(), input.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print('Test: [{0}/{1}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t'
|
||||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
||||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
|
||||
i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg,
|
||||
ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5))
|
||||
|
||||
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
|
||||
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
|
||||
|
||||
|
||||
def accuracy_np(output, target):
|
||||
max_indices = np.argsort(output, axis=1)[:, ::-1]
|
||||
top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
|
||||
top1 = 100 * np.equal(max_indices[:, 0], target).mean()
|
||||
return top1, top5
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,2 @@
|
||||
torch>=1.2.0
|
||||
torchvision>=0.4.0
|
||||
@@ -0,0 +1,47 @@
|
||||
""" Setup
|
||||
"""
|
||||
from setuptools import setup, find_packages
|
||||
from codecs import open
|
||||
from os import path
|
||||
|
||||
here = path.abspath(path.dirname(__file__))
|
||||
|
||||
# Get the long description from the README file
|
||||
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
|
||||
long_description = f.read()
|
||||
|
||||
exec(open('geffnet/version.py').read())
|
||||
setup(
|
||||
name='geffnet',
|
||||
version=__version__,
|
||||
description='(Generic) EfficientNets for PyTorch',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/rwightman/gen-efficientnet-pytorch',
|
||||
author='Ross Wightman',
|
||||
author_email='hello@rwightman.com',
|
||||
classifiers=[
|
||||
# How mature is this project? Common values are
|
||||
# 3 - Alpha
|
||||
# 4 - Beta
|
||||
# 5 - Production/Stable
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Intended Audience :: Education',
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Topic :: Scientific/Engineering',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development',
|
||||
'Topic :: Software Development :: Libraries',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
],
|
||||
|
||||
# Note that this is a string of words separated by whitespace, not a list.
|
||||
keywords='pytorch pretrained models efficientnet mixnet mobilenetv3 mnasnet',
|
||||
packages=find_packages(exclude=['data']),
|
||||
install_requires=['torch >= 1.4', 'torchvision'],
|
||||
python_requires='>=3.6',
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def get_outdir(path, *paths, inc=False):
|
||||
outdir = os.path.join(path, *paths)
|
||||
if not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
elif inc:
|
||||
count = 1
|
||||
outdir_inc = outdir + '-' + str(count)
|
||||
while os.path.exists(outdir_inc):
|
||||
count = count + 1
|
||||
outdir_inc = outdir + '-' + str(count)
|
||||
assert count < 100
|
||||
outdir = outdir_inc
|
||||
os.makedirs(outdir)
|
||||
return outdir
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
from contextlib import suppress
|
||||
|
||||
import geffnet
|
||||
from data import Dataset, create_loader, resolve_data_config
|
||||
from utils import accuracy, AverageMeter
|
||||
|
||||
has_native_amp = False
|
||||
try:
|
||||
if getattr(torch.cuda.amp, 'autocast') is not None:
|
||||
has_native_amp = True
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--model', '-m', metavar='MODEL', default='spnasnet1_00',
|
||||
help='model architecture (default: dpn92)')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N', help='mini-batch size (default: 256)')
|
||||
parser.add_argument('--img-size', default=None, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override mean pixel value of dataset')
|
||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||
help='Override std deviation of of dataset')
|
||||
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
|
||||
help='Override default crop pct of 0.875')
|
||||
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('--num-classes', type=int, default=1000,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
parser.add_argument('--num-gpu', type=int, default=1,
|
||||
help='Number of GPUS to use')
|
||||
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
|
||||
help='use tensorflow mnasnet preporcessing')
|
||||
parser.add_argument('--no-cuda', dest='no_cuda', action='store_true',
|
||||
help='')
|
||||
parser.add_argument('--channels-last', action='store_true', default=False,
|
||||
help='Use channels_last memory layout')
|
||||
parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='Use native Torch AMP mixed precision.')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.checkpoint and not args.pretrained:
|
||||
args.pretrained = True
|
||||
|
||||
amp_autocast = suppress # do nothing
|
||||
if args.amp:
|
||||
if not has_native_amp:
|
||||
print("Native Torch AMP is not available (requires torch >= 1.6), using FP32.")
|
||||
else:
|
||||
amp_autocast = torch.cuda.amp.autocast
|
||||
|
||||
# create model
|
||||
model = geffnet.create_model(
|
||||
args.model,
|
||||
num_classes=args.num_classes,
|
||||
in_chans=3,
|
||||
pretrained=args.pretrained,
|
||||
checkpoint_path=args.checkpoint,
|
||||
scriptable=args.torchscript)
|
||||
|
||||
if args.channels_last:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
|
||||
if args.torchscript:
|
||||
torch.jit.optimized_execution(True)
|
||||
model = torch.jit.script(model)
|
||||
|
||||
print('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
|
||||
data_config = resolve_data_config(model, args)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
if not args.no_cuda:
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||
else:
|
||||
model = model.cuda()
|
||||
criterion = criterion.cuda()
|
||||
|
||||
loader = create_loader(
|
||||
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.batch_size,
|
||||
use_prefetcher=not args.no_cuda,
|
||||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
tensorflow_preprocessing=args.tf_preprocessing)
|
||||
|
||||
batch_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
|
||||
model.eval()
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
for i, (input, target) in enumerate(loader):
|
||||
if not args.no_cuda:
|
||||
target = target.cuda()
|
||||
input = input.cuda()
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
# compute output
|
||||
with amp_autocast():
|
||||
output = model(input)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
|
||||
losses.update(loss.item(), input.size(0))
|
||||
top1.update(prec1.item(), input.size(0))
|
||||
top5.update(prec5.item(), input.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print('Test: [{0}/{1}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
||||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
||||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
|
||||
i, len(loader), batch_time=batch_time,
|
||||
rate_avg=input.size(0) / batch_time.avg,
|
||||
loss=losses, top1=top1, top5=top5))
|
||||
|
||||
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
|
||||
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
basemodel_name = 'tf_efficientnet_b5_ap'
|
||||
print('Loading base model ()...'.format(basemodel_name), end='')
|
||||
repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo')
|
||||
basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local')
|
||||
print('Done.')
|
||||
|
||||
# Remove last layer
|
||||
print('Removing last two layers (global_pool & classifier).')
|
||||
basemodel.global_pool = nn.Identity()
|
||||
basemodel.classifier = nn.Identity()
|
||||
|
||||
self.original_model = basemodel
|
||||
|
||||
def forward(self, x):
|
||||
features = [x]
|
||||
for k, v in self.original_model._modules.items():
|
||||
if (k == 'blocks'):
|
||||
for ki, vi in v._modules.items():
|
||||
features.append(vi(features[-1]))
|
||||
else:
|
||||
features.append(v(features[-1]))
|
||||
return features
|
||||
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
########################################################################################################################
|
||||
|
||||
|
||||
# Upsample + BatchNorm
|
||||
class UpSampleBN(nn.Module):
|
||||
def __init__(self, skip_input, output_features):
|
||||
super(UpSampleBN, self).__init__()
|
||||
|
||||
self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(output_features),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(output_features),
|
||||
nn.LeakyReLU())
|
||||
|
||||
def forward(self, x, concat_with):
|
||||
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
|
||||
f = torch.cat([up_x, concat_with], dim=1)
|
||||
return self._net(f)
|
||||
|
||||
|
||||
# Upsample + GroupNorm + Weight Standardization
|
||||
class UpSampleGN(nn.Module):
|
||||
def __init__(self, skip_input, output_features):
|
||||
super(UpSampleGN, self).__init__()
|
||||
|
||||
self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
|
||||
nn.GroupNorm(8, output_features),
|
||||
nn.LeakyReLU(),
|
||||
Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
|
||||
nn.GroupNorm(8, output_features),
|
||||
nn.LeakyReLU())
|
||||
|
||||
def forward(self, x, concat_with):
|
||||
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
|
||||
f = torch.cat([up_x, concat_with], dim=1)
|
||||
return self._net(f)
|
||||
|
||||
|
||||
# Conv2d with weight standardization
|
||||
class Conv2d(nn.Conv2d):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True):
|
||||
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, groups, bias)
|
||||
|
||||
def forward(self, x):
|
||||
weight = self.weight
|
||||
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
|
||||
keepdim=True).mean(dim=3, keepdim=True)
|
||||
weight = weight - weight_mean
|
||||
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
|
||||
weight = weight / std.expand_as(weight)
|
||||
return F.conv2d(x, weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
# normalize
|
||||
def norm_normalize(norm_out):
|
||||
min_kappa = 0.01
|
||||
norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
|
||||
norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
|
||||
kappa = F.elu(kappa) + 1.0 + min_kappa
|
||||
final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
|
||||
return final_out
|
||||
|
||||
|
||||
# uncertainty-guided sampling (only used during training)
|
||||
@torch.no_grad()
|
||||
def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
|
||||
device = init_normal.device
|
||||
B, _, H, W = init_normal.shape
|
||||
N = int(sampling_ratio * H * W)
|
||||
beta = beta
|
||||
|
||||
# uncertainty map
|
||||
uncertainty_map = -1 * init_normal[:, 3, :, :] # B, H, W
|
||||
|
||||
# gt_invalid_mask (B, H, W)
|
||||
if gt_norm_mask is not None:
|
||||
gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
|
||||
gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
|
||||
uncertainty_map[gt_invalid_mask] = -1e4
|
||||
|
||||
# (B, H*W)
|
||||
_, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
|
||||
|
||||
# importance sampling
|
||||
if int(beta * N) > 0:
|
||||
importance = idx[:, :int(beta * N)] # B, beta*N
|
||||
|
||||
# remaining
|
||||
remaining = idx[:, int(beta * N):] # B, H*W - beta*N
|
||||
|
||||
# coverage
|
||||
num_coverage = N - int(beta * N)
|
||||
|
||||
if num_coverage <= 0:
|
||||
samples = importance
|
||||
else:
|
||||
coverage_list = []
|
||||
for i in range(B):
|
||||
idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
|
||||
coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
|
||||
coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
|
||||
samples = torch.cat((importance, coverage), dim=1) # B, N
|
||||
|
||||
else:
|
||||
# remaining
|
||||
remaining = idx[:, :] # B, H*W
|
||||
|
||||
# coverage
|
||||
num_coverage = N
|
||||
|
||||
coverage_list = []
|
||||
for i in range(B):
|
||||
idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
|
||||
coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
|
||||
coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
|
||||
samples = coverage
|
||||
|
||||
# point coordinates
|
||||
rows_int = samples // W # 0 for first row, H-1 for last row
|
||||
rows_float = rows_int / float(H-1) # 0 to 1.0
|
||||
rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
|
||||
|
||||
cols_int = samples % W # 0 for first column, W-1 for last column
|
||||
cols_float = cols_int / float(W-1) # 0 to 1.0
|
||||
cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
|
||||
|
||||
point_coords = torch.zeros(B, 1, N, 2)
|
||||
point_coords[:, 0, :, 0] = cols_float # x coord
|
||||
point_coords[:, 0, :, 1] = rows_float # y coord
|
||||
point_coords = point_coords.to(device)
|
||||
return point_coords, rows_int, cols_int
|
||||
79
invokeai/backend/image_util/pidi/__init__.py
Normal file
79
invokeai/backend/image_util/pidi/__init__.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Adapted from https://github.com/huggingface/controlnet_aux
|
||||
|
||||
import pathlib
|
||||
|
||||
import cv2
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet
|
||||
from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step
|
||||
|
||||
|
||||
class PIDINetDetector:
|
||||
"""Simple wrapper around a PiDiNet model for edge detection."""
|
||||
|
||||
hf_repo_id = "lllyasviel/Annotators"
|
||||
hf_filename = "table5_pidinet.pth"
|
||||
|
||||
@classmethod
|
||||
def get_model_url(cls) -> str:
|
||||
"""Get the URL to download the model from the Hugging Face Hub."""
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path: pathlib.Path) -> PiDiNet:
|
||||
"""Load the model from a file."""
|
||||
|
||||
model = pidinet()
|
||||
model.load_state_dict({k.replace("module.", ""): v for k, v in torch.load(model_path)["state_dict"].items()})
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def __init__(self, model: PiDiNet) -> None:
|
||||
self.model = model
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def run(
|
||||
self, image: Image.Image, quantize_edges: bool = False, scribble: bool = False, apply_filter: bool = False
|
||||
) -> Image.Image:
|
||||
"""Processes an image and returns the detected edges."""
|
||||
|
||||
device = next(iter(self.model.parameters())).device
|
||||
|
||||
np_img = pil_to_np(image)
|
||||
np_img = normalize_image_channel_count(np_img)
|
||||
|
||||
assert np_img.ndim == 3
|
||||
|
||||
bgr_img = np_img[:, :, ::-1].copy()
|
||||
|
||||
with torch.no_grad():
|
||||
image_pidi = torch.from_numpy(bgr_img).float().to(device)
|
||||
image_pidi = image_pidi / 255.0
|
||||
image_pidi = rearrange(image_pidi, "h w c -> 1 c h w")
|
||||
edge = self.model(image_pidi)[-1]
|
||||
edge = edge.cpu().numpy()
|
||||
if apply_filter:
|
||||
edge = edge > 0.5
|
||||
if quantize_edges:
|
||||
edge = safe_step(edge)
|
||||
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
detected_map = edge[0, 0]
|
||||
|
||||
if scribble:
|
||||
detected_map = nms(detected_map, 127, 3.0)
|
||||
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
|
||||
detected_map[detected_map > 4] = 255
|
||||
detected_map[detected_map < 255] = 0
|
||||
|
||||
output_img = np_to_pil(detected_map)
|
||||
|
||||
return output_img
|
||||
681
invokeai/backend/image_util/pidi/model.py
Normal file
681
invokeai/backend/image_util/pidi/model.py
Normal file
@@ -0,0 +1,681 @@
|
||||
"""
|
||||
Author: Zhuo Su, Wenzhe Liu
|
||||
Date: Feb 18, 2021
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
||||
"""Numpy array to tensor.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Input images.
|
||||
bgr2rgb (bool): Whether to change bgr to rgb.
|
||||
float32 (bool): Whether to change to float32.
|
||||
|
||||
Returns:
|
||||
list[tensor] | tensor: Tensor images. If returned results only have
|
||||
one element, just return tensor.
|
||||
"""
|
||||
|
||||
def _totensor(img, bgr2rgb, float32):
|
||||
if img.shape[2] == 3 and bgr2rgb:
|
||||
if img.dtype == 'float64':
|
||||
img = img.astype('float32')
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1))
|
||||
if float32:
|
||||
img = img.float()
|
||||
return img
|
||||
|
||||
if isinstance(imgs, list):
|
||||
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
||||
else:
|
||||
return _totensor(imgs, bgr2rgb, float32)
|
||||
|
||||
nets = {
|
||||
'baseline': {
|
||||
'layer0': 'cv',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cv',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cv',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cv',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'c-v15': {
|
||||
'layer0': 'cd',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cv',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cv',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cv',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'a-v15': {
|
||||
'layer0': 'ad',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cv',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cv',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cv',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'r-v15': {
|
||||
'layer0': 'rd',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cv',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cv',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cv',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'cvvv4': {
|
||||
'layer0': 'cd',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cd',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cd',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cd',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'avvv4': {
|
||||
'layer0': 'ad',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'ad',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'ad',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'ad',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'rvvv4': {
|
||||
'layer0': 'rd',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'rd',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'rd',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'rd',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'cccv4': {
|
||||
'layer0': 'cd',
|
||||
'layer1': 'cd',
|
||||
'layer2': 'cd',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cd',
|
||||
'layer5': 'cd',
|
||||
'layer6': 'cd',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cd',
|
||||
'layer9': 'cd',
|
||||
'layer10': 'cd',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cd',
|
||||
'layer13': 'cd',
|
||||
'layer14': 'cd',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'aaav4': {
|
||||
'layer0': 'ad',
|
||||
'layer1': 'ad',
|
||||
'layer2': 'ad',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'ad',
|
||||
'layer5': 'ad',
|
||||
'layer6': 'ad',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'ad',
|
||||
'layer9': 'ad',
|
||||
'layer10': 'ad',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'ad',
|
||||
'layer13': 'ad',
|
||||
'layer14': 'ad',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'rrrv4': {
|
||||
'layer0': 'rd',
|
||||
'layer1': 'rd',
|
||||
'layer2': 'rd',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'rd',
|
||||
'layer5': 'rd',
|
||||
'layer6': 'rd',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'rd',
|
||||
'layer9': 'rd',
|
||||
'layer10': 'rd',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'rd',
|
||||
'layer13': 'rd',
|
||||
'layer14': 'rd',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'c16': {
|
||||
'layer0': 'cd',
|
||||
'layer1': 'cd',
|
||||
'layer2': 'cd',
|
||||
'layer3': 'cd',
|
||||
'layer4': 'cd',
|
||||
'layer5': 'cd',
|
||||
'layer6': 'cd',
|
||||
'layer7': 'cd',
|
||||
'layer8': 'cd',
|
||||
'layer9': 'cd',
|
||||
'layer10': 'cd',
|
||||
'layer11': 'cd',
|
||||
'layer12': 'cd',
|
||||
'layer13': 'cd',
|
||||
'layer14': 'cd',
|
||||
'layer15': 'cd',
|
||||
},
|
||||
'a16': {
|
||||
'layer0': 'ad',
|
||||
'layer1': 'ad',
|
||||
'layer2': 'ad',
|
||||
'layer3': 'ad',
|
||||
'layer4': 'ad',
|
||||
'layer5': 'ad',
|
||||
'layer6': 'ad',
|
||||
'layer7': 'ad',
|
||||
'layer8': 'ad',
|
||||
'layer9': 'ad',
|
||||
'layer10': 'ad',
|
||||
'layer11': 'ad',
|
||||
'layer12': 'ad',
|
||||
'layer13': 'ad',
|
||||
'layer14': 'ad',
|
||||
'layer15': 'ad',
|
||||
},
|
||||
'r16': {
|
||||
'layer0': 'rd',
|
||||
'layer1': 'rd',
|
||||
'layer2': 'rd',
|
||||
'layer3': 'rd',
|
||||
'layer4': 'rd',
|
||||
'layer5': 'rd',
|
||||
'layer6': 'rd',
|
||||
'layer7': 'rd',
|
||||
'layer8': 'rd',
|
||||
'layer9': 'rd',
|
||||
'layer10': 'rd',
|
||||
'layer11': 'rd',
|
||||
'layer12': 'rd',
|
||||
'layer13': 'rd',
|
||||
'layer14': 'rd',
|
||||
'layer15': 'rd',
|
||||
},
|
||||
'carv4': {
|
||||
'layer0': 'cd',
|
||||
'layer1': 'ad',
|
||||
'layer2': 'rd',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cd',
|
||||
'layer5': 'ad',
|
||||
'layer6': 'rd',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cd',
|
||||
'layer9': 'ad',
|
||||
'layer10': 'rd',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cd',
|
||||
'layer13': 'ad',
|
||||
'layer14': 'rd',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
}
|
||||
|
||||
def createConvFunc(op_type):
|
||||
assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type)
|
||||
if op_type == 'cv':
|
||||
return F.conv2d
|
||||
|
||||
if op_type == 'cd':
|
||||
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2'
|
||||
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3'
|
||||
assert padding == dilation, 'padding for cd_conv set wrong'
|
||||
|
||||
weights_c = weights.sum(dim=[2, 3], keepdim=True)
|
||||
yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups)
|
||||
y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
return y - yc
|
||||
return func
|
||||
elif op_type == 'ad':
|
||||
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2'
|
||||
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3'
|
||||
assert padding == dilation, 'padding for ad_conv set wrong'
|
||||
|
||||
shape = weights.shape
|
||||
weights = weights.view(shape[0], shape[1], -1)
|
||||
weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise
|
||||
y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
return y
|
||||
return func
|
||||
elif op_type == 'rd':
|
||||
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2'
|
||||
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3'
|
||||
padding = 2 * dilation
|
||||
|
||||
shape = weights.shape
|
||||
if weights.is_cuda:
|
||||
buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0)
|
||||
else:
|
||||
buffer = torch.zeros(shape[0], shape[1], 5 * 5).to(weights.device)
|
||||
weights = weights.view(shape[0], shape[1], -1)
|
||||
buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
|
||||
buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
|
||||
buffer[:, :, 12] = 0
|
||||
buffer = buffer.view(shape[0], shape[1], 5, 5)
|
||||
y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
return y
|
||||
return func
|
||||
else:
|
||||
print('impossible to be here unless you force that')
|
||||
return None
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
|
||||
super(Conv2d, self).__init__()
|
||||
if in_channels % groups != 0:
|
||||
raise ValueError('in_channels must be divisible by groups')
|
||||
if out_channels % groups != 0:
|
||||
raise ValueError('out_channels must be divisible by groups')
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
self.pdc = pdc
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
class CSAM(nn.Module):
|
||||
"""
|
||||
Compact Spatial Attention Module
|
||||
"""
|
||||
def __init__(self, channels):
|
||||
super(CSAM, self).__init__()
|
||||
|
||||
mid_channels = 4
|
||||
self.relu1 = nn.ReLU()
|
||||
self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0)
|
||||
self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
nn.init.constant_(self.conv1.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.relu1(x)
|
||||
y = self.conv1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.sigmoid(y)
|
||||
|
||||
return x * y
|
||||
|
||||
class CDCM(nn.Module):
|
||||
"""
|
||||
Compact Dilation Convolution based Module
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(CDCM, self).__init__()
|
||||
|
||||
self.relu1 = nn.ReLU()
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
||||
self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False)
|
||||
self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False)
|
||||
self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False)
|
||||
self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False)
|
||||
nn.init.constant_(self.conv1.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu1(x)
|
||||
x = self.conv1(x)
|
||||
x1 = self.conv2_1(x)
|
||||
x2 = self.conv2_2(x)
|
||||
x3 = self.conv2_3(x)
|
||||
x4 = self.conv2_4(x)
|
||||
return x1 + x2 + x3 + x4
|
||||
|
||||
|
||||
class MapReduce(nn.Module):
|
||||
"""
|
||||
Reduce feature maps into a single edge map
|
||||
"""
|
||||
def __init__(self, channels):
|
||||
super(MapReduce, self).__init__()
|
||||
self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
|
||||
nn.init.constant_(self.conv.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class PDCBlock(nn.Module):
|
||||
def __init__(self, pdc, inplane, ouplane, stride=1):
|
||||
super(PDCBlock, self).__init__()
|
||||
self.stride=stride
|
||||
|
||||
self.stride=stride
|
||||
if self.stride > 1:
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
|
||||
self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride > 1:
|
||||
x = self.pool(x)
|
||||
y = self.conv1(x)
|
||||
y = self.relu2(y)
|
||||
y = self.conv2(y)
|
||||
if self.stride > 1:
|
||||
x = self.shortcut(x)
|
||||
y = y + x
|
||||
return y
|
||||
|
||||
class PDCBlock_converted(nn.Module):
|
||||
"""
|
||||
CPDC, APDC can be converted to vanilla 3x3 convolution
|
||||
RPDC can be converted to vanilla 5x5 convolution
|
||||
"""
|
||||
def __init__(self, pdc, inplane, ouplane, stride=1):
|
||||
super(PDCBlock_converted, self).__init__()
|
||||
self.stride=stride
|
||||
|
||||
if self.stride > 1:
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
|
||||
if pdc == 'rd':
|
||||
self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride > 1:
|
||||
x = self.pool(x)
|
||||
y = self.conv1(x)
|
||||
y = self.relu2(y)
|
||||
y = self.conv2(y)
|
||||
if self.stride > 1:
|
||||
x = self.shortcut(x)
|
||||
y = y + x
|
||||
return y
|
||||
|
||||
class PiDiNet(nn.Module):
|
||||
def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
|
||||
super(PiDiNet, self).__init__()
|
||||
self.sa = sa
|
||||
if dil is not None:
|
||||
assert isinstance(dil, int), 'dil should be an int'
|
||||
self.dil = dil
|
||||
|
||||
self.fuseplanes = []
|
||||
|
||||
self.inplane = inplane
|
||||
if convert:
|
||||
if pdcs[0] == 'rd':
|
||||
init_kernel_size = 5
|
||||
init_padding = 2
|
||||
else:
|
||||
init_kernel_size = 3
|
||||
init_padding = 1
|
||||
self.init_block = nn.Conv2d(3, self.inplane,
|
||||
kernel_size=init_kernel_size, padding=init_padding, bias=False)
|
||||
block_class = PDCBlock_converted
|
||||
else:
|
||||
self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
|
||||
block_class = PDCBlock
|
||||
|
||||
self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
|
||||
self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
|
||||
self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
|
||||
self.fuseplanes.append(self.inplane) # C
|
||||
|
||||
inplane = self.inplane
|
||||
self.inplane = self.inplane * 2
|
||||
self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
|
||||
self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
|
||||
self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
|
||||
self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
|
||||
self.fuseplanes.append(self.inplane) # 2C
|
||||
|
||||
inplane = self.inplane
|
||||
self.inplane = self.inplane * 2
|
||||
self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
|
||||
self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
|
||||
self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
|
||||
self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
|
||||
self.fuseplanes.append(self.inplane) # 4C
|
||||
|
||||
self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2)
|
||||
self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
|
||||
self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
|
||||
self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
|
||||
self.fuseplanes.append(self.inplane) # 4C
|
||||
|
||||
self.conv_reduces = nn.ModuleList()
|
||||
if self.sa and self.dil is not None:
|
||||
self.attentions = nn.ModuleList()
|
||||
self.dilations = nn.ModuleList()
|
||||
for i in range(4):
|
||||
self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
|
||||
self.attentions.append(CSAM(self.dil))
|
||||
self.conv_reduces.append(MapReduce(self.dil))
|
||||
elif self.sa:
|
||||
self.attentions = nn.ModuleList()
|
||||
for i in range(4):
|
||||
self.attentions.append(CSAM(self.fuseplanes[i]))
|
||||
self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
|
||||
elif self.dil is not None:
|
||||
self.dilations = nn.ModuleList()
|
||||
for i in range(4):
|
||||
self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
|
||||
self.conv_reduces.append(MapReduce(self.dil))
|
||||
else:
|
||||
for i in range(4):
|
||||
self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
|
||||
|
||||
self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias
|
||||
nn.init.constant_(self.classifier.weight, 0.25)
|
||||
nn.init.constant_(self.classifier.bias, 0)
|
||||
|
||||
# print('initialization done')
|
||||
|
||||
def get_weights(self):
|
||||
conv_weights = []
|
||||
bn_weights = []
|
||||
relu_weights = []
|
||||
for pname, p in self.named_parameters():
|
||||
if 'bn' in pname:
|
||||
bn_weights.append(p)
|
||||
elif 'relu' in pname:
|
||||
relu_weights.append(p)
|
||||
else:
|
||||
conv_weights.append(p)
|
||||
|
||||
return conv_weights, bn_weights, relu_weights
|
||||
|
||||
def forward(self, x):
|
||||
H, W = x.size()[2:]
|
||||
|
||||
x = self.init_block(x)
|
||||
|
||||
x1 = self.block1_1(x)
|
||||
x1 = self.block1_2(x1)
|
||||
x1 = self.block1_3(x1)
|
||||
|
||||
x2 = self.block2_1(x1)
|
||||
x2 = self.block2_2(x2)
|
||||
x2 = self.block2_3(x2)
|
||||
x2 = self.block2_4(x2)
|
||||
|
||||
x3 = self.block3_1(x2)
|
||||
x3 = self.block3_2(x3)
|
||||
x3 = self.block3_3(x3)
|
||||
x3 = self.block3_4(x3)
|
||||
|
||||
x4 = self.block4_1(x3)
|
||||
x4 = self.block4_2(x4)
|
||||
x4 = self.block4_3(x4)
|
||||
x4 = self.block4_4(x4)
|
||||
|
||||
x_fuses = []
|
||||
if self.sa and self.dil is not None:
|
||||
for i, xi in enumerate([x1, x2, x3, x4]):
|
||||
x_fuses.append(self.attentions[i](self.dilations[i](xi)))
|
||||
elif self.sa:
|
||||
for i, xi in enumerate([x1, x2, x3, x4]):
|
||||
x_fuses.append(self.attentions[i](xi))
|
||||
elif self.dil is not None:
|
||||
for i, xi in enumerate([x1, x2, x3, x4]):
|
||||
x_fuses.append(self.dilations[i](xi))
|
||||
else:
|
||||
x_fuses = [x1, x2, x3, x4]
|
||||
|
||||
e1 = self.conv_reduces[0](x_fuses[0])
|
||||
e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)
|
||||
|
||||
e2 = self.conv_reduces[1](x_fuses[1])
|
||||
e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)
|
||||
|
||||
e3 = self.conv_reduces[2](x_fuses[2])
|
||||
e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)
|
||||
|
||||
e4 = self.conv_reduces[3](x_fuses[3])
|
||||
e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)
|
||||
|
||||
outputs = [e1, e2, e3, e4]
|
||||
|
||||
output = self.classifier(torch.cat(outputs, dim=1))
|
||||
#if not self.training:
|
||||
# return torch.sigmoid(output)
|
||||
|
||||
outputs.append(output)
|
||||
outputs = [torch.sigmoid(r) for r in outputs]
|
||||
return outputs
|
||||
|
||||
def config_model(model):
|
||||
model_options = list(nets.keys())
|
||||
assert model in model_options, \
|
||||
'unrecognized model, please choose from %s' % str(model_options)
|
||||
|
||||
# print(str(nets[model]))
|
||||
|
||||
pdcs = []
|
||||
for i in range(16):
|
||||
layer_name = 'layer%d' % i
|
||||
op = nets[model][layer_name]
|
||||
pdcs.append(createConvFunc(op))
|
||||
|
||||
return pdcs
|
||||
|
||||
def pidinet():
|
||||
pdcs = config_model('carv4')
|
||||
dil = 24 #if args.dil else None
|
||||
return PiDiNet(60, pdcs, dil=dil, sa=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = pidinet()
|
||||
ckp = torch.load('table5_pidinet.pth')['state_dict']
|
||||
model.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
|
||||
im = cv2.imread('examples/test_my/cat_v4.png')
|
||||
im = img2tensor(im).unsqueeze(0)/255.
|
||||
res = model(im)[-1]
|
||||
res = res>0.5
|
||||
res = res.float()
|
||||
res = (res[0,0].cpu().data.numpy()*255.).astype(np.uint8)
|
||||
print(res.shape)
|
||||
cv2.imwrite('edge.png', res)
|
||||
@@ -86,12 +86,20 @@ def np_to_pil(image: np.ndarray) -> Image.Image:
|
||||
|
||||
def pil_to_cv2(image: Image.Image) -> np.ndarray:
|
||||
"""Converts a PIL image to a CV2 image."""
|
||||
return cv2.cvtColor(np.array(image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
||||
|
||||
if image.mode == "RGBA":
|
||||
return cv2.cvtColor(np.array(image, dtype=np.uint8), cv2.COLOR_RGBA2BGRA)
|
||||
else:
|
||||
return cv2.cvtColor(np.array(image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
||||
|
||||
|
||||
def cv2_to_pil(image: np.ndarray) -> Image.Image:
|
||||
"""Converts a CV2 image to a PIL image."""
|
||||
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
if image.ndim == 3 and image.shape[2] == 4:
|
||||
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
|
||||
else:
|
||||
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
|
||||
def normalize_image_channel_count(image: np.ndarray) -> np.ndarray:
|
||||
@@ -217,3 +225,23 @@ def safe_step(x: np.ndarray, step: int = 2) -> np.ndarray:
|
||||
y = x.astype(np.float32) * float(step + 1)
|
||||
y = y.astype(np.int32).astype(np.float32) / float(step)
|
||||
return y
|
||||
|
||||
|
||||
def resize_to_multiple(image: np.ndarray, multiple: int) -> np.ndarray:
|
||||
"""Resize an image to make its dimensions multiples of the given number."""
|
||||
|
||||
# Get the original dimensions
|
||||
height, width = image.shape[:2]
|
||||
|
||||
# Calculate the scaling factor to make the dimensions multiples of the given number
|
||||
new_width = (width // multiple) * multiple
|
||||
new_height = int((new_width / width) * height)
|
||||
|
||||
# If new_height is not a multiple, adjust it
|
||||
if new_height % multiple != 0:
|
||||
new_height = (new_height // multiple) * multiple
|
||||
|
||||
# Resize the image
|
||||
resized_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
return resized_image
|
||||
|
||||
@@ -1,672 +0,0 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
"""LoRA model support."""
|
||||
|
||||
import bisect
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
params = {"weight": self.get_weight(orig_module.weight)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias
|
||||
return params
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||
unknown_keys = set(values.keys()) - all_known_keys
|
||||
if unknown_keys:
|
||||
logger.warning(
|
||||
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||
)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
self.mid = values.get("lora_mid.weight", None)
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lora_up.weight",
|
||||
"lora_down.weight",
|
||||
"lora_mid.weight",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
self.t1 = values.get("hada_t1", None)
|
||||
self.t2 = values.get("hada_t2", None)
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"hada_w1_a",
|
||||
"hada_w1_b",
|
||||
"hada_w2_a",
|
||||
"hada_w2_b",
|
||||
"hada_t1",
|
||||
"hada_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1 = values.get("lokr_w1", None)
|
||||
if self.w1 is None:
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
else:
|
||||
self.w1_b = None
|
||||
self.w1_a = None
|
||||
|
||||
self.w2 = values.get("lokr_w2", None)
|
||||
if self.w2 is None:
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
else:
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
|
||||
self.t2 = values.get("lokr_t2", None)
|
||||
|
||||
if self.w1_b is not None:
|
||||
self.rank = self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
self.rank = self.w2_b.shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lokr_w1",
|
||||
"lokr_w1_a",
|
||||
"lokr_w1_b",
|
||||
"lokr_w2",
|
||||
"lokr_w2_a",
|
||||
"lokr_w2_b",
|
||||
"lokr_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
self.bias = values.get("diff_b", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"diff", "diff_b"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"weight", "on_input"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class NormLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["w_norm"]
|
||||
self.bias = values.get("b_norm", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"w_norm", "b_norm"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
):
|
||||
self._name = name
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
) -> Self:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem,
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
sd = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
# lora and locon
|
||||
if "lora_up.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_a" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
# norms
|
||||
elif "w_norm" in values:
|
||||
layer = NormLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
0
invokeai/backend/lora/__init__.py
Normal file
0
invokeai/backend/lora/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user