Compare commits

..

204 Commits

Author SHA1 Message Date
Ean Garvey
493f776253 add pretrain_models.csv 2023-06-15 14:52:18 -05:00
Ean Garvey
6a82667778 Add pretraining pytest script with SHARK dynamo. 2023-06-15 12:01:35 -05:00
Stefan Kapusniak
38570a9bbb Some Fixes for update to gradio 3.34.0 (#1538)
* Fixes randomize seed buttons that stopped working.
* Update now deprecated method to set initial colums for output
gallery to the newer undeprecated one.
2023-06-15 01:10:36 -07:00
dependabot[bot]
a5c882f296 Bump gradio from 3.15.0 to 3.34.0 (#1518)
Bumps [gradio](https://github.com/gradio-app/gradio) from 3.15.0 to 3.34.0.
- [Release notes](https://github.com/gradio-app/gradio/releases)
- [Changelog](https://github.com/gradio-app/gradio/blob/main/CHANGELOG.md)
- [Commits](https://github.com/gradio-app/gradio/compare/v3.15.0...v3.34.0)

---
updated-dependencies:
- dependency-name: gradio
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-14 18:13:48 -07:00
Ean Garvey
eb6d11cfed Change mlir dialects for tf tests to stablehlo. (#1535)
* Change mlir dialects for tf tests to stablehlo

* Update shark_runner.py
2023-06-14 10:43:49 -07:00
Vivek Khandelwal
46184a81ac Add Falcon pipeline (#1534) 2023-06-14 09:39:16 -07:00
PhaneeshB
149165a2f0 add multi-device mutli-precision vmfb names 2023-06-14 22:08:24 +05:30
dan
bec82a665f mega vicuna merge
single endpoint in apps/language/models/scripts/vicuna.py
removed main functions from pipelines
replaced divergent utils compile with shark_importer
adds support for different precisions
2023-06-14 19:06:29 +05:30
Ean Garvey
9551490341 Remove deprecared --iree-mhlo-demote-164-to-132 flag usage. (#1533) 2023-06-13 22:40:47 -05:00
Ean Garvey
49b3ecdbca (pytest) don't run redundant tests in cpu suite (#1532) 2023-06-13 22:40:33 -05:00
Ean Garvey
f53e3594c3 OPT Refactor (#1516)
* Change script to 1.3b model and add pytorch comparison

* fix CLI command

* Match OPT transformers model updates + numerics against latest version

* Cleanup OPT sentence completion script.

* Fix formatting and add standalone validation scripts.

* Add minimal OPT wrapper and example with import_with_fx

* Rename OPT full model wrapper.

* Cleanup test scripts for OPT.
2023-06-13 22:40:07 -05:00
Ean Garvey
5562d1dfda Fix xfails for cpu pytest cases (#1527)
Adding cpu-sync and cpu-task device configs was allowing respective tests to bypass the xfail conditional for cpu pytests marked in tank/all_models.csv. This commit updates the conditional to xfail those cases for cpu-sync and cpu-task as well.
2023-06-13 17:01:51 -07:00
Stefan Kapusniak
c7b0c2961e UI/Web Improve output gallery temp file handling (#1531)
* On startup report that cleaning up of temp files is taking place, in
case it takes a long time.
* Have the output gallery tab delete any zero length temporary files
generated by gradio < 3.32.0 for its gallery control whenever it
needs to update that control with images. This prevents such
files multiplying out of control.
2023-06-13 16:25:37 -05:00
Ean Garvey
44273b0791 Fix conditional in transform_fx() (#1530) 2023-06-13 16:24:53 -05:00
Prashant Kumar
0a4c8fcb3e Minor changes in the fx transforms. 2023-06-13 21:23:35 +05:30
Stefan Kapusniak
2fec3c8169 re-indents add_upcast in shark importer (#1523)
* The two with blocks in add_upcast appear to be underindented making
SD 1.4 break on rdna3, I've pushed them out one more tab, and then
everything appears to work again.
2023-06-12 14:41:10 -05:00
Gaurav Shukla
5e7d5930dd [vicuna] Add device and precision propagation in vicuna (#1520)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-12 12:14:43 -05:00
Prashant Kumar
b6dbd20250 Modify the fx transforms. (#1521)
- The bounds are set properly.
- The upcasting and downcasting is done for vicuna.
2023-06-12 09:40:14 -07:00
Nithin Meganathan
34f1295349 Add a model config generator (#1511)
Model config generator takes a PyTorch model as input and generates a JSON file with model layers and other propperties that define sharding on a particular hardware.
2023-06-09 15:32:00 -07:00
Phaneesh Barwaria
1980d7b2c3 Cpu device map (#1515)
* update cpu iree device

* fix vmfb paths vic unsharded
2023-06-09 11:27:02 -05:00
powderluv
2cfacc5051 fix osx torch_mlir (#1513)
* fix osx torch_mlir

* Update index.py

* Update index.py
2023-06-09 00:57:26 -07:00
Phaneesh Barwaria
436f58ddc4 cli using generate and mem fixes (#1509) 2023-06-08 13:13:32 -05:00
Phaneesh Barwaria
6b29bd17c8 Enable compilation vicuna (#1507)
* add cli for unsharded vic

* enable mlir download and compile
2023-06-07 13:08:22 -07:00
Ean Garvey
2c3485ca3e Add standalone OPT sentence completion script. (#1506) 2023-06-07 10:58:03 -07:00
Daniel Garvey
f206ecc635 reenable compilation in vicuna pipeline, add flags (#1505)
* replace vicuna.py backend with pipeline

* add some memory management to fist vicuna compile

reenable compilation
2023-06-07 09:49:27 -07:00
Stefan Kapusniak
a187e05ae6 Prevent having no cuda devices breaking the UI (#1503)
Don't break the UI when the LLM tab only wants cuda devices but there
aren't any.
2023-06-06 11:41:16 -07:00
Gaurav Shukla
8c21960486 [vicuna] Set only cuda devices in vicuna UI for now
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-06 22:15:20 +05:30
Gaurav Shukla
be62fce676 [vicuna] Fix vicuna chatbot (#1499)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-06 09:23:32 -07:00
PhaneeshB
f23b778a6c remove old vicuna scripts 2023-06-06 21:35:58 +05:30
PhaneeshB
436edf900d add vic sharded pipeline 2023-06-06 21:35:58 +05:30
Gaurav Shukla
ed58c2553f [vicuna] Integrate vicuna in shark studio
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-06 20:57:48 +05:30
Stefan Kapusniak
f2ca58e844 Add .csv and .json param info to output gallery (#1495) 2023-06-06 07:08:34 -07:00
Ean Garvey
1dbcc736eb [SD] (RDNA2) Enable new tuning for sd1.4 (#1498) 2023-06-06 06:48:58 -07:00
Phaneesh Barwaria
a83808ddc5 Vicuna cuda on A100 40G (#1496)
* vic chat with memory management (precompiled vmfb)

* fix vmfb path and download
2023-06-06 15:10:33 +05:30
Ean Garvey
a07fe80530 Update OPT, ResNet example scripts. (#1492)
* Update API in OPT example.

* fix resnet50 script

* Add OPT1.3b test script.
2023-06-05 20:19:35 -07:00
Ean Garvey
d0ba3ef8fa disable use_tuned on SD1.4 for rdna2 (#1490)
this is a temporary measure while we retune SD1.4 for rdna2. The current config fails during iree-compile.
2023-06-05 19:46:16 -05:00
Stefan Kapusniak
8400529c2c Fix output gallery not using shark_tmp (#1493)
This fix the gallery component of the  output gallery dumping temporary
files into the standard folders rather than shark_tmp so those files never
got cleared out on restart and would build up.
2023-06-05 16:23:49 -05:00
powderluv
7eaee9c242 update SHARK to nodai SHARK 2023-06-05 00:44:49 -07:00
powderluv
8230eebce5 Switch to CPU torch builds for shark.whl 2023-06-05 00:36:03 -07:00
Ean Garvey
6296ea4be9 fix config handling for sd1.4 on rdna2 (#1489) 2023-06-05 00:02:30 -07:00
Ean Garvey
4151ec3a8f (pytest) tag efficientnet, mobilenet as xfails on vulkan (#1488) 2023-06-04 23:22:32 -07:00
powderluv
a2467e8d43 Enable SHARK whl packages 2023-06-04 23:21:22 -07:00
Ean Garvey
e677178bcc Replace RDNA2 SD lowering configs. (#1486) 2023-06-05 00:57:43 -05:00
Anush Elangovan
7ef1bea953 XFAIL some macos tests 2023-06-04 15:27:03 -07:00
Chi_Liu
ad89bb1413 Add distilgpt2 to stablehlo in shark tank (#1481) 2023-06-02 16:44:46 -05:00
Ean Garvey
218ed78c40 Change instances of input_type='mhlo' to 'auto' (#1482) 2023-06-02 16:43:47 -05:00
Stefan Kapusniak
6046f36ab6 UI/Web: Fix upscaler stop button (mostly) (#1479)
* UI/Web: Fix upscaler stop button

* Hook the cancel_sd function up to the Stop button.
* Adds checks for SD_STATE_CANCEL in the upscaler ui inference function.
* Set and check for SD_STATE_IDLE, SD_STATE_CANCEL in the upscaler
pipeline.

* UI/Web: lint fixes for upscaler stop button fix

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-01 22:26:55 -07:00
Foxlum
5915bf7de3 Add to and tweak vulkan configuration environments. (#1475)
* Update vulkan_target_env_utils.py

* Update vulkan_target_env_utils.py

Adjust target environment capabilities.

* Update vulkan_target_env_utils.py

black linted?
2023-06-01 22:25:20 -07:00
Phaneesh Barwaria
f0a4e59758 LLM Pipeline Wrapper (#1477)
* [LLM] Add LLM pipeline

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

* add base pipeline and stableLM

* StableLM on UI - full block

* add SLM default model name

* add vicuna with pipeline

* add one token gen api for vic

* Fix stableLM bugs

* debug vic memory

* lint fix

---------

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
Co-authored-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-31 10:17:20 -07:00
Stefan Kapusniak
1ddef26af5 Web/UI: Add an Output Gallery tab for SD (#1470)
* WebUI: Adds an Output Gallery tab

Adds an new Output Gallery tab to the ui/webui with these features:

* Subdirectory select dropdown listing subdirectories at any depth below
the <output_dir>/generated_imgs directory,
* Large, full height, gallery area displaying the images in the selected
subdirectory. Shows nod logo when no images are in the selected
subdirectory.
* Slider that changes the number of columns of images that the gallery
displays from between 1 to 16 columns (defaults to 4).
* Expandable parameter info panel showing any generation parameters
saved in the file of the selected image for PNGs, alternatively the
image's EXIF data for JPEGs
* Send to buttons for txt2img, img2img, inpaint, outpaint and upscaler.
* Auto update of gallery and gallery label (to show generation status),
when a new image is generated by any of the stable diffusion tabs, and
is outputted to the currently selected subdirectory.
* Command line option for enabling and disabling the output gallery
(defaults to enabled)
* Command line option for following symlinks when getting entries
for the subdirectory list (defaults to off, as Python os.walk doesn't
check for circular references if following symlinks)

* Reformat with black

Reformat changes with black and then adjust some places where black's
formatting then needed some rephrasing of the code to make things
clearer.

* Add back transformers and sd_cancel imports

Adds back the transformers import in index.py needed for .exe
generation. Add comment so it doesn't get mistakenly removed
next time.
Adds back sd_cancel import in upscaler.py that is currently unused
but should be being used for the 'Stop' button.
2023-05-30 13:47:48 -07:00
Chi_Liu
ba8eddb12f Add GPT3/OPT to Stablehlo in shark tank (#1468)
Co-authored-by: AmosLewis <Amos_Lewsi@foxmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-05-29 21:58:39 -07:00
yzhang93
47b346d428 Modify the lowering config format for SPIRVMatmulPromoteVectorize pipeline (#1471) 2023-05-29 21:53:48 -07:00
Ean Garvey
1b4f4f5f4d Fix download path for SD1.4 Unet. (#1469) 2023-05-26 11:59:51 -07:00
Elias Joseph
73cd7e8320 added full vicuna to vicuna.py 2023-05-26 22:06:40 +05:30
Ean Garvey
19c0ae3702 Cleanup SD pipeline utils (#1466) 2023-05-25 12:50:11 -05:00
Ean Garvey
54e57f7771 Revive SD downloads from shark_tank. (#1465) 2023-05-25 12:03:21 -05:00
PhaneeshB
6d64b8e273 vic and slm common generation base 2023-05-25 20:29:41 +05:30
PhaneeshB
a8ea0326f5 correct SLM saved vmfb naming 2023-05-25 20:29:41 +05:30
PhaneeshB
58e9194553 add Lists import 2023-05-25 20:29:41 +05:30
PhaneeshB
eb360e255d remove unused imports 2023-05-25 20:29:41 +05:30
PhaneeshB
a6f88d7f72 refactor mlir compile 2023-05-25 20:29:41 +05:30
Prashant Kumar
8e571d165f Enable cpu f16 dtype tracing for the vicuna model. (#1461) 2023-05-24 09:37:57 -07:00
Ean Garvey
3cddd01b10 Update OPT tokenizer and xfail a few more large tests on macos CI (#1459)
* Update opt_torch_test.py

* Update all_models.csv
2023-05-23 14:36:57 -07:00
Chi_Liu
64c2b2d96b Add gpt2 to stablehlo support in shark tank (#1447)
- Add torch decomposition support when generating shark tank
- Add gpt2 stablehlo
2023-05-22 10:45:51 -07:00
Phaneesh Barwaria
f5ce121988 SLM on Sharkstudio (#1454)
* localize import, fix file reading, device cpu

* extract out model args
2023-05-19 11:21:08 -07:00
Ean Garvey
991f144598 Add iree hidden imports to SD spec (#1456)
* Add iree hidden imports to SD spec

* Update shark_sd_cli.spec
2023-05-19 11:19:16 -07:00
PhaneeshB
09bea17e59 fix #2 SLM in SharkStudio 2023-05-18 00:56:22 +05:30
Daniel Garvey
aefcf80b48 swap to cpu an remove hardcoded paths (#1448)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-05-17 10:53:34 -07:00
PhaneeshB
512235892e fix SLM for SharkStudio 2023-05-17 22:34:30 +05:30
PhaneeshB
6602a2f5ba add continuous output for CLI 2023-05-17 18:33:46 +05:30
Boian Petkantchin
20114deea0 In MiniLM JAX example verify MLIR result against JAX 2023-05-16 09:54:07 -07:00
Boian Petkantchin
9acf519078 Add option to skip venv creation in setup script 2023-05-16 09:54:07 -07:00
Boian Petkantchin
bdf37b5311 If device/backend is unknown pass it to IREE verbatim 2023-05-16 09:54:07 -07:00
powderluv
8ee2ac89f8 Rename sharded_vicuna_fp32_web.py to vicuna_web.py 2023-05-16 09:41:35 -07:00
powderluv
60cb48be2e Rename sharded_vicuna_fp32.py to vicuna.py 2023-05-16 09:40:51 -07:00
powderluv
86a215b063 Delete sharded_vicunia.py 2023-05-16 09:37:39 -07:00
powderluv
d6e3a9a236 Delete standalone_vicuna.py 2023-05-16 09:37:26 -07:00
Chi_Liu
a0097a1ead Add mlir_type for torch_model_list.csv (#1428)
- Enable stablehlo/tosa mlir output for torch model
- Add BERT stablehlo support
2023-05-15 10:23:54 -07:00
Ean Garvey
a9bae00606 Fix vulkan device selection at compile time and adapt to IREE python changes. (#1407)
* Add support for vulkan device selection at compile time.

* Don't convert device ID to int and fix .exe imports
2023-05-12 23:31:50 -07:00
Daniel Garvey
4731c1a835 prevent loading tokenizer on import (#1432)
also adds sentencepiece dep for exe
moved vicuna imports to after an if statement
in general we should avoid importing files that load whole models as
global variables
2023-05-12 19:11:45 -07:00
Ean Garvey
4c07e47e8c Specify a few models for expected failure on CUDA CI. (#1430) 2023-05-12 17:03:37 -05:00
Gaurav Shukla
e0cc2871bb [SD] Yield 2 tokens at a time in vicuna
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 23:49:01 +05:30
Gaurav Shukla
649f39408b [SD] Fix vicuna response
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 18:06:21 +05:30
Gaurav Shukla
c142297d73 [SD] Fix gradio to 3.22.0 version
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com
2023-05-11 18:05:55 +05:30
Gaurav Shukla
9e07360b00 [SD] Standalone vicuna with web
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 17:23:44 +05:30
Gaurav Shukla
7b74c86e42 [SD] Fix SAMPLE_INPUT_LEN import issue
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 15:41:43 +05:30
Eliasj42
fa833f8366 fixed spacing issue with chat-bot (#1417)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-10 16:07:50 -07:00
Gaurav Shukla
fcb059aa38 [SD] Integrate vicuna in the web (#1410) 2023-05-10 11:30:22 -07:00
PhaneeshB
517c670f82 vicuna chat cli 2023-05-10 22:55:06 +05:30
Eliasj42
59df14f18b added vicuna demo (#1408)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-09 21:18:20 -07:00
Ean Garvey
6c95ac0f37 Revert dialect registration in model annotator (#1406)
Matches https://github.com/nod-ai/SHARK-Runtime/pull/58
2023-05-09 11:50:19 -07:00
Daniel Garvey
7a4a51ae73 vulkan vic f16 (#1404)
Co-authored-by: dan <dan@nod-labs.com>
2023-05-08 16:46:53 -07:00
powderluv
d816cc015e Revert "added standalone vicuna script (#1399)" (#1402)
This reverts commit 0e4a8ca240.
2023-05-05 16:08:05 -07:00
Eliasj42
54ce3d48ca added standalone vicuna script (#1401)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-05 18:05:52 -05:00
Eliasj42
0e4a8ca240 added standalone vicuna script (#1399)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-05 15:46:05 -07:00
Daniel Garvey
6ca1298675 maximizes window size for webview launch (#1394) 2023-05-04 20:43:06 -07:00
jinchen62
bbef7a6464 Redesign model manager webui (#1391) 2023-05-04 20:41:29 -07:00
Ean Garvey
cdf2d61d53 Remove imports from iree.compiler.transforms from model annotator. (#1392) 2023-05-04 20:40:19 -07:00
Ean Garvey
6c14847d1f xfail some large tests on macOS builder and switch to hash updates. (#1341)
* Update test-models.yml

* Disable large tests on macOS builder
2023-05-04 19:47:03 -05:00
Gaurav Shukla
68ecdd2a73 [SD] Add LoRA as experimental tab
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-04 22:30:25 +05:30
Gaurav Shukla
3f4d444d18 [SD] Fix stable LM chatbot
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-04 22:30:25 +05:30
m68k-fr
e473d0375b [Web] Models folders cleanup (#1365) 2023-05-03 16:13:20 -05:00
Ean Garvey
e38d96850f Fix input image loading in img2img rest API (#1388) 2023-05-03 15:51:00 -05:00
Gaurav Shukla
fed63dfd4b [SD] Add stableLM chatbot (#1383)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-05-03 15:37:20 -05:00
Boian Petkantchin
eba4d06405 In MiniLM JAX example do not hardcode device (#1385)
* In MiniLM JAX example do not hardcode device

* In MiniLM JAX example don't use bytecode MLIR

---------

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-05-03 10:34:42 -07:00
Boian Petkantchin
4cfba153d2 Add example JAX MiniLM inference (#1380)
* Do not hardcode the name of the VM module in get_iree_module

* Add example JAX MiniLM inference

---------

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-05-02 15:03:54 -07:00
jinchen62
307c05f38d Convert original vae to diffusers (#1382) 2023-05-02 01:27:28 -07:00
jinchen62
696df349cb Fix curl issue (#1369) 2023-04-28 09:31:14 -07:00
jinchen62
cb54cb1348 Add model manager tab for SD webui (#1368) 2023-04-28 02:43:40 -07:00
Daniel Garvey
9bdb86637d add tkinter launch for webui (#1364) 2023-04-27 19:17:55 -05:00
jinchen62
fb6f26517f Fix webui note (#1367) 2023-04-27 16:14:43 -07:00
Chi_Liu
aa8ada9da9 Add support for torch to stablehlo and tosa in shark_importer (#1360) 2023-04-27 08:09:45 -07:00
powderluv
1db906a373 Revert "Add model manager tab for webui (#1359)" (#1362)
This reverts commit 9d1d1617d8.
2023-04-26 22:25:26 -07:00
jinchen62
9d1d1617d8 Add model manager tab for webui (#1359) 2023-04-26 13:38:18 -07:00
jinchen62
7112789cb8 Add support of using civitai model download url (#1357) 2023-04-25 23:39:52 -07:00
jinchen62
d6b8be2849 Add drawing canvas for img2img stencil scribble (#1355) 2023-04-25 14:41:01 -07:00
powderluv
822171277c Revert "[SD] Add FastChat as part of SD WebUI (#1349)" (#1350)
This reverts commit a5ae9d9f02.
2023-04-24 15:22:25 -07:00
Abhishek Varma
a5ae9d9f02 [SD] Add FastChat as part of SD WebUI (#1349)
-- This commit includes FastChat as part of SD WebUI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-04-24 11:12:58 -07:00
powderluv
09e3f63d5b Fix pascal (#1346)
* Add fp32 for upscaler VAE

* Plumb Pascal vulkan support
2023-04-23 20:28:25 -07:00
powderluv
d60a5a9396 Add fp32 for upscaler VAE (#1345) 2023-04-23 15:27:55 -07:00
m68k-fr
90df0ee365 [Web] Gallery set to a 768px reference for high-end desktop users (#1344) 2023-04-23 11:48:06 -07:00
nirvedhmeshram
133c1bcadd add device to scheduler model names (#1338) 2023-04-22 20:13:56 -05:00
powderluv
caadbe14e9 Revert VAE to use im2col (#1339) 2023-04-22 15:23:41 -07:00
Ean Garvey
5f5823ccd9 Fix inference object imports for SD apps. (#1334) 2023-04-21 13:40:48 -05:00
Vivek Khandelwal
d2f7e03b7e Add StableLM model (#1331) 2023-04-21 09:51:02 -07:00
Gaurav Shukla
0b01bbe479 [SD] Add txt2img/upscaler/inpaint/outpaint Rest API (#1325)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-21 09:06:06 -07:00
yzhang93
25c5fc44ae Modify tuner.py to take vulkan target triple flag (#1328) 2023-04-20 14:31:32 -07:00
Daniel Garvey
7330729c92 enable sd pytest (#1322) 2023-04-19 22:11:30 -05:00
Ean Garvey
ce16cd5431 Create local shark_tank if needed for tuning configs. (#1321)
Now that --clear_all successfully deletes local shark_tank cache, we need to make sure it exists before trying to use it.
2023-04-19 11:44:21 -05:00
Ean Garvey
598dc5f79d Don't dump image data on img2img api call. (#1320) 2023-04-19 21:24:46 +05:30
Abhishek Varma
1f8e332cbe [SD] Fix img2img API bug for custom_vae argument (#1319)
-- https://github.com/nod-ai/SHARK/pull/1314 misses to add `custom_vae`
   parameter to img2img_if's invocation within img2img_api.
-- This commit adds a fix to the same.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-04-19 10:39:52 -05:00
Abhishek Varma
17b9632659 [SD] Adapted SHARK's v1 img2img API for SdPaint + updated Stencil model ID (#1318) 2023-04-19 06:29:36 -07:00
jinchen62
bda92a54ab Fix custom vae path (#1317) 2023-04-18 20:50:43 -07:00
jinchen62
747ed383b1 Add custom vae dropdown in webui (#1314) 2023-04-18 17:24:02 -07:00
Ean Garvey
1afe07c296 Disable winograd on VAE with rdna2 and fix unet tuning. (#1313)
* Disable winograd on VAE with rdna2 and fix unet tuning.

* Fix batch size 1 downloads and clear_all on windows.
2023-04-18 15:55:10 -05:00
jinchen62
b70919b38d Fix memory leak with ondemand (#1312)
support ondemand for outpainting and multi batch_count
2023-04-18 13:03:16 -05:00
m68k-fr
4e513d647f Update list of scheduler available for inferences (#1298) 2023-04-17 22:37:00 -05:00
jinchen62
94cd2a0fed Fix outpainting config (#1310) 2023-04-17 10:48:52 -07:00
Kyle Herndon
606029c01c Fix LoRA device format bug and allow LoRA to resume from a previous training 2023-04-17 13:19:46 +05:30
powderluv
1aa85222e9 Add AMD W7900 target triple (#1304)
This maps to RDNA3
2023-04-16 00:14:21 -07:00
m68k-fr
1b3f468c04 [Web] Style Fixes for Gradio V3.25.0 (#1300) 2023-04-13 18:40:42 -05:00
m68k-fr
35de7e27fa [Web] remove txt2img ui dependencies from png import metadata (#1275) 2023-04-12 07:32:47 -10:00
yzhang93
467f900759 Add auto-tuner to SD apps (#1291) 2023-04-12 09:21:17 -07:00
Ean Garvey
0bd9d582c7 Add documentation for using SHARK with AI-Render (#1296) 2023-04-12 03:09:34 -10:00
jinchen62
428cfe8dae Fix low vram mode issues (#1295)
- add ondemand back to img2img
- workaround memory leak for batch count
2023-04-11 17:59:09 -07:00
Ean Garvey
f17915bedc Fix batch size appending to model name. (#1294)
* Update shark_downloader.py

* Update shark_downloader.py
2023-04-11 15:34:25 -05:00
Gaurav Shukla
1b49b5149a [SD] Add Img2Img rest API
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-11 23:06:58 +05:30
jinchen62
3002793301 Unload clip on demand and workaround memory leak (#1283) 2023-04-10 16:59:03 -07:00
Phaneesh Barwaria
d25ef5529f Add fix for vae fp32 Upscalar (#1284)
- fixes size mismatch error for upscalar vae
2023-04-07 14:36:40 -05:00
Ean Garvey
308856a947 Touch unet if base cfg needed for SD pipeline init (#1281) 2023-04-05 03:02:29 -05:00
m68k-fr
151b4e142f [SD] Fix encoder error for model_max_length not beeing 77 (#1278)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-04-04 22:39:29 -07:00
Ean Garvey
e5a69a7c36 pin diffusers to e47459c (#1279) 2023-04-04 18:29:21 -07:00
m68k-fr
450b6cafc4 [SD] Add weight emphasis to prompts encoder (#1276) 2023-04-04 09:47:04 -07:00
Daniel Garvey
237d26baa2 update model db to reflect changes (#1277)
* remove 1/1 tqdm progress bar

* update model_db to reflect changes
2023-04-04 11:46:55 -05:00
Daniel Garvey
67d6ee1104 remove 1/1 tqdm progress bar (#1274) 2023-04-03 22:30:09 -05:00
Ean Garvey
98b069488e Add tank_version.json (#1272) 2023-04-03 18:36:23 -07:00
jinchen62
e0f227643a Fix webui circular import issue (#1271) 2023-04-03 16:00:10 -07:00
jinchen62
a0af3bb0cb xload and unload models (#1242) 2023-04-03 14:42:18 -07:00
powderluv
2cd61a5b96 strip source map (#1270) 2023-04-03 14:41:32 -07:00
Gaurav Shukla
f49d41a807 [SD] Add Stable diffusion text2image rest API (#1265)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-03 12:02:24 -07:00
Ean Garvey
2191fc8952 Separate pytest benchmark modes and fix model updates for SHARK downloader / pytest. (#1264)
* Only xfail windows models in CI

* downloader: make model updates more robust.

* Separate baseline and native benchmarks in pytest.

* Fix native benchmarks

* Fix torchvision model utils.
2023-04-03 08:24:21 -07:00
PhaneeshB
aea7796e60 add gradio client to spec 2023-04-03 18:57:19 +05:30
Abhishek Varma
a376619f1e [SD] Improve vmfb caching algo and retry mechanism (#1248)
-- This commit gets rid of the all-or-nothing vmfb caching mechanism
   and improves the retry mechanism by providing lower-level granularity
   for compiling each model units.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-03-31 09:38:14 -07:00
powderluv
02d52bb626 Add Intel ARC A770 target triple (#1263)
This just enables the plumbing. It generates black images.
2023-03-29 14:49:05 -07:00
Abhishek Varma
3b63645f79 [SD] Fix custom model path for WebUI (#1260)
-- This commit fixes custom model path for WebUI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-29 09:48:11 -07:00
Ean Garvey
d6f740b998 allow pytest to retry getting model artifacts + disable autotuning for pytorch benchmarks (#1257)
* Adds a few xfails to enable macOS builder

* Convert string batch sizes to ints where needed.

* allow pytest to retry getting model artifacts

* Reduce attempts and add assert msg.
2023-03-28 23:38:45 -05:00
Daniel Garvey
594c6b8ea2 fix ckpt dir (#1258) 2023-03-28 14:31:01 -07:00
Ean Garvey
96b1560da5 Make batch size configurable via pytest and fix sharktank generation. (#1227)
* Fix sharktank generation and add batch_size pytest option for torch.

* Disable torch dynamo until py3.11 supported

* Compile torchmodel without dynamo if torch.compile fails

* Use release versions of TF/Keras for importer.

* Pin torchvision and remove debug prints.

* Remove duplicates from torch model list.

* Update generate_sharktank.py

* xfail a few models that fail sharktank generation/ numerics
2023-03-28 14:33:39 -05:00
Abhishek Varma
0ef6a0e234 [SD] Fix Stencil scribble crash by updating image resize (#1255)
-- This commit updates Stencil resize feature to cap the size of
   images within [128,768] as supported by the SD pipeline.
-- This solves the issue of scribble crashing on larger image.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-28 10:13:11 -07:00
Gaurav Shukla
641d535f44 [SD] Fix device path issue for cpu (#1256)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-28 10:09:49 -07:00
Daniel Garvey
5bb7846227 single entry point exe for all cli apps (#1158)
usage:
add --app="img2img" (or "inpaint" "outpaint" "txt2img")
2023-03-28 11:15:21 -05:00
yzhang93
8f84258fb8 Fix check for use_tuned conditions (#1252) 2023-03-27 11:21:25 -07:00
Ean Garvey
7619e76bbd Disable and xfail some models that fail validation/compilation. (#1251)
* Rollback T5 models for torch as the inputs give some issues that aren't trivial to resolve
* xfail efficientnet-b0 on torch+cuda -- see CUDA requesting shared memory size larger than allowed size openxla/iree#12771
2023-03-27 12:42:53 -05:00
Daniel Garvey
9267eadbfa disable openjourney gen for nightly (#1249) 2023-03-27 11:55:34 -05:00
Phaneesh Barwaria
431132b8ee Fix img2img mode switch (#1247)
* add updated scheduler value in global config

* clear scheduler global variable with others
2023-03-27 07:01:22 -07:00
cstueckrath
fb35e13e7a fix Python version detection bug (#1246)
* fix Python version detection bug

* Update setup_venv.ps1
2023-03-27 07:00:40 -07:00
yzhang93
17a67897d1 Add SD v2.1 768x768 tuned model (#1244)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-24 10:39:15 -07:00
Gaurav Shukla
da449b73aa [SD] Disable lora training tab for now (#1241)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-24 09:16:24 -07:00
Kyle Herndon
0b0526699a Fix incorrect device argument initialization for LoRA training by extracting the device type and number and formatting it for pytorch (#1237)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
2023-03-24 01:10:50 -07:00
Boian Petkantchin
4fac46f7bb In models testing fix paths to be relative to the script dir not cwd (#1128)
authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-03-22 15:26:52 -05:00
Daniel Garvey
49925950f1 fix false positives (#1193) 2023-03-22 15:25:39 -05:00
Thomas
807947c0c8 Remove deprecated cli option iree-hal-cuda-disable-loop-nounroll-wa (#1235) 2023-03-22 12:05:15 -05:00
Abhishek Varma
593428bda4 [SD] Fix for transformers/__init__.py issue in PyInstaller (#1233)
-- This commit fixes the transformers/__init__.py issue in PyInstaller.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-22 08:43:53 -07:00
Abhishek Varma
cede9b4fec [SD] Fix custom_vae as a required parameter in inpaint (#1232) 2023-03-22 04:30:17 -07:00
Prashant Kumar
c2360303f0 Add the int8 quantized model. 2023-03-22 16:28:13 +05:30
jinchen62
420366c1b8 Move schedulers to global obj (#1225) 2023-03-21 22:40:43 -07:00
Ean Garvey
d31bae488c Set iree-input-type to tm_tensor for SD (#1228) 2023-03-21 19:07:31 -07:00
Kyle Herndon
c23fcf3748 Fix incorrect device argument initialization for LoRA training (#1231)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-21 19:07:18 -07:00
jinchen62
7dbbb1726a Fix SD obj not defined if fail to get models from pretrained (#1222) 2023-03-21 07:55:17 -07:00
Abhishek Varma
8b8cc7fd33 [SD] Update LoRA inference to handle various checkpoints (#1215) 2023-03-21 06:52:20 -07:00
Ean Garvey
e3c96a2b9d Move sentencepiece to importer requirements. (#1218) 2023-03-21 00:39:57 -05:00
Ean Garvey
5e3f50647d Set --vulkan_large_heap_block_size default to 2gb. (#1220) 2023-03-20 21:07:09 -07:00
gpetters94
7899e1803a Add fix for attention slicing fp16 (#1217) 2023-03-20 19:11:29 -07:00
mariecwhite
d105246b9c Fix t5 models 2023-03-21 10:39:59 +11:00
mariecwhite
90c958bca2 Add T5-base and T5-large Torch and TF Models (#1116) 2023-03-20 17:32:50 -05:00
mariecwhite
f99903e023 Add EfficientNet B0 and B7 Torch and TF models 2023-03-21 09:22:05 +11:00
mariecwhite
c6f44ef1b3 Add EfficientNet B0 and B7 Torch and TF models 2023-03-21 09:14:45 +11:00
mariecwhite
8dcd4d5aeb Make batch size configurable 2023-03-20 18:03:17 -04:00
Phoenix Meadowlark
d319f4684e Add peak memory reporting for IREE, TF and PyTorch (#1216) 2023-03-20 15:40:49 -05:00
Ean Garvey
54d7b6d83e Generate model artifacts in pytests if they don't exist in the cloud. (#1121)
* Add gen_shark_files fn to shark_downloader for OTF artifact generation

* add generate_sharktank as a tank/ python module.

* Fix some paths in tank generation.
2023-03-20 12:13:19 -05:00
m68k-fr
4a622532e5 [Web] Stop images (#1212) 2023-03-19 14:37:30 -07:00
cstueckrath
650b2ada58 add pytorch_lightning to requirements (#1211)
* add pytorch_lightning to requirements

this will additionally add lightning-utilities and torchmetrics

* Update shark_sd.spec

* Update shark_sd_cli.spec
2023-03-19 12:29:54 -07:00
m68k-fr
f87f8949f3 [Web] CSS fix for gradio V3.22.1 (#1210) 2023-03-19 06:13:59 -07:00
m68k-fr
7dc9bf8148 [Web] Move "stop Batch" button to "Advanced Options" toggle (#1209) 2023-03-18 20:54:42 -07:00
123 changed files with 10945 additions and 3247 deletions

View File

@@ -50,27 +50,13 @@ jobs:
shell: powershell
run: |
./setup_venv.ps1
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
python process_skipfiles.py
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
# GHA windows VM OOMs so disable for now
#- name: Build and validate the SHARK Runtime package
# shell: powershell
# run: |
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
#- uses: actions/upload-artifact@v2
# with:
# path: dist/*
mv ./dist/shark_sd.exe ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets
uses: dwenegar/upload-release-assets@v1
@@ -78,7 +64,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
assets_path: ./dist/*
assets_path: ./dist/nodai*
#asset_content_type: application/vnd.microsoft.portable-executable
- name: Publish Release

View File

@@ -112,7 +112,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
@@ -120,9 +120,9 @@ jobs:
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
@@ -137,7 +137,7 @@ jobs:
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan --update_tank
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
@@ -145,17 +145,19 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
pytest -k vulkan -s
pytest -k vulkan -s --ci
- name: Validate Stable Diffusion Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
python build_tools/stable_diffusion_testing.py --device=vulkan

View File

@@ -114,12 +114,12 @@ source shark.venv/bin/activate
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
```shell
python3.11 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
python3.11 apps/stable_diffusion/scripts/main.py --app=txt2img --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc

View File

@@ -0,0 +1,210 @@
import torch
import torch_mlir
from transformers import (
AutoTokenizer,
StoppingCriteria,
)
from io import BytesIO
from pathlib import Path
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
get_vmfb_from_path,
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def shouldStop(tokens):
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if tokens[0][-1] == stop_id:
return True
return False
MAX_SEQUENCE_LENGTH = 256
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def compile_stableLM(
model,
model_inputs,
model_name,
model_vmfb_name,
device="cuda",
precision="fp32",
):
from shark.shark_inference import SharkInference
# device = "cuda" # "cpu"
# TODO: vmfb and mlir name should include precision and device
vmfb_path = (
Path(model_name + f"_{device}.vmfb")
if model_vmfb_name is None
else Path(model_vmfb_name)
)
shark_module = get_vmfb_from_path(
vmfb_path, device, mlir_dialect="tm_tensor"
)
if shark_module is not None:
return shark_module
mlir_path = Path(model_name + ".mlir")
print(
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
)
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
)
print("Saved vmfb at ", str(path))
return shark_module
class StableLMModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
combine_input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
output = self.model(**combine_input_dict)
return output.logits
# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
def get_tokenizer():
model_path = "stabilityai/stablelm-tuned-alpha-3b"
tok = AutoTokenizer.from_pretrained(model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
print("Sucessfully loaded the tokenizer to the memory")
return tok
# sharkStableLM = compile_stableLM
# (
# None,
# tuple([input_ids, attention_mask]),
# "stableLM_linalg_f32_seqLen256",
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
# )
def generate(
new_text,
max_new_tokens,
sharkStableLM,
tokenizer=None,
):
if tokenizer is None:
tokenizer = get_tokenizer()
# Construct the input message string for the model by
# concatenating the current system message and conversation history
# Tokenize the messages string
# sharkStableLM = compile_stableLM
# (
# None,
# tuple([input_ids, attention_mask]),
# "stableLM_linalg_f32_seqLen256",
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
# )
words_list = []
for i in range(max_new_tokens):
# numWords = len(new_text.split())
# if(numWords>220):
# break
params = {
"new_text": new_text,
}
generated_token_op = generate_new_token(
sharkStableLM, tokenizer, params
)
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]
if stop_generation:
break
print(detok, end="", flush=True)
words_list.append(detok)
if detok == "":
break
new_text = new_text + detok
return words_list
def generate_new_token(shark_model, tokenizer, params):
new_text = params["new_text"]
model_inputs = tokenizer(
[new_text],
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
output = shark_model(
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
)
output = torch.from_numpy(output)
next_toks = torch.topk(output, 1)
stop_generation = False
if shouldStop(next_toks.indices):
stop_generation = True
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
detok = tokenizer.decode(
new_token,
skip_special_tokens=True,
)
ret_dict = {
"new_token": new_token,
"detok": detok,
"stop_generation": stop_generation,
}
return ret_dict

View File

@@ -0,0 +1,122 @@
import argparse
from pathlib import Path
from apps.language_models.src.pipelines import vicuna_pipeline as vp
from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp
import torch
parser = argparse.ArgumentParser(
prog="vicuna runner",
description="runs a vicuna model",
)
parser.add_argument(
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
"--first_vicuna_vmfb_path", default=None, help="path to first vicuna vmfb"
)
parser.add_argument(
"-s",
"--sharded",
default=False,
action=argparse.BooleanOptionalAction,
help="Run model as sharded",
)
# TODO: sharded config
parser.add_argument(
"--second_vicuna_vmfb_path",
default=None,
help="path to second vicuna vmfb",
)
parser.add_argument(
"--first_vicuna_mlir_path",
default=None,
help="path to first vicuna mlir file",
)
parser.add_argument(
"--second_vicuna_mlir_path",
default=None,
help="path to second vicuna mlir",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
parser.add_argument(
"--cli",
default=False,
action=argparse.BooleanOptionalAction,
help="Run model in cli mode",
)
if __name__ == "__main__":
args, unknown = parser.parse_known_args()
vic = None
if not args.sharded:
first_vic_mlir_path = (
Path(f"first_vicuna_{args.precision}.mlir")
if args.first_vicuna_mlir_path is None
else Path(args.first_vicuna_mlir_path)
)
second_vic_mlir_path = (
Path(f"second_vicuna_{args.precision}.mlir")
if args.second_vicuna_mlir_path is None
else Path(args.second_vicuna_mlir_path)
)
first_vic_vmfb_path = (
Path(
f"first_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
)
if args.first_vicuna_vmfb_path is None
else Path(args.first_vicuna_vmfb_path)
)
second_vic_vmfb_path = (
Path(
f"second_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
)
if args.second_vicuna_vmfb_path is None
else Path(args.second_vicuna_vmfb_path)
)
vic = vp.Vicuna(
"vicuna",
device=args.device,
precision=args.precision,
first_vicuna_mlir_path=first_vic_mlir_path,
second_vicuna_mlir_path=second_vic_mlir_path,
first_vicuna_vmfb_path=first_vic_vmfb_path,
second_vicuna_vmfb_path=second_vic_vmfb_path,
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
)
else:
vic = vsp.Vicuna(
"vicuna",
device=args.device,
precision=args.precision,
)
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"
import gc
while True:
# TODO: Add break condition from user input
user_prompt = input("User: ")
prompt_history = (
prompt_history + "USER:\n" + user_prompt + prologue_prompt
)
prompt = prompt_history.strip()
res_str = vic.generate(prompt, cli=True)
torch.cuda.empty_cache()
gc.collect()
print(
"\n-----\nAssistant: Here's the complete formatted reply:\n",
res_str,
)
prompt_history += f"\n{res_str}\n"

View File

@@ -0,0 +1,22 @@
import torch
class FalconModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": None,
"use_cache": True,
}
output = self.model(
**input_dict,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)[0]
return output[:, -1, :]

View File

@@ -0,0 +1,15 @@
import torch
class StableLMModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
combine_input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
output = self.model(**combine_input_dict)
return output.logits

View File

@@ -0,0 +1,261 @@
import torch
from transformers import AutoModelForCausalLM
class FirstVicuna(torch.nn.Module):
def __init__(self, model_path):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
def forward(self, input_ids):
op = self.model(input_ids=input_ids, use_cache=True)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class SecondVicuna(torch.nn.Module):
def __init__(self, model_path):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
def forward(
self,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
):
# input_ids = input_tuple[0]
# input_tuple = torch.unbind(pkv, dim=0)
token = i0
past_key_values = (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class CombinedModel(torch.nn.Module):
def __init__(
self,
first_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF",
second_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF",
):
super().__init__()
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
def forward(self, input_ids):
first_output = self.first_vicuna(input_ids=input_ids, use_cache=True)
logits = first_output[0]
pkv = first_output[1:]
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
second_output = self.second_vicuna(secondVicunaInput)
return second_output

View File

@@ -0,0 +1,178 @@
import torch
class FirstVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states, attention_mask, position_ids):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class SecondVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
past_key_value0,
past_key_value1,
),
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class CompiledFirstVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class CompiledSecondVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
def forward(
self,
input_ids,
is_first=True,
past_key_values=None,
attention_mask=None,
):
if is_first:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers0
)
return self.model.forward(input_ids, attention_mask=attention_mask)
else:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers1
)
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)

View File

@@ -0,0 +1,41 @@
from abc import ABC, abstractmethod
class SharkLLMBase(ABC):
def __init__(
self, model_name, hf_model_path=None, max_num_tokens=512
) -> None:
self.model_name = model_name
self.hf_model_path = hf_model_path
self.max_num_tokens = max_num_tokens
self.shark_model = None
self.device = "cpu"
self.precision = "fp32"
@classmethod
@abstractmethod
def compile(self):
pass
@classmethod
@abstractmethod
def generate(self, prompt):
pass
@classmethod
@abstractmethod
def generate_new_token(self, params):
pass
@classmethod
@abstractmethod
def get_tokenizer(self):
pass
@classmethod
@abstractmethod
def get_src_model(self):
pass
def load_init_from_config(self):
pass

View File

@@ -0,0 +1,473 @@
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.utils import (
get_vmfb_from_path,
)
from io import BytesIO
from pathlib import Path
from contextlib import redirect_stdout
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
StoppingCriteriaList,
)
import copy
import re
import torch
import torch_mlir
import os
import argparse
parser = argparse.ArgumentParser(
prog="falcon runner",
description="runs a falcon model",
)
parser.add_argument(
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
"--falcon_vmfb_path", default=None, help="path to falcon's vmfb"
)
parser.add_argument(
"--falcon_mlir_path",
default=None,
help="path to falcon's mlir file",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
parser.add_argument(
"--cli",
default=True,
action=argparse.BooleanOptionalAction,
help="Run model in cli mode",
)
class Falcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="tiiuae/falcon-7b-instruct",
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=Path("falcon.mlir"),
falcon_vmfb_path=Path("falcon.vmfb"),
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_padding_length = 100
self.device = device
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.src_model = self.get_src_model()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, trust_remote_code=True
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 11
return tokenizer
def get_src_model(self):
print("Loading src model")
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return falcon_model
def compile_falcon(self):
vmfb = get_vmfb_from_path(self.falcon_vmfb_path, self.device, "linalg")
if vmfb is not None:
return vmfb
print(
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
if args.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/falcon/7b/cuda/falcon.mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
if not mlir_generated:
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 100)
)
compilation_attention_mask = torch.ones(
1, 100, dtype=torch.int64
)
falconCompileInput = (
compilation_input_ids,
compilation_attention_mask,
)
model = FalconModel(self.src_model)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
falconCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
print(f"[DEBUG] generating torch mlir")
module = torch_mlir.compile(
ts_graph,
[*falconCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module
print(f"[DEBUG] writing mlir to file")
with open(f"{self.model_name}.mlir", "wb") as f_:
with redirect_stdout(f_):
print(module.operation.get_asm())
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile(self):
if (
not self.falcon_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/falcon/7b/cuda/falcon.vmfb",
self.falcon_vmfb_path.absolute(),
single_file=True,
)
falcon_shark_model = self.compile_falcon()
return falcon_shark_model
def generate(self, prompt):
model_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.max_padding_length,
add_special_tokens=False,
return_tensors="pt",
)
model_inputs["prompt_text"] = prompt
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
generate_kwargs = {
"max_length": self.max_num_tokens,
"do_sample": True,
"top_k": 10,
"num_return_sequences": 1,
"eos_token_id": 11,
}
generate_kwargs["input_ids"] = input_ids
generate_kwargs["attention_mask"] = attention_mask
generation_config_ = GenerationConfig.from_model_config(
self.src_model.config
)
generation_config = copy.deepcopy(generation_config_)
model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
stopping_criteria = StoppingCriteriaList()
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
model_kwargs,
) = self.src_model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
self.logits_processor = self.src_model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids.shape[-1],
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.src_model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
)
self.logits_warper = self.src_model._get_logits_warper(
generation_config
)
(
self.input_ids,
self.model_kwargs,
) = self.src_model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
**model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id) if eos_token_id is not None else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
output_attentions = generation_config.output_attentions # False
output_hidden_states = generation_config.output_hidden_states # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
all_text = prompt
for i in range(self.max_num_tokens - 1):
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
all_text = all_text + new_word
print(f"{new_word}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(input_ids, self.scores)
):
break
torch.cuda.empty_cache()
gc.collect()
return all_text
def generate_new_token(self):
model_inputs = self.src_model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = torch.from_numpy(
self.shark_model(
"forward",
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
if __name__ == "__main__":
args = parser.parse_args()
falcon_mlir_path = (
Path("falcon.mlir")
if args.falcon_mlir_path is None
else Path(args.falcon_mlir_path)
)
falcon_vmfb_path = (
Path("falcon.vmfb")
if args.falcon_vmfb_path is None
else Path(args.falcon_vmfb_path)
)
falcon = Falcon(
"falcon",
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
falcon_vmfb_path=falcon_vmfb_path,
)
import gc
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
continue_execution = True
while continue_execution:
use_default_prompt = input(
"\nDo you wish to use the default prompt text? True or False?: "
)
if use_default_prompt:
prompt = default_prompt_text
else:
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)
res_str = falcon.generate(prompt)
torch.cuda.empty_cache()
gc.collect()
print(
"\n\n-----\nHere's the complete formatted result: \n\n",
res_str,
)
continue_execution = input(
"\nDo you wish to run script one more time? True or False?: "
)

View File

@@ -0,0 +1,185 @@
import torch
import torch_mlir
from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
from io import BytesIO
from pathlib import Path
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
get_vmfb_from_path,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.stablelm_model import (
StableLMModel,
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
class SharkStableLM(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
max_num_tokens=512,
device="cuda",
precision="fp32",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
self.precision = precision
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
def shouldStop(self, tokens):
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if tokens[0][-1] == stop_id:
return True
return False
def get_src_model(self):
model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, torch_dtype=torch.float32
)
return model
def get_model_inputs(self):
input_ids = torch.randint(3, (1, self.max_sequence_len))
attention_mask = torch.randint(3, (1, self.max_sequence_len))
return input_ids, attention_mask
def compile(self):
tmp_model_name = (
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
)
# device = "cuda" # "cpu"
# TODO: vmfb and mlir name should include precision and device
model_vmfb_name = None
vmfb_path = (
Path(tmp_model_name + f"_{self.device}.vmfb")
if model_vmfb_name is None
else Path(model_vmfb_name)
)
shark_module = get_vmfb_from_path(
vmfb_path, self.device, mlir_dialect="tm_tensor"
)
if shark_module is not None:
return shark_module
mlir_path = Path(tmp_model_name + ".mlir")
print(
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
model = StableLMModel(self.get_src_model())
model_inputs = self.get_model_inputs()
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(tmp_model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
)
print("Saved vmfb at ", str(path))
return shark_module
def get_tokenizer(self):
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
return tok
def generate(self, prompt):
words_list = []
for i in range(self.max_num_tokens):
params = {
"new_text": prompt,
}
generated_token_op = self.generate_new_token(params)
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]
if stop_generation:
break
print(detok, end="", flush=True) # this is for CLI and DEBUG
words_list.append(detok)
if detok == "":
break
prompt = prompt + detok
return words_list
def generate_new_token(self, params):
new_text = params["new_text"]
model_inputs = self.tokenizer(
[new_text],
padding="max_length",
max_length=self.max_sequence_len,
truncation=True,
return_tensors="pt",
)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
output = self.shark_model(
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
)
output = torch.from_numpy(output)
next_toks = torch.topk(output, 1)
stop_generation = False
if self.shouldStop(next_toks.indices):
stop_generation = True
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
detok = self.tokenizer.decode(
new_token,
skip_special_tokens=True,
)
ret_dict = {
"new_token": new_token,
"detok": detok,
"stop_generation": stop_generation,
}
return ret_dict
# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""

View File

@@ -0,0 +1,559 @@
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.utils import (
get_vmfb_from_path,
)
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import torch
import torch_mlir
import os
class Vicuna(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
max_num_tokens=512,
device="cuda",
precision="fp32",
first_vicuna_mlir_path=Path("first_vicuna.mlir"),
second_vicuna_mlir_path=Path("second_vicuna.mlir"),
first_vicuna_vmfb_path=Path("first_vicuna.vmfb"),
second_vicuna_vmfb_path=Path("second_vicuna.vmfb"),
load_mlir_from_shark_tank=True,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
self.first_vicuna_mlir_path = first_vicuna_mlir_path
self.second_vicuna_mlir_path = second_vicuna_mlir_path
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return vicuna_model
def compile_first_vicuna(self):
vmfb = get_vmfb_from_path(
self.first_vicuna_vmfb_path, self.device, "tm_tensor"
)
if vmfb is not None:
return vmfb
# Compilation path needs some more work before it is functional
print(
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.first_vicuna_mlir_path.exists():
with open(self.first_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/vicuna/unsharded/mlir/first_vicuna.mlir",
self.first_vicuna_mlir_path.absolute(),
single_file=True,
)
if self.first_vicuna_mlir_path.exists():
with open(self.first_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
if not mlir_generated:
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = self.tokenizer(
compilation_prompt
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(self.hf_model_path)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
print(f"[DEBUG] generating torch mlir")
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
def remove_constant_dim(line):
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim)", line
)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
return line
module = str(module)
new_lines = []
print(f"[DEBUG] rewriting torch_mlir file")
for line in module.splitlines():
line = remove_constant_dim(line)
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
new_lines.append(
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
)
if (
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
in line
):
continue
new_lines.append(line)
module = "\n".join(new_lines)
print(f"[DEBUG] converting to bytecode")
del new_lines
module = module.encode("UTF-8")
module = BytesIO(module)
bytecode = module.read()
del module
print(f"[DEBUG] writing mlir to file")
f_ = open(self.first_vicuna_mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
path = shark_module.save_module(
self.first_vicuna_vmfb_path.parent.absolute(),
self.first_vicuna_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
print("Saved first vic vmfb at vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile_second_vicuna(self):
vmfb = get_vmfb_from_path(
self.second_vicuna_vmfb_path, self.device, "tm_tensor"
)
if vmfb is not None:
return vmfb
# Compilation path needs some more work before it is functional
print(
f"[DEBUG] mlir path {self.second_vicuna_mlir_path} {'exists' if self.second_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.second_vicuna_mlir_path.exists():
with open(self.second_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/vicuna/unsharded/mlir/second_vicuna.mlir",
self.second_vicuna_mlir_path.absolute(),
single_file=True,
)
if self.second_vicuna_mlir_path.exists():
with open(self.second_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
if not mlir_generated:
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(self.hf_model_path)
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[
i
] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
def remove_constant_dim(line):
if "c19_i64" in line:
line = re.sub("c19_i64", "dim_i64", line)
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim)", line
)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
if "20x" in line:
line = re.sub("20x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dimp1)", line
)
if " 20," in line:
line = re.sub(" 20,", " %dimp1,", line)
return line
module_str = str(module)
new_lines = []
for line in module_str.splitlines():
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128xf32>"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
)
continue
if "%c2 = arith.constant 2 : index" in line:
continue
if "%c20_i64 = arith.constant 20 : i64" in line:
new_lines.append("%c1_i64 = arith.constant 1 : i64")
new_lines.append(
"%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
)
new_lines.append(
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
)
continue
line = remove_constant_dim(line)
new_lines.append(line)
module_str = "\n".join(new_lines)
bytecode = module_str.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(self.second_vicuna_mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
path = shark_module.save_module(
self.second_vicuna_vmfb_path.parent.absolute(),
self.second_vicuna_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
print("Saved vmfb at ", str(path))
shark_module.load_module(self.second_vicuna_vmfb_path)
# self.shark_module = shark_module
return shark_module
def compile(self):
# Cannot load both the models in the memory at once
# due to memory constraints, hence on demand compilation
# is being used until the space is enough for both models
# Testing : DO NOT Download Vmfbs if not found. Modify later
# download vmfbs for A100
if (
not self.first_vicuna_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/vicuna/unsharded/first_vicuna.vmfb",
self.first_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get first vic
# TODO: Remove after testing to avoid memory overload
# fvic_shark_model = self.compile_first_vicuna()
pass
if (
not self.second_vicuna_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/vicuna/unsharded/second_vicuna.vmfb",
self.second_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get second vic
# TODO: Remove after testing to avoid memory overload
# svic_shark_model = self.compile_second_vicuna()
pass
# get first vic
# fvic_shark_model = self.compile_first_vicuna()
# get second vic
# svic_shark_model = self.compile_second_vicuna()
# return tuple of shark_modules
# return fvic_shark_model, svic_shark_model
return None
# return tuple of shark_modules once mem is supported
# return fvic_shark_model, svic_shark_model
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
import gc
res = []
res_tokens = []
params = {
"prompt": prompt,
"is_first": True,
"fv": self.compile_first_vicuna(),
}
generated_token_op = self.generate_new_token(params=params)
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
res.append(detok)
res_tokens.append(token)
if cli:
print(f"Assistant: {detok}", end=" ", flush=True)
# Clear First Vic from Memory (main and cuda)
del params
torch.cuda.empty_cache()
gc.collect()
sec_vic = self.compile_second_vicuna()
for _ in range(self.max_num_tokens - 2):
params = {
"prompt": None,
"is_first": False,
"logits": logits,
"pkv": pkv,
"sv": sec_vic,
}
generated_token_op = self.generate_new_token(params=params)
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
if token == 2:
break
res_tokens.append(token)
if detok == "<0x0A>":
res.append("\n")
if cli:
print("\n", end="", flush=True)
else:
res.append(detok)
if cli:
print(f"{detok}", end=" ", flush=True)
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = self.tokenizer.decode(res_tokens)
# print(f"[DEBUG] final output : \n{res_str}")
return res_str
def generate_new_token(self, params, debug=False):
def forward_first(first_vic, prompt, cache_outputs=False):
input_ids = self.tokenizer(prompt).input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
firstVicunaInput = (input_ids,)
assert first_vic is not None
output_first_vicuna = first_vic("forward", firstVicunaInput)
output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:])
logits_first_vicuna = torch.tensor(output_first_vicuna[0])
if cache_outputs:
torch.save(
logits_first_vicuna, "logits_first_vicuna_tensor.pt"
)
torch.save(
output_first_vicuna_tensor, "output_first_vicuna_tensor.pt"
)
token = torch.argmax(
torch.tensor(logits_first_vicuna)[:, -1, :], dim=1
)
return token, logits_first_vicuna, output_first_vicuna_tensor
def forward_second(sec_vic, inputs=None, load_inputs=False):
if inputs is not None:
logits = inputs[0]
pkv = inputs[1:]
elif load_inputs:
pkv = torch.load("output_first_vicuna_tensor.pt")
pkv = tuple(torch.tensor(x) for x in pkv)
logits = torch.load("logits_first_vicuna_tensor.pt")
else:
print(
"Either inputs must be given, or load_inputs must be true"
)
return None
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
secondVicunaOutput = sec_vic("forward", secondVicunaInput)
new_pkv = secondVicunaOutput[1:]
new_logits = secondVicunaOutput[0]
new_token = torch.argmax(torch.tensor(new_logits)[:, -1, :], dim=1)
return new_token, new_logits, new_pkv
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
fv = params["fv"]
token, logits, pkv = forward_first(
fv, # self.shark_model[0],
prompt=prompt,
cache_outputs=False,
)
else:
_logits = params["logits"]
_pkv = params["pkv"]
inputs = (_logits,) + tuple(_pkv)
sv = params["sv"]
token, logits, pkv = forward_second(
sv, # self.shark_model[1],
inputs=inputs,
load_inputs=False,
)
detok = self.tokenizer.decode(token)
if debug:
print(
f"[DEBUG] is_first: {is_first} |"
f" token : {token} | detok : {detok}"
)
ret_dict = {
"token": token,
"logits": logits,
"pkv": pkv,
"detok": detok,
}
return ret_dict
def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
pass

View File

@@ -0,0 +1,408 @@
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
FirstVicunaLayer,
SecondVicunaLayer,
CompiledFirstVicunaLayer,
CompiledSecondVicunaLayer,
ShardedVicunaModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from shark.shark_importer import import_with_fx
from io import BytesIO
from pathlib import Path
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from torch_mlir import TensorPlaceholder
import re
import torch
import torch_mlir
import os
class Vicuna(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
max_num_tokens=512,
device="cuda",
precision="fp32",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return vicuna_model
def write_in_dynamic_inputs0(self, module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def write_in_dynamic_inputs1(self, module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
if "dim_42 =" in line:
continue
if f"%c{dynamic_input_size}_i64 =" in line:
new_lines.append(
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
)
new_lines.append(
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
)
continue
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim_42)", line
)
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim_42\)",
"tensor.empty(%dim_42, %dim_42)",
line,
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def compile_vicuna_layer(
self,
vicuna_layer,
hidden_states,
attention_mask,
position_ids,
past_key_value0=None,
past_key_value1=None,
):
if past_key_value0 is None and past_key_value1 is None:
model_inputs = (hidden_states, attention_mask, position_ids)
else:
model_inputs = (
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
)
mlir_bytecode = import_with_fx(
vicuna_layer,
model_inputs,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
return mlir_bytecode
def compile_to_vmfb(self, inputs, layers, is_first=True):
mlirs, modules = [], []
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
if is_first:
mlir_path = Path(f"{idx}_0.mlir")
vmfb_path = Path(f"{idx}_0.vmfb")
else:
mlir_path = Path(f"{idx}_1.mlir")
vmfb_path = Path(f"{idx}_1.vmfb")
if vmfb_path.exists():
continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
hidden_states_placeholder = TensorPlaceholder.like(
inputs[0], dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
inputs[1], dynamic_axes=[3]
)
position_ids_placeholder = TensorPlaceholder.like(
inputs[2], dynamic_axes=[1]
)
if not is_first:
pkv0_placeholder = TensorPlaceholder.like(
inputs[3], dynamic_axes=[2]
)
pkv1_placeholder = TensorPlaceholder.like(
inputs[4], dynamic_axes=[2]
)
print(f"Compiling layer {idx} mlir")
if is_first:
ts_g = self.compile_vicuna_layer(
layer, inputs[0], inputs[1], inputs[2]
)
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
inputs[1],
inputs[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
else:
ts_g = self.compile_vicuna_layer(
layer,
inputs[0],
inputs[1],
inputs[2],
inputs[3],
inputs[4],
)
module = torch_mlir.compile(
ts_g,
(
inputs[0],
attention_mask_placeholder,
inputs[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
# bytecode_stream = BytesIO()
# module.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
if is_first:
module = self.write_in_dynamic_inputs0(str(module), 137)
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
else:
module = self.write_in_dynamic_inputs1(str(module), 138)
if idx in [0, 5, 6, 7]:
module_str = module
module_str = module_str.splitlines()
new_lines = []
for line in module_str:
if len(line) < 1000:
new_lines.append(line)
else:
new_lines.append(line[:999])
module_str = "\n".join(new_lines)
f1_ = open(f"{idx}_1_test.mlir", "w+")
f1_.write(module_str)
f1_.close()
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
mlirs.append(bytecode)
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
if is_first:
vmfb_path = Path(f"{idx}_0.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_0",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
else:
vmfb_path = Path(f"{idx}_1.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_1",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
return mlirs, modules
def get_sharded_model(self):
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it
SAMPLE_INPUT_LEN = 137
vicuna_model = self.get_src_model()
placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
)
placeholder_input1 = (
torch.zeros([1, 1, 4096]),
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
torch.zeros([1, 1], dtype=torch.int64),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
layers0 = [
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules0 = self.compile_to_vmfb(
placeholder_input0, layers0, is_first=True
)
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
layers1 = [
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules1 = self.compile_to_vmfb(
placeholder_input1, layers1, is_first=False
)
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
sharded_model = ShardedVicunaModel(
vicuna_model, shark_layers0, shark_layers1
)
return sharded_model
def compile(self):
return self.get_sharded_model()
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
tokens_generated = []
_past_key_values = None
_token = None
detoks_generated = []
for iteration in range(self.max_num_tokens):
params = {
"prompt": prompt,
"is_first": iteration == 0,
"token": _token,
"past_key_values": _past_key_values,
}
generated_token_op = self.generate_new_token(params=params)
_token = generated_token_op["token"]
_past_key_values = generated_token_op["past_key_values"]
_detok = generated_token_op["detok"]
if _token == 2:
break
detoks_generated.append(_detok)
tokens_generated.append(_token)
for i in range(len(tokens_generated)):
if type(tokens_generated[i]) != int:
tokens_generated[i] = int(tokens_generated[i][0])
result_output = self.tokenizer.decode(tokens_generated)
return result_output
def generate_new_token(self, params):
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
input_ids = self.tokenizer(prompt).input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(input_ids, is_first=is_first)
else:
token = params["token"]
past_key_values = params["past_key_values"]
input_ids = [token]
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(
input_ids, past_key_values=past_key_values, is_first=is_first
)
_logits = output["logits"]
_past_key_values = output["past_key_values"]
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
_detok = self.tokenizer.decode(_token)
ret_dict = {
"token": _token,
"detok": _detok,
"past_key_values": _past_key_values,
}
print(f" token : {_token} | detok : {_detok}")
return ret_dict
def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
pass

View File

@@ -0,0 +1,25 @@
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from pathlib import Path
# expects a Path / str as arg
# returns None if path not found or SharkInference module
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
if not isinstance(vmfb_path, Path):
vmfb_path = Path(vmfb_path)
from shark.shark_inference import SharkInference
if not vmfb_path.exists():
return None
print("Loading vmfb from: ", vmfb_path)
shark_module = SharkInference(
None, device=device, mlir_dialect=mlir_dialect
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
return shark_module

View File

@@ -10,7 +10,7 @@ Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=mhlo for tf models
# use iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb

View File

@@ -1,6 +1 @@
from apps.stable_diffusion.scripts.txt2img import txt2img_inf
from apps.stable_diffusion.scripts.img2img import img2img_inf
from apps.stable_diffusion.scripts.inpaint import inpaint_inf
from apps.stable_diffusion.scripts.outpaint import outpaint_inf
from apps.stable_diffusion.scripts.upscaler import upscaler_inf
from apps.stable_diffusion.scripts.train_lora_word import lora_train

View File

@@ -2,276 +2,22 @@ import sys
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
StencilPipeline,
resize_stencil,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# Both width and height should be > 384 and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image):
width, height = image.size
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 384:
n_size = 384
if width == min_size:
width = n_size
height = n_size / aspect_ratio
else:
height = n_size
width = n_size * aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
new_image = image.resize((n_width, n_height))
return new_image, n_width, n_height
# Exposed to UI.
def img2img_inf(
prompt: str,
negative_prompt: str,
init_image,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.strength = strength
args.scheduler = scheduler
args.img_path = "not none"
if init_image is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB")
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_stencil = None if use_stencil == "None" else use_stencil
args.use_stencil = use_stencil
if use_stencil is not None:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler"
)
args.scheduler = "PNDM"
else:
sys.exit(
"Img2Img works best with PNDM scheduler. Other schedulers are not supported yet."
)
cpu_scheduling = not args.scheduler.startswith("Shark")
args.precision = precision
dtype = torch.float32 if precision == "fp32" else torch.half
new_config_obj = Config(
"img2img",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=use_lora,
use_stencil=use_stencil,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
if use_stencil is not None:
args.use_tuned = False
global_obj.set_sd_obj(
StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
)
else:
global_obj.set_sd_obj(
Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"STRENGTH": strength}
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
batch_size,
height,
width,
steps,
strength,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, strength={args.strength}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -288,16 +34,11 @@ if __name__ == "__main__":
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, args.width, args.height = resize_stencil(image)
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler"
)
args.scheduler = "PNDM"
else:
sys.exit(
"Img2Img works best with PNDM scheduler. Other schedulers are not supported yet."
)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
@@ -324,6 +65,7 @@ if __name__ == "__main__":
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
@@ -342,6 +84,7 @@ if __name__ == "__main__":
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
start_time = time.time()
@@ -377,3 +120,7 @@ if __name__ == "__main__":
extra_info = {"STRENGTH": args.strength}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -10,196 +11,10 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def inpaint_inf(
prompt: str,
negative_prompt: str,
image_dict,
height: int,
width: int,
inpaint_full_res: bool,
inpaint_full_res_padding: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
args.mask_path = "not none"
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"inpaint",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=use_lora,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_sd_obj(
InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
image = image_dict["image"]
mask_image = image_dict["mask"]
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
mask_image,
batch_size,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -229,6 +44,7 @@ if __name__ == "__main__":
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
@@ -236,10 +52,10 @@ if __name__ == "__main__":
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
@@ -282,3 +98,7 @@ if __name__ == "__main__":
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,19 @@
from apps.stable_diffusion.src import args
from apps.stable_diffusion.scripts import (
img2img,
txt2img,
# inpaint,
# outpaint,
)
if __name__ == "__main__":
if args.app == "txt2img":
txt2img.main()
elif args.app == "img2img":
img2img.main()
# elif args.app == "inpaint":
# inpaint.main()
# elif args.app == "outpaint":
# outpaint.main()
else:
print(f"args.app value is {args.app} but this isn't supported")

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
@@ -12,203 +13,7 @@ from apps.stable_diffusion.src import (
)
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def outpaint_inf(
prompt: str,
negative_prompt: str,
init_image,
pixels: int,
mask_blur: int,
directions: list,
noise_q: float,
color_variation: float,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"outpaint",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=use_lora,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_sd_obj(
OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
left = True if "left" in directions else False
right = True if "right" in directions else False
top = True if "up" in directions else False
bottom = True if "down" in directions else False
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
init_image,
pixels,
mask_blur,
left,
right,
top,
bottom,
noise_q,
color_variation,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -243,6 +48,7 @@ if __name__ == "__main__":
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
@@ -307,3 +113,7 @@ if __name__ == "__main__":
}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -73,6 +73,7 @@ from apps.stable_diffusion.src import (
set_init_device_flags,
clear_all,
)
from apps.stable_diffusion.src.utils import update_lora_weight
# Setup the dataset
@@ -159,7 +160,19 @@ class LoraDataset(Dataset):
return example
schedulers = None
def torch_device(device):
device_tokens = device.split("=>")
if len(device_tokens) == 1:
device_str = device_tokens[0].strip()
else:
device_str = device_tokens[1].strip()
device_type_tokens = device_str.split("://")
if device_type_tokens[0] == "metal":
device_type_tokens[0] = "vulkan"
if len(device_type_tokens) > 1:
return device_type_tokens[0] + ":" + device_type_tokens[1]
else:
return device_type_tokens[0]
########## Setting up the model ##########
@@ -180,6 +193,7 @@ def lora_train(
max_length: int,
training_images_dir: str,
lora_save_dir: str,
use_lora: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -187,8 +201,6 @@ def lora_train(
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
print(
"Note LoRA training is not compatible with the latest torch-mlir branch"
)
@@ -227,7 +239,8 @@ def lora_train(
args.max_length = max_length
args.height = height
args.width = width
args.device = device
args.device = torch_device(device)
args.use_lora = use_lora
# Load the Stable Diffusion model
text_encoder = CLIPTextModel.from_pretrained(
@@ -252,29 +265,33 @@ def lora_train(
unet.to(args.device)
text_encoder.to(args.device)
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if use_lora != "":
update_lora_weight(unet, args.use_lora, "unet")
else:
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)
unet.set_attn_processor(lora_attn_procs)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
class VaeModel(torch.nn.Module):
@@ -671,4 +688,5 @@ if __name__ == "__main__":
args.max_length,
args.training_images_dir,
args.lora_save_dir,
args.use_lora,
)

View File

@@ -0,0 +1,126 @@
import os
from pathlib import Path
from shark_tuner.codegen_tuner import SharkCodegenTuner
from shark_tuner.iree_utils import (
dump_dispatches,
create_context,
export_module_to_mlir_file,
)
from shark_tuner.model_annotation import model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.utils import set_init_device_flags
from apps.stable_diffusion.src.utils.sd_annotation import (
get_device_args,
load_winograd_configs,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def load_mlir_module():
sd_model = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
max_len=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=False,
low_cpu_mem_usage=args.low_cpu_mem_usage,
return_mlir=True,
)
if args.annotation_model == "unet":
mlir_module = sd_model.unet()
model_name = sd_model.model_name["unet"]
elif args.annotation_model == "vae":
mlir_module = sd_model.vae()
model_name = sd_model.model_name["vae"]
else:
raise ValueError(
f"{args.annotation_model} is not supported for tuning."
)
return mlir_module, model_name
def main():
args.use_tuned = False
set_init_device_flags()
mlir_module, model_name = load_mlir_module()
# Get device and device specific arguments
device, device_spec_args = get_device_args()
device_spec = ""
vulkan_target_triple = ""
if device_spec_args:
device_spec = device_spec_args[-1].split("=")[-1].strip()
if device == "vulkan":
vulkan_target_triple = device_spec
device_spec = device_spec.split("-")[0]
# Add winograd annotation for vulkan device
use_winograd = (
True
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
else False
)
winograd_config = (
load_winograd_configs()
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
else ""
)
with create_context() as ctx:
input_module = model_annotation(
ctx,
input_contents=mlir_module,
config_path=winograd_config,
search_op="conv",
winograd=use_winograd,
)
# Dump model dispatches
generates_dir = Path.home() / "tmp"
if not os.path.exists(generates_dir):
os.makedirs(generates_dir)
dump_mlir = generates_dir / "temp.mlir"
dispatch_dir = generates_dir / f"{model_name}_{device_spec}_dispatches"
export_module_to_mlir_file(input_module, dump_mlir)
dump_dispatches(
dump_mlir,
device,
dispatch_dir,
vulkan_target_triple,
use_winograd=use_winograd,
)
# Tune each dispatch
dtype = "f16" if args.precision == "fp16" else "f32"
config_filename = f"{model_name}_{device_spec}_configs.json"
for f_path in os.listdir(dispatch_dir):
if not f_path.endswith(".mlir"):
continue
model_dir = os.path.join(dispatch_dir, f_path)
tuner = SharkCodegenTuner(
model_dir,
device,
"random",
args.num_iters,
args.tuned_config_dir,
dtype,
args.search_op,
batch_size=1,
config_filename=config_filename,
use_dispatch=True,
vulkan_target_triple=vulkan_target_triple,
)
tuner.tune()
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,5 @@
import torch
import transformers
import time
from apps.stable_diffusion.src import (
args,
@@ -11,186 +12,7 @@ from apps.stable_diffusion.src import (
)
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"txt2img",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=use_lora,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_sd_obj(
Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += (
f"\nsteps={steps}, guidance_scale={guidance_scale}, seed={seeds}"
)
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
# text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -200,7 +22,6 @@ if __name__ == "__main__":
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
use_lora = args.use_lora
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
@@ -216,7 +37,9 @@ if __name__ == "__main__":
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
@@ -256,3 +79,7 @@ if __name__ == "__main__":
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
@@ -12,187 +13,6 @@ from apps.stable_diffusion.src import (
)
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def upscaler_inf(
prompt: str,
negative_prompt: str,
init_image,
height: int,
width: int,
steps: int,
noise_level: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.scheduler = scheduler
if init_image is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB").resize((height, width))
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
args.height = 128
args.width = 128
new_config_obj = Config(
"upscaler",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
args.height,
args.width,
device,
use_lora=None,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_size = batch_size
args.max_length = max_length
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_sd_obj(
UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.get_sd_obj().low_res_scheduler = schedulers["DDPM"]
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"NOISE LEVEL": noise_level}
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
low_res_img = image
high_res_img = Image.new("RGB", (height * 4, width * 4))
for i in range(0, width, 128):
for j in range(0, height, 128):
box = (j, i, j + 128, i + 128)
upscaled_image = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
low_res_img.crop(box),
batch_size,
args.height,
args.width,
steps,
noise_level,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
save_output_img(high_res_img, img_seed, extra_info)
generated_imgs.append(high_res_img)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
if args.clear_all:
clear_all()
@@ -232,7 +52,9 @@ if __name__ == "__main__":
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ddpm_scheduler=schedulers["DDPM"],
ondemand=args.ondemand,
)
start_time = time.time()

View File

@@ -21,12 +21,17 @@ datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv-python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('tkinter')
datas += collect_data_files('webview')
datas += collect_data_files('sentencepiece')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
@@ -42,6 +47,7 @@ block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['web/index.py'],

View File

@@ -22,8 +22,10 @@ datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
@@ -40,9 +42,10 @@ block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['scripts/txt2img.py'],
['scripts/main.py'],
pathex=['.'],
binaries=binaries,
datas=datas,

View File

@@ -5,6 +5,7 @@ from apps.stable_diffusion.src.utils import (
get_available_devices,
clear_all,
save_output_img,
resize_stencil,
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,

View File

@@ -1,9 +1,11 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from collections import defaultdict
from pathlib import Path
import torch
import safetensors.torch
import traceback
import subprocess
import sys
import os
from apps.stable_diffusion.src.utils import (
@@ -11,13 +13,14 @@ from apps.stable_diffusion.src.utils import (
get_opt_flags,
base_models,
args,
fetch_or_delete_vmfbs,
preprocessCKPT,
convert_original_vae,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
get_stencil_model_id,
update_lora_weight,
)
@@ -54,29 +57,9 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
return new_shape
# Get the input info for various models i.e. "unet", "clip", "vae", "vae_encode".
def get_input_info(model_info, max_len, width, height, batch_size):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = defaultdict(list)
for k in model_info:
for inp in model_info[k]:
shape = model_info[k][inp]["shape"]
dtype = dtype_config[model_info[k][inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, max_len, width, height, batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map[k].append(tensor)
return input_map
def check_compilation(model, model_name):
if not model:
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
class SharkifyStableDiffusionModel:
@@ -99,7 +82,9 @@ class SharkifyStableDiffusionModel:
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
use_lora: str = ""
use_lora: str = "",
use_quantize: str = None,
return_mlir: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -107,11 +92,21 @@ class SharkifyStableDiffusionModel:
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.use_quantize = use_quantize
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
if not os.path.isfile(weights_path):
subprocess.run(["wget", custom_weights, "-O", weights_path])
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
@@ -145,18 +140,32 @@ class SharkifyStableDiffusionModel:
self.use_lora = use_lora
print(self.model_name)
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
self.sharktank_dir = sharktank_dir
self.generate_vmfb = generate_vmfb
def get_extended_name_for_all_model(self, mask_to_fetch):
self.inputs = dict()
self.model_to_run = ""
if self.custom_weights != "":
self.model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights, self.is_inpaint)
else:
self.model_to_run = args.hf_model_id
self.custom_vae = self.process_custom_vae()
self.base_model_id = fetch_and_update_base_model_id(self.model_to_run)
if self.base_model_id != "" and args.ckpt_loc != "":
args.hf_model_id = self.base_model_id
self.return_mlir = return_mlir
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
index = 0
for model in sub_model_list:
if mask_to_fetch[index] == False:
index += 1
continue
sub_model = model
model_config = self.model_name
if "vae" == model:
@@ -164,6 +173,8 @@ class SharkifyStableDiffusionModel:
model_config = model_config + get_path_stem(self.custom_vae)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
model_config = model_config + get_path_stem(self.use_stencil)
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
@@ -176,6 +187,29 @@ class SharkifyStableDiffusionModel:
if not (height % 8 == 0 and height >= 128):
sys.exit("height should be greater than 128 and multiple of 8")
# Get the input info for a model i.e. "unet", "clip", "vae", etc.
def get_input_info_for(self, model_info):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = []
for inp in model_info:
shape = model_info[inp]["shape"]
dtype = dtype_config[model_info[inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, self.max_len, self.width, self.height, self.batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map.append(tensor)
return input_map
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
@@ -192,16 +226,20 @@ class SharkifyStableDiffusionModel:
vae_encode = VaeEncodeModel()
inputs = tuple(self.inputs["vae_encode"])
is_f16 = True if self.precision == "fp16" else False
shark_vae_encode = compile_through_fx(
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
shark_vae_encode, vae_encode_mlir = compile_through_fx(
vae_encode,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
model_name=self.model_name["vae_encode"],
extended_model_name=self.model_name["vae_encode"],
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae_encode",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae_encode
return shark_vae_encode, vae_encode_mlir
def get_vae(self):
class VaeModel(torch.nn.Module):
@@ -241,53 +279,31 @@ class SharkifyStableDiffusionModel:
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae = compile_through_fx(
shark_vae, vae_mlir = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
model_name=self.model_name["vae"],
extended_model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae
def get_vae_upscaler(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
def forward(self, input):
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
return x
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
shark_vae = compile_through_fx(
vae,
inputs,
use_tuned=self.use_tuned,
model_name=self.model_name["vae"],
extra_args=get_opt_flags("vae", precision="fp32"),
)
return shark_vae
return shark_vae, vae_mlir
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -295,6 +311,8 @@ class SharkifyStableDiffusionModel:
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
@@ -323,18 +341,22 @@ class SharkifyStableDiffusionModel:
unet = ControlledUnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_unet"])
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
shark_controlled_unet = compile_through_fx(
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
model_name=self.model_name["stencil_unet"],
extended_model_name=self.model_name["stencil_unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_controlled_unet
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self):
class StencilControlNetModel(torch.nn.Module):
@@ -378,16 +400,20 @@ class SharkifyStableDiffusionModel:
inputs = tuple(self.inputs["stencil_adaptor"])
input_mask = [True, True, True, True]
shark_cnet = compile_through_fx(
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
model_name=self.model_name["stencil_adaptor"],
extended_model_name=self.model_name["stencil_adaptor"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_adaptor",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_cnet
return shark_cnet, cnet_mlir
def get_unet(self):
class UnetModel(torch.nn.Module):
@@ -399,7 +425,7 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
self.unet.load_attn_procs(use_lora)
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
@@ -433,10 +459,10 @@ class SharkifyStableDiffusionModel:
save_dir,
exist_ok=True,
)
shark_unet = compile_through_fx(
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
model_name=self.model_name["unet"],
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
@@ -444,8 +470,12 @@ class SharkifyStableDiffusionModel:
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet
return shark_unet, unet_mlir
def get_unet_upscaler(self):
class UnetModel(torch.nn.Module):
@@ -473,26 +503,32 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
shark_unet = compile_through_fx(
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
model_name=self.model_name["unet"],
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet
return shark_unet, unet_mlir
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
def forward(self, input):
return self.text_encoder(input)[0]
@@ -504,16 +540,20 @@ class SharkifyStableDiffusionModel:
save_dir,
exist_ok=True,
)
shark_clip = compile_through_fx(
shark_clip, clip_mlir = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
model_name=self.model_name["clip"],
extended_model_name=self.model_name["clip"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip
return shark_clip, clip_mlir
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
@@ -532,134 +572,115 @@ class SharkifyStableDiffusionModel:
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
if "state_dict" in vae_checkpoint:
vae_checkpoint = vae_checkpoint["state_dict"]
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
# configiration.
def compile_all(self, base_model_id, need_vae_encode, need_stencil):
self.inputs = get_input_info(
base_models[base_model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
if self.is_upscaler:
return self.get_clip(), self.get_unet_upscaler(), self.get_vae_upscaler()
compiled_controlnet = None
compiled_controlled_unet = None
compiled_unet = None
if need_stencil:
compiled_controlnet = self.get_control_net()
compiled_controlled_unet = self.get_controlled_unet()
else:
compiled_unet = self.get_unet()
if self.custom_vae != "":
print("Plugging in custom Vae")
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
if need_stencil:
return compiled_clip, compiled_controlled_unet, compiled_vae, compiled_controlnet
if need_vae_encode:
compiled_vae_encode = self.get_vae_encode()
return compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode
return compiled_clip, compiled_unet, compiled_vae
def __call__(self):
# Step 1:
# -- Fetch all vmfbs for the model, if present, else delete the lot.
need_vae_encode, need_stencil = False, False
if not self.is_upscaler and args.img_path is not None:
if self.use_stencil is not None:
need_stencil = True
else:
need_vae_encode = True
# `mask_to_fetch` prepares a mask to pick a combination out of :-
# ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
mask_to_fetch = [True, True, False, True, False, False]
if need_vae_encode:
mask_to_fetch = [True, True, False, True, True, False]
elif need_stencil:
mask_to_fetch = [True, False, True, True, False, True]
self.model_name = self.get_extended_name_for_all_model(mask_to_fetch)
vmfbs = fetch_or_delete_vmfbs(self.model_name, self.precision)
if vmfbs[0]:
# -- If all vmfbs are indeed present, we also try and fetch the base
# model configuration for running SD with custom checkpoints.
if self.custom_weights != "":
args.hf_model_id = fetch_and_update_base_model_id(self.custom_weights)
if args.hf_model_id == "":
sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.")
print("Loaded vmfbs from cache and successfully fetched base model configuration.")
return vmfbs
# Step 2:
# -- If vmfbs weren't found, we try to see if the base model configuration
# for the required SD run is known to us and bypass the retry mechanism.
model_to_run = ""
if self.custom_weights != "":
model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights, self.is_inpaint)
else:
model_to_run = args.hf_model_id
# For custom Vae user can provide either the repo-id or a checkpoint file,
# and for a checkpoint file we'd need to process it via Diffusers' script.
self.custom_vae = self.process_custom_vae()
base_model_fetched = fetch_and_update_base_model_id(model_to_run)
if base_model_fetched != "":
print("Compiling all the models with the fetched base model configuration.")
if args.ckpt_loc != "":
args.hf_model_id = base_model_fetched
return self.compile_all(base_model_fetched, need_vae_encode, need_stencil)
# Step 3:
# -- This is the retry mechanism where the base model's configuration is not
# known to us and figure that out by trial and error.
print("Inferring base model configuration.")
for model_id in base_models:
try:
if need_vae_encode:
compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode = self.compile_all(model_id, need_vae_encode, need_stencil)
elif need_stencil:
compiled_clip, compiled_unet, compiled_vae, compiled_controlnet = self.compile_all(model_id, need_vae_encode, need_stencil)
else:
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id, need_vae_encode, need_stencil)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
if need_vae_encode:
return (
compiled_clip,
compiled_unet,
compiled_vae,
compiled_vae_encode,
)
if need_stencil:
return (
compiled_clip,
compiled_unet,
compiled_vae,
compiled_controlnet,
)
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
)
vae_checkpoint = convert_original_vae(vae_checkpoint)
finally:
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
def compile_unet_variants(self, model):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler()
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
return get_unet()
else:
return self.get_unet()
else:
return self.get_controlled_unet()
def vae_encode(self):
try:
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
check_compilation(compiled_vae_encode, "Vae Encode")
if self.return_mlir:
return vae_encode_mlir
return compiled_vae_encode
except Exception as e:
sys.exit(e)
def clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(base_models["clip"])
compiled_clip, clip_mlir = self.get_clip()
check_compilation(compiled_clip, "Clip")
if self.return_mlir:
return clip_mlir
return compiled_clip
except Exception as e:
sys.exit(e)
def unet(self):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
compiled_unet = None
unet_inputs = base_models[model]
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet, unet_mlir = self.compile_unet_variants(model)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
try:
compiled_unet, unet_mlir = self.compile_unet_variants(model)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(self.model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
break
check_compilation(compiled_unet, "Unet")
if self.return_mlir:
return unet_mlir
return compiled_unet
except Exception as e:
sys.exit(e)
def vae(self):
try:
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
self.inputs["vae"] = self.get_input_info_for(vae_input)
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")
if self.return_mlir:
return vae_mlir
return compiled_vae
except Exception as e:
sys.exit(e)
def controlnet(self):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
check_compilation(compiled_stencil_adaptor, "Stencil")
if self.return_mlir:
return controlnet_mlir
return compiled_stencil_adaptor
except Exception as e:
sys.exit(e)

View File

@@ -20,6 +20,15 @@ hf_model_variant_map = {
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
bucket_key = "gs://shark_tank/prashant_nod"
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
@@ -41,6 +50,12 @@ def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
# TODO: Get the quantize model from model_db.json
if args.use_quantize == "int8":
bk, mk, flags = get_quantize_model()
return get_shark_model(bk, mk, flags)
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"

View File

@@ -20,16 +20,15 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class Image2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -40,9 +39,30 @@ class Image2ImagePipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_image_latents(
self,
@@ -89,9 +109,12 @@ class Image2ImagePipeline(StableDiffusionPipeline):
return latents, timesteps
def encode_image(self, input_image):
self.load_vae_encode()
vae_encode_start = time.time()
latents = self.vae_encode("forward", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
if self.ondemand:
self.unload_vae_encode()
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
return latents
@@ -131,8 +154,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -161,6 +186,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
@@ -168,5 +194,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -19,16 +19,15 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class InpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -39,9 +38,30 @@ class InpaintPipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
@@ -305,9 +325,12 @@ class InpaintPipeline(StableDiffusionPipeline):
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
@@ -383,8 +406,10 @@ class InpaintPipeline(StableDiffusionPipeline):
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -428,6 +453,7 @@ class InpaintPipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
@@ -435,6 +461,8 @@ class InpaintPipeline(StableDiffusionPipeline):
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
if inpaint_full_res:
output_image = self.apply_overlay(

View File

@@ -20,16 +20,15 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i
StableDiffusionPipeline,
)
import math
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class OutpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -40,9 +39,30 @@ class OutpaintPipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
@@ -123,9 +143,12 @@ class OutpaintPipeline(StableDiffusionPipeline):
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
@@ -384,8 +407,10 @@ class OutpaintPipeline(StableDiffusionPipeline):
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -506,6 +531,7 @@ class OutpaintPipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],

View File

@@ -20,16 +20,16 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import controlnet_hint_conversion
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class StencilPipeline(StableDiffusionPipeline):
def __init__(
self,
controlnet: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -39,9 +39,22 @@ class StencilPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.controlnet = controlnet
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
def load_controlnet(self):
if self.controlnet is not None:
return
self.controlnet = self.sd_model.controlnet()
def unload_controlnet(self):
del self.controlnet
self.controlnet = None
def prepare_latents(
self,
@@ -68,6 +81,113 @@ class StencilPipeline(StableDiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_stencil_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
controlnet_conditioning_scale: float = 1.0,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
self.load_controlnet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.ondemand:
self.unload_unet()
self.unload_controlnet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
@@ -108,8 +228,10 @@ class StencilPipeline(StableDiffusionPipeline):
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -134,11 +256,11 @@ class StencilPipeline(StableDiffusionPipeline):
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
controlnet=self.controlnet,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
@@ -146,5 +268,7 @@ class StencilPipeline(StableDiffusionPipeline):
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -1,5 +1,4 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
@@ -19,15 +18,12 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -39,8 +35,12 @@ class Text2ImagePipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
def prepare_latents(
self,
@@ -110,8 +110,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -128,12 +130,15 @@ class Text2ImagePipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
self.load_vae()
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -20,6 +20,8 @@ from diffusers import (
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_IDLE,
SD_STATE_CANCEL,
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import (
@@ -27,6 +29,7 @@ from apps.stable_diffusion.src.utils import (
end_profiling,
)
from PIL import Image
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def preprocess(image):
@@ -55,10 +58,6 @@ def preprocess(image):
class UpscalerPipeline(StableDiffusionPipeline):
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -80,9 +79,14 @@ class UpscalerPipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.low_res_scheduler = low_res_scheduler
self.status = SD_STATE_IDLE
def prepare_extra_step_kwargs(self, generator, eta):
accepts_eta = "eta" in set(
@@ -163,6 +167,8 @@ class UpscalerPipeline(StableDiffusionPipeline):
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.status = SD_STATE_IDLE
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
@@ -208,6 +214,11 @@ class UpscalerPipeline(StableDiffusionPipeline):
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -251,8 +262,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# 4. Preprocess image
image = preprocess(image).to(dtype)
@@ -299,6 +312,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
@@ -306,5 +320,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -20,7 +20,6 @@ from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
get_vae,
get_clip,
get_unet,
@@ -30,15 +29,15 @@ from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
import sys
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
class StableDiffusionPipeline:
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -50,14 +49,88 @@ class StableDiffusionPipeline:
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.unet = unet
self.vae = None
self.text_encoder = None
self.unet = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
self.sd_model = sd_model
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
try:
self.tokenizer = get_tokenizer()
except:
self.load_unet()
self.unload_unet()
self.tokenizer = get_tokenizer()
def load_clip(self):
if self.text_encoder is not None:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
)
self.text_encoder = self.sd_model.clip()
else:
try:
self.text_encoder = get_clip()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.text_encoder = self.sd_model.clip()
def unload_clip(self):
del self.text_encoder
self.text_encoder = None
def load_unet(self):
if self.unet is not None:
return
if self.import_mlir or self.use_lora:
self.unet = self.sd_model.unet()
else:
try:
self.unet = get_unet()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.unet = self.sd_model.unet()
def unload_unet(self):
del self.unet
self.unet = None
def load_vae(self):
if self.vae is not None:
return
if self.import_mlir or self.use_lora:
self.vae = self.sd_model.vae()
else:
try:
self.vae = get_vae()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.vae = self.sd_model.vae()
def unload_vae(self):
del self.vae
self.vae = None
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -77,12 +150,14 @@ class StableDiffusionPipeline:
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
self.load_clip()
clip_inf_start = time.time()
text_embeddings = self.text_encoder("forward", (text_input,))
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
@@ -111,109 +186,6 @@ class StableDiffusionPipeline:
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def produce_stencil_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
controlnet=None,
controlnet_conditioning_scale: float = 1.0,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def produce_img_latents(
self,
latents,
@@ -226,10 +198,12 @@ class StableDiffusionPipeline:
masked_image_latents=None,
return_all_latents=False,
):
self.status = SD_STATE_IDLE
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
@@ -275,6 +249,11 @@ class StableDiffusionPipeline:
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -308,115 +287,556 @@ class StableDiffusionPipeline:
width: int,
use_base_vae: bool,
use_tuned: bool,
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
if (
not import_mlir
and not use_lora
and cls.__name__ == "StencilPipeline"
):
sys.exit("StencilPipeline not supported with SharkTank currently.")
is_inpaint = cls.__name__ in [
"InpaintPipeline",
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
if import_mlir or use_lora:
if not import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
)
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
)
if cls.__name__ in [
"Image2ImagePipeline",
"InpaintPipeline",
"OutpaintPipeline",
]:
clip, unet, vae, vae_encode = mlir_import()
return cls(
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
)
if cls.__name__ in ["StencilPipeline"]:
clip, unet, vae, controlnet = mlir_import()
return cls(
controlnet, vae, clip, get_tokenizer(), unet, scheduler
)
if cls.__name__ in ["UpscalerPipeline"]:
clip, unet, vae = mlir_import()
return cls(
vae, clip, get_tokenizer(), unet, scheduler, ddpm_scheduler
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
try:
if cls.__name__ in [
"Image2ImagePipeline",
"InpaintPipeline",
"OutpaintPipeline",
]:
return cls(
get_vae_encode(),
get_vae(),
get_clip(),
get_tokenizer(),
get_unet(),
scheduler,
)
if cls.__name__ == "StencilPipeline":
import sys
sd_model = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
use_quantize=use_quantize,
)
sys.exit(
"StencilPipeline not supported with SharkTank currently."
)
if cls.__name__ in ["UpscalerPipeline"]:
return cls(
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
scheduler,
ddpm_scheduler,
sd_model,
import_mlir,
use_lora,
ondemand,
)
except:
print("download pipeline failed, falling back to import_mlir")
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
# #####################################################
# Implements text embeddings with weights from prompts
# https://huggingface.co/AlanB/lpw_stable_diffusion_mod
# #####################################################
def encode_prompts_weight(
self,
prompt,
negative_prompt,
model_max_length,
do_classifier_free_guidance=True,
max_embeddings_multiples=1,
num_images_per_prompt=1,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
model_max_length (int):
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not,
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
num_images_per_prompt (`int`):
number of images that should be generated per prompt
SHARK: num_images_per_prompt is not used (defaulted to 1)
"""
# SHARK: Save model_max_length, load the clip and init inference time
self.model_max_length = model_max_length
self.load_clip()
clip_inf_start = time.time()
batch_size = len(prompt) if isinstance(prompt, list) else 1
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
if cls.__name__ in [
"Image2ImagePipeline",
"InpaintPipeline",
"OutpaintPipeline",
]:
clip, unet, vae, vae_encode = mlir_import()
return cls(
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
)
if cls.__name__ == "StencilPipeline":
clip, unet, vae, controlnet = mlir_import()
return cls(
controlnet, vae, clip, get_tokenizer(), unet, scheduler
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt
if do_classifier_free_guidance
else None,
max_embeddings_multiples=max_embeddings_multiples,
)
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = text_embeddings.shape
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = uncond_embeddings.shape
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()
from typing import List, Optional, Union
import re
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(
pipe: StableDiffusionPipeline, prompt: List[str], max_length: int
):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens,
weights,
max_length,
bos,
eos,
no_boseos_middle=True,
chunk_length=77,
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length
if no_boseos_middle
else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
tokens[i] = (
[bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
)
if no_boseos_middle:
weights[i] = (
[1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
)
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j
* (chunk_length - 2) : min(
len(weights[i]), (j + 1) * (chunk_length - 2)
)
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
# text_embedding = pipe.text_encoder(text_input_chunk)[0]
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
formatted_text_input_chunk = torch.cat(
[text_input_chunk, text_input_chunk]
)
text_embedding = pipe.text_encoder(
"forward", (formatted_text_input_chunk,)
)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
else:
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
# Convert the result to tensor
# text_embeddings = pipe.text_encoder(text_input)[0]
formatted_text_input = torch.cat([text_input, text_input])
text_embeddings = pipe.text_encoder(
"forward", (formatted_text_input,)
)[0]
text_embeddings = torch.from_numpy(text_embeddings)[None, :]
return text_embeddings
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
pipe (`StableDiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(
pipe, prompt, max_length - 2
)
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(
prompt, max_length=max_length, truncation=True
).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(
max_length, max([len(token) for token in uncond_tokens])
)
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(
uncond_tokens, dtype=torch.long, device="cpu"
)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
prompt_weights = torch.tensor(
prompt_weights, dtype=torch.float, device="cpu"
)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
uncond_weights = torch.tensor(
uncond_weights, dtype=torch.float, device="cpu"
)
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@@ -40,6 +40,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
def compile(self):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
device = args.device.split(":", 1)[0].strip()
model_input = {
"euler": {
@@ -89,19 +90,19 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
def _import(self):
scaling_model = ScalingModel()
self.scaling_model = compile_through_fx(
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model = compile_through_fx(
self.step_model, _ = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)

View File

@@ -24,12 +24,17 @@ from apps.stable_diffusion.src.utils.utils import (
get_available_devices,
get_opt_flags,
preprocessCKPT,
fetch_or_delete_vmfbs,
convert_original_vae,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
get_path_stem,
get_extended_name,
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
clear_all,
save_output_img,
get_generation_text_info,
update_lora_weight,
resize_stencil,
)

View File

@@ -1,6 +1,157 @@
{
"stabilityai/stable-diffusion-x4-upscaler": {
"unet": {
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"vae_upscaler": {
"latents" : {
"shape" : [
"1*batch_size",4,"8*height","8*width"
],
"dtype":"f32"
}
}
},
"unet": {
"stabilityai/stable-diffusion-2-1": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-2-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"runwayml/stable-diffusion-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-x4-upscaler": {
"latents": {
"shape": [
"2*batch_size",
@@ -28,141 +179,39 @@
"shape": [2],
"dtype": "i64"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"8*height","8*width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"stabilityai/stable-diffusion-2-1": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
"stencil_adaptor": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
}
},
"CompVis/stable-diffusion-v1-4": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stencil_adaptor": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
}
},
"stencil_unet": {
"stencil_unet": {
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
@@ -242,143 +291,6 @@
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"stabilityai/stable-diffusion-2-inpainting": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"runwayml/stable-diffusion-inpainting": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
}
}
}

View File

@@ -1,85 +1,19 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/sd_untuned",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
"stablediffusion/untuned":"gs://shark_tank/nightly"
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/unet/fp32/length_64/untuned":"unet_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/vae/fp32/length_64/untuned":"vae_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"anythingv3/v1_4/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v1_4/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
"anythingv3/v1_4/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
"anythingv3/v1_4/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v1_4/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v1_4/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
"anythingv3/v1_4/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
"anythingv3/v1_4/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
"anythingv3/v1_4/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v1_4/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
"anythingv3/v1_4/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v1_4/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v1_4/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
"analogdiffusion/v1_4/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
"analogdiffusion/v1_4/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v1_4/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v1_4/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
"analogdiffusion/v1_4/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
"analogdiffusion/v1_4/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
"analogdiffusion/v1_4/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v1_4/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
"analogdiffusion/v1_4/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
"openjourney/v1_4/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
"openjourney/v1_4/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
"openjourney/v1_4/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
"openjourney/v1_4/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
"openjourney/v1_4/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
"openjourney/v1_4/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
"openjourney/v1_4/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
"dreamlike/v1_4/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
"dreamlike/v1_4/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
"dreamlike/v1_4/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
"dreamlike/v1_4/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
"dreamlike/v1_4/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v1_4/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v1_4/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan"
}
]

View File

@@ -45,12 +45,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}

View File

@@ -70,24 +70,27 @@ def load_winograd_configs():
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
if not os.path.exists(WORKDIR):
os.mkdir(WORKDIR)
winograd_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
def load_lower_configs():
def load_lower_configs(base_model_id=None):
from apps.stable_diffusion.src.models import get_variant_version
from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
)
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
if not base_model_id:
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
variant, version = get_variant_version(base_model_id)
@@ -114,7 +117,16 @@ def load_lower_configs():
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
and args.width == 768
):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
@@ -212,7 +224,7 @@ def annotate_with_lower_configs(
return bytecode
def sd_model_annotation(mlir_model, model_name):
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
@@ -220,19 +232,22 @@ def sd_model_annotation(mlir_model, model_name):
winograd_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs()
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
winograd_model, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
tuned_model = mlir_model
else:
use_winograd = False
lowering_config_dir = load_lower_configs()
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)

View File

@@ -22,6 +22,12 @@ p = argparse.ArgumentParser(
### Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
)
p.add_argument(
"-p",
"--prompts",
@@ -340,6 +346,21 @@ p.add_argument(
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
)
p.add_argument(
"--ondemand",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
@@ -360,7 +381,7 @@ p.add_argument(
p.add_argument(
"--vulkan_large_heap_block_size",
default="4147483648",
default="2073741824",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
@@ -472,7 +493,13 @@ p.add_argument(
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="one of: [api, app, web]",
)
p.add_argument(
"--share",
@@ -488,6 +515,28 @@ p.add_argument(
help="flag for setting server port",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for enabling rest API",
)
p.add_argument(
"--output_gallery",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the output gallery tab, and avoid exposing images under --output_dir in the UI",
)
p.add_argument(
"--output_gallery_followlinks",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether the output gallery tab in the UI should follow symlinks when listing subdirectorys under --output_dir",
)
##############################################################################
### SD model auto-annotation flags
##############################################################################
@@ -512,6 +561,31 @@ p.add_argument(
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file",
)
##############################################################################
### SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all",
)
args, unknown = p.parse_known_args()
if args.import_debug:

View File

@@ -126,14 +126,14 @@ def controlnet_hint_conversion(
stencil_to_model_id_map = {
"canny": "lllyasviel/sd-controlnet-canny",
"depth": "lllyasviel/sd-controlnet-depth",
"canny": "lllyasviel/control_v11p_sd15_canny",
"depth": "lllyasviel/control_v11p_sd15_depth",
"hed": "lllyasviel/sd-controlnet-hed",
"mlsd": "lllyasviel/sd-controlnet-mlsd",
"normal": "lllyasviel/sd-controlnet-normal",
"openpose": "lllyasviel/sd-controlnet-openpose",
"scribble": "lllyasviel/sd-controlnet-scribble",
"seg": "lllyasviel/sd-controlnet-seg",
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
"normal": "lllyasviel/control_v11p_sd15_normalbae",
"openpose": "lllyasviel/control_v11p_sd15_openpose",
"scribble": "lllyasviel/control_v11p_sd15_scribble",
"seg": "lllyasviel/control_v11p_sd15_seg",
}

View File

@@ -3,12 +3,15 @@ import gc
import json
import re
from PIL import PngImagePlugin
from PIL import Image
from datetime import datetime as dt
from csv import DictWriter
from pathlib import Path
import numpy as np
from random import randint
import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
@@ -21,8 +24,13 @@ from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
download_from_original_stable_diffusion_ckpt,
create_vae_diffusers_config,
convert_ldm_vae_checkpoint,
)
import requests
from io import BytesIO
from omegaconf import OmegaConf
def get_extended_name(model_name):
@@ -36,6 +44,15 @@ def get_vmfb_path_name(model_name):
return vmfb_path
def _load_vmfb(shark_module, vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
vmfb_path = get_vmfb_path_name(model_name)
@@ -66,7 +83,6 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache
from shark.shark_downloader import download_model
if "cuda" in args.device:
@@ -78,7 +94,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
mlir_model, device=args.device, mlir_dialect="tm_tensor"
)
return _compile_module(shark_module, model_name, extra_args)
@@ -87,7 +103,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
def compile_through_fx(
model,
inputs,
model_name,
extended_model_name,
is_f16=False,
f16_input_mask=None,
use_tuned=False,
@@ -95,7 +111,20 @@ def compile_through_fx(
debug=False,
generate_vmfb=True,
extra_args=[],
base_model_id=None,
model_name=None,
precision=None,
return_mlir=False,
):
if not return_mlir and model_name is not None:
vmfb_path = get_vmfb_path_name(extended_model_name)
if os.path.isfile(vmfb_path):
shark_module = SharkInference(mlir_module=None, device=args.device)
return (
_load_vmfb(shark_module, vmfb_path, model_name, precision),
None,
)
from shark.parser import shark_args
if "cuda" in args.device:
@@ -110,29 +139,28 @@ def compile_through_fx(
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=model_name,
model_name=extended_model_name,
save_dir=save_dir,
)
if use_tuned:
if "vae" in model_name.split("_")[0]:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
mlir_module = sd_model_annotation(mlir_module, model_name)
if "unet" in model_name.split("_")[0]:
args.annotation_model = "unet"
mlir_module = sd_model_annotation(
mlir_module, extended_model_name, base_model_id
)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
mlir_dialect="tm_tensor",
)
if generate_vmfb:
shark_module = SharkInference(
return (
_compile_module(shark_module, extended_model_name, extra_args),
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
del mlir_module
gc.collect()
return _compile_module(shark_module, model_name, extra_args)
del mlir_module
gc.collect()
@@ -264,8 +292,9 @@ def set_init_device_flags():
if (
args.precision != "fp16"
or args.height != 512
or args.width != 512
or args.height not in [512, 768]
or (args.height == 512 and args.width != 512)
or (args.height == 768 and args.width != 768)
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
@@ -299,8 +328,34 @@ def set_init_device_flags():
]:
args.use_tuned = False
elif (
args.height == 768
and args.width == 768
and (
base_model_id
not in [
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
]
or "rdna" not in args.iree_vulkan_target_triple
)
):
args.use_tuned = False
elif "rdna2" in args.iree_vulkan_target_triple and (
base_model_id
not in [
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]
):
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
print(
f"Using tuned models for {base_model_id}(fp16) on device {args.device}."
)
else:
print("Tuned models are currently not supported for this setting.")
@@ -368,7 +423,7 @@ def get_available_devices():
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
available_devices.append("device => cpu")
return available_devices
@@ -425,7 +480,7 @@ def get_path_stem(path):
def get_path_to_diffusers_checkpoint(custom_weights):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = path.stem
diffusers_directory_name = os.path.join("diffusers", path.stem)
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
@@ -454,7 +509,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
num_in_channels = 9 if is_inpaint else 4
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
@@ -464,44 +519,129 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
print("Loading complete")
def load_vmfb(vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module = SharkInference(mlir_module=None, device=args.device)
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
def convert_original_vae(vae_checkpoint):
vae_state_dict = {}
for key in list(vae_checkpoint.keys()):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
vae_state_dict, vae_config
)
return converted_vae_checkpoint
# This utility returns vmfbs of Clip, Unet, Vae and Vae_encode, in case all of them
# are present; deletes them otherwise.
def fetch_or_delete_vmfbs(extended_model_name, precision="fp32"):
vmfb_path = [
get_vmfb_path_name(extended_model_name[model])
for model in extended_model_name
]
number_of_vmfbs = len(vmfb_path)
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = True
compiled_models = [None] * number_of_vmfbs
for i in range(number_of_vmfbs):
all_vmfb_present = all_vmfb_present and vmfb_present[i]
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(number_of_vmfbs):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
def processLoRA(model, use_lora, splitting_prefix):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
model_name = [model for model in extended_model_name.keys()]
for i in range(number_of_vmfbs):
compiled_models[i] = load_vmfb(
vmfb_path[i], model_name[i], precision
state_dict = torch.load(use_lora)
alpha = 0.75
visited = []
# directly update weight in model
process_unet = "te" not in splitting_prefix
for key in state_dict:
if ".alpha" in key or key in visited:
continue
curr_layer = model
if ("text" not in key and process_unet) or (
"text" in key and not process_unet
):
layer_infos = (
key.split(".")[0].split(splitting_prefix)[-1].split("_")
)
return compiled_models
else:
continue
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = (
state_dict[pair_keys[0]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
weight_down = (
state_dict[pair_keys[1]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
return model
def update_lora_weight_for_unet(unet, use_lora):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
return unet
main_file_name = get_path_stem(use_lora)
if ".bin" in use_lora:
main_file_name += ".bin"
elif ".safetensors" in use_lora:
main_file_name += ".safetensors"
elif ".pt" in use_lora:
main_file_name += ".pt"
else:
sys.exit("Only .bin and .safetensors format for LoRA is supported")
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_")
def update_lora_weight(model, use_lora, model_name):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora)
try:
return processLoRA(model, use_lora, "lora_te_")
except:
return None
# `fetch_and_update_base_model_id` is a resource utility function which
@@ -557,17 +697,28 @@ def clear_all():
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
shutil.rmtree(
os.path.join(home, ".local/shark_tank"), ignore_errors=True
)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
def get_generated_imgs_path() -> Path:
return Path(
args.output_dir if args.output_dir else Path.cwd(), "generated_imgs"
)
def get_generated_imgs_todays_subdir() -> str:
return dt.now().strftime("%Y%m%d")
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info={}):
output_path = args.output_dir if args.output_dir else Path.cwd()
generated_imgs_path = Path(
output_path, "generated_imgs", dt.now().strftime("%Y%m%d")
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
@@ -629,3 +780,57 @@ def save_output_img(output_img, img_seed, extra_info={}):
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
return text_output
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image):
width, height = image.size
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 128:
n_size = 128
if width == min_size:
width = n_size
height = n_size / aspect_ratio
else:
height = n_size
width = n_size * aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
min_size = min(width, height)
if min_size > 768:
n_size = 768
if width == min_size:
height = n_size
width = n_size * aspect_ratio
else:
width = n_size
height = n_size / aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
new_image = image.resize((n_width, n_height))
return new_image, n_width, n_height

View File

@@ -1,204 +1,387 @@
from multiprocessing import Process, freeze_support
import os
import sys
import transformers # ensures inclusion in pysintaller exe generation
from apps.stable_diffusion.src import args, clear_all
import apps.stable_diffusion.web.utils.global_obj as global_obj
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
import gradio as gr
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src import args, clear_all
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
)
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create the custom model folder if it doesn't already exist
get_custom_model_path().mkdir(parents=True, exist_ok=True)
# import before IREE to avoid MLIR library issues
import torch_mlir
if args.clear_all:
clear_all()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
def launch_app(address):
from tkinter import Tk
import webview
window = Tk()
# getting screen width and height of display
width = window.winfo_screenwidth()
height = window.winfo_screenheight()
webview.create_window(
"SHARK AI Studio", url=address, width=width, height=height
)
return os.path.join(base_path, relative_path)
webview.start(private_mode=False)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
if __name__ == "__main__":
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
if args.api or "api" in args.ui.split(","):
from apps.stable_diffusion.web.ui import (
txt2img_api,
img2img_api,
upscaler_api,
inpaint_api,
)
from fastapi import FastAPI, APIRouter
import uvicorn
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_gallery,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
img2img_web,
img2img_gallery,
img2img_init_image,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
img2img_sendto_upscaler,
inpaint_web,
inpaint_gallery,
inpaint_init_image,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
inpaint_sendto_upscaler,
outpaint_web,
outpaint_gallery,
outpaint_init_image,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
outpaint_sendto_upscaler,
upscaler_web,
upscaler_gallery,
upscaler_init_image,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
lora_train_web,
)
# init global sd pipeline and config
global_obj._init()
# init global sd pipeline and config
global_obj.init()
app = FastAPI()
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route(
# "/sdapi/v1/outpaint", outpaint_api, methods=["post"]
# )
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
app.include_router(APIRouter())
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
sys.exit(0)
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
import gradio as gr
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
)
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create custom models folders if they don't exist
create_custom_models_folders()
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
img2img_web.render()
with gr.TabItem(label="Inpainting", id=2):
inpaint_web.render()
with gr.TabItem(label="Outpainting", id=3):
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.TabItem(label="LoRA Training", id=5):
lora_train_web.render()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
register_button_click(
dark_theme = resource_path("ui/css/sd_dark_theme.css")
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_png_info_img,
txt2img_status,
txt2img_sendto_img2img,
1,
[txt2img_gallery],
[img2img_init_image, tabs],
)
register_button_click(
txt2img_sendto_inpaint,
2,
[txt2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_outpaint,
3,
[txt2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_upscaler,
4,
[txt2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_status,
img2img_sendto_inpaint,
2,
[img2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_outpaint,
3,
[img2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_upscaler,
4,
[img2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_status,
inpaint_sendto_img2img,
1,
[inpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
inpaint_sendto_outpaint,
3,
[inpaint_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
inpaint_sendto_upscaler,
4,
[inpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_status,
outpaint_sendto_img2img,
1,
[outpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
outpaint_sendto_inpaint,
2,
[outpaint_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
outpaint_sendto_upscaler,
4,
[outpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_status,
upscaler_sendto_img2img,
1,
[upscaler_gallery],
[img2img_init_image, tabs],
)
register_button_click(
upscaler_sendto_inpaint,
2,
[upscaler_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
upscaler_sendto_outpaint,
3,
[upscaler_gallery],
[outpaint_init_image, tabs],
lora_train_web,
model_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
modelmanager_sendto_inpaint,
modelmanager_sendto_outpaint,
modelmanager_sendto_upscaler,
stablelm_chat,
outputgallery_web,
outputgallery_tab_select,
outputgallery_watch,
outputgallery_filename,
outputgallery_sendto_txt2img,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
outputgallery_sendto_upscaler,
)
# init global sd pipeline and config
global_obj._init()
sd_web.queue()
sd_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_modelmanager_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
img2img_web.render()
with gr.TabItem(label="Inpainting", id=2):
inpaint_web.render()
with gr.TabItem(label="Outpainting", id=3):
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.TabItem(label="Model Manager", id=5):
model_web.render()
with gr.TabItem(label="Chat Bot(Experimental)", id=6):
stablelm_chat.render()
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
lora_train_web.render()
if args.output_gallery:
with gr.TabItem(label="Output Gallery", id=8) as og_tab:
outputgallery_web.render()
# extra output gallery configuration
outputgallery_tab_select(og_tab.select)
outputgallery_watch(
[
txt2img_status,
img2img_status,
inpaint_status,
outpaint_status,
upscaler_status,
]
)
# send to buttons
register_button_click(
txt2img_sendto_img2img,
1,
[txt2img_gallery],
[img2img_init_image, tabs],
)
register_button_click(
txt2img_sendto_inpaint,
2,
[txt2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_outpaint,
3,
[txt2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_upscaler,
4,
[txt2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
img2img_sendto_inpaint,
2,
[img2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_outpaint,
3,
[img2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_upscaler,
4,
[img2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
inpaint_sendto_img2img,
1,
[inpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
inpaint_sendto_outpaint,
3,
[inpaint_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
inpaint_sendto_upscaler,
4,
[inpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
outpaint_sendto_img2img,
1,
[outpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
outpaint_sendto_inpaint,
2,
[outpaint_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
outpaint_sendto_upscaler,
4,
[outpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
upscaler_sendto_img2img,
1,
[upscaler_gallery],
[img2img_init_image, tabs],
)
register_button_click(
upscaler_sendto_inpaint,
2,
[upscaler_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
upscaler_sendto_outpaint,
3,
[upscaler_gallery],
[outpaint_init_image, tabs],
)
if args.output_gallery:
register_outputgallery_button(
outputgallery_sendto_txt2img,
0,
[outputgallery_filename],
[txt2img_png_info_img, tabs],
)
register_outputgallery_button(
outputgallery_sendto_img2img,
1,
[outputgallery_filename],
[img2img_init_image, tabs],
)
register_outputgallery_button(
outputgallery_sendto_inpaint,
2,
[outputgallery_filename],
[inpaint_init_image, tabs],
)
register_outputgallery_button(
outputgallery_sendto_outpaint,
3,
[outputgallery_filename],
[outpaint_init_image, tabs],
)
register_outputgallery_button(
outputgallery_sendto_upscaler,
4,
[outputgallery_filename],
[upscaler_init_image, tabs],
)
register_modelmanager_button(
modelmanager_sendto_txt2img,
0,
[hf_models],
[txt2img_custom_model, txt2img_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_img2img,
1,
[hf_models],
[img2img_custom_model, img2img_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_inpaint,
2,
[hf_models],
[inpaint_custom_model, inpaint_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_outpaint,
3,
[hf_models],
[outpaint_custom_model, outpaint_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_upscaler,
4,
[hf_models],
[upscaler_custom_model, upscaler_hf_model_id, tabs],
)
sd_web.queue()
if args.ui == "app":
t = Process(
target=launch_app, args=[f"http://localhost:{args.server_port}"]
)
t.start()
sd_web.launch(
share=args.share,
inbrowser=args.ui == "web",
server_name="0.0.0.0",
server_port=args.server_port,
)

View File

@@ -1,41 +1,88 @@
from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_inf,
txt2img_api,
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_png_info_img,
txt2img_status,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
img2img_api,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_status,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
img2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.inpaint_ui import (
inpaint_inf,
inpaint_api,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_status,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
inpaint_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.outpaint_ui import (
outpaint_inf,
outpaint_api,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_status,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
outpaint_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.upscaler_ui import (
upscaler_inf,
upscaler_api,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_status,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
)
from apps.stable_diffusion.web.ui.model_manager import (
model_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
modelmanager_sendto_inpaint,
modelmanager_sendto_outpaint,
modelmanager_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat
from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_web,
outputgallery_tab_select,
outputgallery_watch,
outputgallery_filename,
outputgallery_sendto_txt2img,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
outputgallery_sendto_upscaler,
)

View File

@@ -7,92 +7,102 @@ Procedure to upgrade the dark theme:
*/
:root {
--color-accent-soft: var(--neutral-700);
--color-background-primary: var(--neutral-950);
--color-background-secondary: var(--neutral-900);
--color-border-accent: var(--neutral-600);
--color-border-primary: var(--neutral-700);
--text-color-code-background: var(--neutral-800);
--text-color-link-active: var(--secondary-500);
--text-color-link: var(--secondary-500);
--text-color-link-hover: var(--secondary-400);
--text-color-link-visited: var(--secondary-600);
--text-color-subdued: var(--neutral-400);
--body-background-color: var(--color-background-primary);
--body-background-fill: var(--background-fill-primary);
--body-text-color: var(--neutral-100);
--color-accent-soft: var(--neutral-700);
--background-fill-primary: var(--neutral-950);
--background-fill-secondary: var(--neutral-900);
--border-color-accent: var(--neutral-600);
--border-color-primary: var(--neutral-700);
--link-text-color-active: var(--secondary-500);
--link-text-color: var(--secondary-500);
--link-text-color-hover: var(--secondary-400);
--link-text-color-visited: var(--secondary-600);
--body-text-color-subdued: var(--neutral-400);
--shadow-spread: 1px;
--block-background: var(--neutral-800);
--block-border-color: var(--color-border-primary);
--block-border-width: 1px;
--block-info-color: var(--text-color-subdued);
--block-label-background: var(--color-background-secondary);
--block-label-border-color: var(--color-border-primary);
--block-label-border-width: 1px;
--block-label-color: var(--neutral-200);
--block-shadow: none;
--block-title-background: none;
--block-title-border-color: none;
--block-title-border-width: 0px;
--block-title-color: var(--neutral-200);
--panel-background: var(--color-background-secondary);
--panel-border-color: var(--color-border-primary);
--checkbox-background: var(--neutral-800);
--checkbox-background-focus: var(--checkbox-background);
--checkbox-background-hover: var(--checkbox-background);
--checkbox-background-selected: var(--secondary-600);
--block-background-fill: var(--neutral-800);
--block-border-color: var(--border-color-primary);
--block_border_width: None;
--block-info-text-color: var(--body-text-color-subdued);
--block-label-background-fill: var(--background-fill-secondary);
--block-label-border-color: var(--border-color-primary);
--block_label_border_width: None;
--block-label-text-color: var(--neutral-200);
--block_shadow: None;
--block_title_background_fill: None;
--block_title_border_color: None;
--block_title_border_width: None;
--block-title-text-color: var(--neutral-200);
--panel-background-fill: var(--background-fill-secondary);
--panel-border-color: var(--border-color-primary);
--panel_border_width: None;
--checkbox-background-color: var(--neutral-800);
--checkbox-background-color-focus: var(--checkbox-background-color);
--checkbox-background-color-hover: var(--checkbox-background-color);
--checkbox-background-color-selected: var(--secondary-600);
--checkbox-border-color: var(--neutral-700);
--checkbox-border-color-focus: var(--secondary-500);
--checkbox-border-color-hover: var(--neutral-600);
--checkbox-border-color-selected: var(--secondary-600);
--checkbox-label-background: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-selected: var(--checkbox-label-background);
--checkbox-label-border-color: var(--color-border-primary);
--checkbox-label-border-color-hover: var(--color-border-primary);
--checkbox-text-color: var(--body-text-color);
--checkbox-text-color-selected: var(--checkbox-text-color);
--error-background: var(--color-background-primary);
--error-border-color: var(--color-border-primary);
--error-border-width: var(--error-border-width);
--error-color: #ef4444;
--input-background: var(--neutral-800);
--input-background-focus: var(--secondary-600);
--input-background-hover: var(--input-background);
--input-border-color: var(--color-border-primary);
--checkbox-border-width: var(--input-border-width);
--checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
--checkbox-label-border-color: var(--border-color-primary);
--checkbox-label-border-color-hover: var(--checkbox-label-border-color);
--checkbox-label-border-width: var(--input-border-width);
--checkbox-label-text-color: var(--body-text-color);
--checkbox-label-text-color-selected: var(--checkbox-label-text-color);
--error-background-fill: var(--background-fill-primary);
--error-border-color: var(--border-color-primary);
--error_border_width: None;
--error-text-color: #ef4444;
--input-background-fill: var(--neutral-800);
--input-background-fill-focus: var(--secondary-600);
--input-background-fill-hover: var(--input-background-fill);
--input-border-color: var(--border-color-primary);
--input-border-color-focus: var(--neutral-700);
--input-border-color-hover: var(--color-border-primary);
--input-border-color-hover: var(--input-border-color);
--input_border_width: None;
--input-placeholder-color: var(--neutral-500);
--input-shadow: var(--input-shadow);
--input_shadow: None;
--input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset);
--loader-color: var(--color-accent);
--stat-color-background: linear-gradient(to right, var(--primary-400), var(--primary-600));
--loader_color: None;
--slider_color: None;
--stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600));
--table-border-color: var(--neutral-700);
--table-even-background: var(--neutral-950);
--table-odd-background: var(--neutral-900);
--table-even-background-fill: var(--neutral-950);
--table-odd-background-fill: var(--neutral-900);
--table-row-focus: var(--color-accent-soft);
--button-cancel-background: linear-gradient(to bottom right, #dc2626, #b91c1c);
--button-cancel-background-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
--button-border-width: var(--input-border-width);
--button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
--button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
--button-cancel-border-color: #dc2626;
--button-cancel-border-color-hover: var(--button-cancel-border-color);
--button-cancel-text-color: white;
--button-cancel-text-color-hover: var(--button-cancel-text-color);
--button-primary-background: linear-gradient(to bottom right, var(--primary-600), var(--primary-700));
--button-primary-background-hover: linear-gradient(to bottom right, var(--primary-600), var(--primary-600));
--button-primary-border-color: var(--primary-600);
--button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600));
--button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500));
--button-primary-border-color: var(--primary-500);
--button-primary-border-color-hover: var(--button-primary-border-color);
--button-primary-text-color: white;
--button-primary-text-color-hover: var(--button-primary-text-color);
--button-secondary-background: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
--button-secondary-background-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
--button-secondary-border-color: var(--neutral-600);
--button-secondary-border-color-hover: var(--button-secondary-border-color);
--button-secondary-text-color: white;
--button-secondary-text-color-hover: var(--button-secondary-text-color);
--block-border-width: 1px;
--block-label-border-width: 1px;
--form-gap-width: 1px;
--error-border-width: 1px;
--input-border-width: 1px;
}
/* SHARK theme */
body {
background-color: var(--color-background-primary);
background-color: var(--background-fill-primary);
}
/* display in full width for desktop devices */
@@ -131,7 +141,7 @@ body {
}
#prompt_box textarea, #negative_prompt_box textarea {
background-color: var(--color-background-primary) !important;
background-color: var(--background-fill-primary) !important;
}
#prompt_examples {
@@ -143,7 +153,6 @@ body {
}
#ui_body {
background-color: var(--color-background-secondary) !important;
padding: var(--size-2) !important;
border-radius: 0.5em !important;
}
@@ -160,18 +169,49 @@ footer {
border-radius: 0 !important;
}
/* Gallery: Remove the default square ratio thumbnail and limit images height to the container */
#gallery .thumbnail-item.thumbnail-lg {
aspect-ratio: unset;
max-height: calc(55vh - (2 * var(--spacing-lg)));
}
@media (min-width: 1921px) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
#gallery .grid-wrap, #gallery .preview{
min-height: calc(768px + 4px + var(--size-14));
max-height: calc(768px + 4px + var(--size-14));
}
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
#gallery .thumbnail-item.thumbnail-lg {
max-height: 770px !important;
}
}
/* Don't upscale when viewing in solo image mode */
#gallery .preview img {
object-fit: scale-down;
}
/* Navbar images in cover mode*/
#gallery .preview .thumbnail-item img {
object-fit: cover;
}
/* Limit the stable diffusion text output height */
#std_output textarea {
max-height: 215px;
}
/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */
#gallery .wrap.default {
pointer-events: none;
}
/* Import Png info box */
#txt2img_prompt_image .fixed-height {
height: var(--size-32);
#txt2img_prompt_image {
height: var(--size-32) !important;
}
/* Hide "remove buttons" from ui dropdowns */
#custom_model .token-remove.remove-all,
#lora_weights .token-remove.remove-all,
#scheduler .token-remove.remove-all,
#device .token-remove.remove-all,
#stencil_model .token-remove.remove-all {
@@ -190,3 +230,44 @@ footer {
#top_logo .download {
display: none;
}
/* output gallery tab */
.output_parameters_dataframe tbody td {
font-size: small;
line-height: var(--line-xs)
}
#output_refresh_button {
max-width: 30px;
align-self: end;
padding-bottom: 8px;
}
.outputgallery_sendto {
min-width: 7em !important;
}
/* output gallery should take up most of the viewport height regardless of image size/number */
#outputgallery_gallery .fixed-height {
min-height: 89vh !important;
}
/* don't stretch non-square images to be square, breaking their aspect ratio */
#outputgallery_gallery .thumbnail-item.thumbnail-lg > img {
object-fit: contain !important;
}
/* centered logo for when there are no images */
#top_logo.logo_centered {
height: 100%;
width: 100%;
}
#top_logo.logo_centered img{
object-fit: scale-down;
position: absolute;
width: 80%;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}

View File

@@ -1,17 +1,350 @@
from pathlib import Path
import os
import torch
import time
import gradio as gr
import PIL
from PIL import Image
from apps.stable_diffusion.scripts import img2img_inf
from apps.stable_diffusion.src import args
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
scheduler_list_cpu_only,
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
StencilPipeline,
resize_stencil,
get_schedulers,
set_init_device_flags,
utils,
save_output_img,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
import numpy as np
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def img2img_inf(
prompt: str,
negative_prompt: str,
image_dict,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.strength = strength
args.scheduler = scheduler
args.img_path = "not none"
args.ondemand = ondemand
if image_dict is None:
return None, "An Initial Image is required"
if use_stencil == "scribble":
image = image_dict["mask"].convert("RGB")
elif isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
else:
image = image_dict["image"].convert("RGB")
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_stencil = None if use_stencil == "None" else use_stencil
args.use_stencil = use_stencil
if use_stencil is not None:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
args.precision = precision
dtype = torch.float32 if precision == "fp32" else torch.half
new_config_obj = Config(
"img2img",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=use_stencil,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
if use_stencil is not None:
args.use_tuned = False
global_obj.set_sd_obj(
StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
else:
global_obj.set_sd_obj(
Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(args.scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"STRENGTH": strength}
text_output = ""
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
batch_size,
height,
width,
steps,
strength,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(
out_imgs[0],
img_seed,
extra_info,
)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Image-to-Image", current_batch + 1, batch_count, batch_size
)
return generated_imgs, text_output, ""
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Img2Img Rest API.
def img2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = img2img_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["denoising_strength"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
use_stencil=InputData["use_stencil"]
if "use_stencil" in InputData.keys()
else "None",
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Image-to-Image") as img2img_web:
@@ -29,23 +362,31 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
img2img_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
hf_model_id = gr.Textbox(
img2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
@@ -62,7 +403,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
img2img_init_image = gr.Image(
label="Input Image", type="pil"
label="Input Image",
source="upload",
tool="sketch",
type="pil",
).style(height=300)
with gr.Accordion(label="Stencil Options", open=False):
@@ -73,13 +417,64 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
value="None",
choices=["None", "canny", "openpose", "scribble"],
)
def show_canvas(choice):
if choice == "scribble":
return (
gr.Slider.update(visible=True),
gr.Slider.update(visible=True),
gr.Button.update(visible=True),
)
else:
return (
gr.Slider.update(visible=False),
gr.Slider.update(visible=False),
gr.Button.update(visible=False),
)
def create_canvas(w, h):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
create_button = gr.Button(
label="Start",
value="Open drawing canvas!",
visible=False,
)
create_button.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[img2img_init_image],
)
use_stencil.change(
fn=show_canvas,
inputs=use_stencil,
outputs=[canvas_width, canvas_height, create_button],
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -93,8 +488,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="PNDM",
choices=scheduler_list,
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -143,22 +538,29 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
step=0.01,
label="Denoising Strength",
)
with gr.Row():
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
@@ -168,6 +570,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -182,15 +585,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
lambda: -1,
inputs=[],
outputs=[seed],
_js="() => -1",
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -198,19 +599,14 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at {get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
img2img_status = gr.Textbox(visible=False)
with gr.Row():
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
img2img_sendto_outpaint = gr.Button(
@@ -235,8 +631,9 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
img2img_custom_model,
img2img_hf_model_id,
custom_vae,
precision,
device,
max_length,
@@ -245,14 +642,24 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[img2img_gallery, std_output],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
status_kwargs = dict(
fn=lambda bc, bs: status_label("Image-to-Image", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=img2img_status,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -1,17 +1,299 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import inpaint_inf
from apps.stable_diffusion.src import args
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
scheduler_list_cpu_only,
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def inpaint_inf(
prompt: str,
negative_prompt: str,
image_dict,
height: int,
width: int,
inpaint_full_res: bool,
inpaint_full_res_padding: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
args.mask_path = "not none"
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"inpaint",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
mask_image,
batch_size,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Inpaint", i + 1, batch_count, batch_size
)
return generated_imgs, text_output
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Inpaint Rest API.
def inpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["image"])
mask = decode_base64_to_image(InputData["mask"])
res = inpaint_inf(
InputData["prompt"],
InputData["negative_prompt"],
{"image": init_image, "mask": mask},
InputData["height"],
InputData["width"],
InputData["is_full_res"],
InputData["full_res_padding"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Inpainting") as inpaint_web:
@@ -29,23 +311,33 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
inpaint_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files()
+ get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
)
hf_model_id = gr.Textbox(
inpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
@@ -71,10 +363,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -88,8 +380,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="PNDM",
choices=scheduler_list,
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -145,22 +437,29 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Row():
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
@@ -170,6 +469,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -184,15 +484,13 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
lambda: -1,
inputs=[],
outputs=[seed],
_js="() => -1",
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -200,19 +498,15 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at {get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
inpaint_status = gr.Textbox(visible=False)
with gr.Row():
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
inpaint_sendto_outpaint = gr.Button(
@@ -238,8 +532,9 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
inpaint_custom_model,
inpaint_hf_model_id,
custom_vae,
precision,
device,
max_length,
@@ -247,14 +542,23 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[inpaint_gallery, std_output],
outputs=[inpaint_gallery, std_output, inpaint_status],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
status_kwargs = dict(
fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=inpaint_status,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -9,7 +9,8 @@ from apps.stable_diffusion.web.ui.utils import (
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list_txt2img,
get_custom_vae_or_lora_weights,
scheduler_list,
predefined_models,
)
@@ -48,6 +49,20 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
lines=3,
)
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights to initialize weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID to initialize weights",
lines=3,
)
with gr.Group(elem_id="image_dir_box_outer"):
training_images_dir = gr.Textbox(
label="ImageDirectory",
@@ -68,7 +83,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
choices=scheduler_list_txt2img,
choices=scheduler_list,
)
with gr.Row():
height = gr.Slider(
@@ -111,22 +126,25 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
label="CFG Scale",
)
with gr.Row():
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -141,15 +159,13 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
lambda: -1,
inputs=[],
outputs=[seed],
_js="() => -1",
queue=False,
)
with gr.Column(scale=6):
train_lora = gr.Button("Train LoRA")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
@@ -194,6 +210,9 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
max_length,
training_images_dir,
output_loc,
get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
),
],
outputs=[std_output],
show_progress=args.progress_bar,
@@ -201,4 +220,4 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
prompt_submit = prompt.submit(**kwargs)
train_click = train_lora.click(**kwargs)
clear_queue.click(fn=None, cancels=[prompt_submit, train_click])
stop_batch.click(fn=None, cancels=[prompt_submit, train_click])

View File

@@ -0,0 +1,157 @@
import os
import gradio as gr
import requests
from io import BytesIO
from PIL import Image
def get_hf_list(num_of_models=20):
path = "https://huggingface.co/api/models"
params = {
"search": "stable-diffusion",
"sort": "downloads",
"direction": "-1",
"limit": {num_of_models},
"full": "true",
}
response = requests.get(path, params=params)
return response.json()
def get_civit_list(num_of_models=50):
path = f"https://civitai.com/api/v1/models?limit={num_of_models}&types=Checkpoint"
headers = {"Content-Type": "application/json"}
raw_json = requests.get(path, headers=headers).json()
models = list(raw_json.items())[0][1]
safe_models = [
safe_model for safe_model in models if not safe_model["nsfw"]
]
version_id = 0 # Currently just using the first version.
safe_models = [
safe_model
for safe_model in safe_models
if safe_model["modelVersions"][version_id]["files"][0]["metadata"][
"format"
]
== "SafeTensor"
]
first_version_models = []
for model_iter in safe_models:
# The modelVersion would only keep the version name.
if (
model_iter["modelVersions"][version_id]["images"][0]["nsfw"]
!= "None"
):
continue
model_iter["modelVersions"][version_id]["modelName"] = model_iter[
"name"
]
model_iter["modelVersions"][version_id]["rating"] = model_iter[
"stats"
]["rating"]
model_iter["modelVersions"][version_id]["favoriteCount"] = model_iter[
"stats"
]["favoriteCount"]
model_iter["modelVersions"][version_id]["downloadCount"] = model_iter[
"stats"
]["downloadCount"]
first_version_models.append(model_iter["modelVersions"][version_id])
return first_version_models
def get_image_from_model(model_json):
model_id = model_json["modelId"]
image = None
for img_info in model_json["images"]:
if img_info["nsfw"] == "None":
image_url = model_json["images"][0]["url"]
response = requests.get(image_url)
image = BytesIO(response.content)
break
return image
with gr.Blocks() as model_web:
with gr.Row():
model_source = gr.Radio(
value=None,
choices=["Hugging Face", "Civitai"],
type="value",
label="Model Source",
)
model_numebr = gr.Slider(
1,
100,
value=10,
step=1,
label="Number of models",
interactive=True,
)
# TODO: add more filters
get_model_btn = gr.Button(value="Get Models")
hf_models = gr.Dropdown(
label="Hugging Face Model List",
choices=None,
value=None,
visible=False,
)
# TODO: select and SendTo
civit_models = gr.Gallery(
label="Civitai Model Gallery",
value=None,
interactive=True,
visible=False,
)
with gr.Row(visible=False) as sendto_btns:
modelmanager_sendto_txt2img = gr.Button(value="SendTo Txt2Img")
modelmanager_sendto_img2img = gr.Button(value="SendTo Img2Img")
modelmanager_sendto_inpaint = gr.Button(value="SendTo Inpaint")
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
def get_model_list(model_source, model_numebr):
if model_source == "Hugging Face":
hf_model_list = get_hf_list(model_numebr)
models = []
for model in hf_model_list:
# TODO: add model info
models.append(f'{model["modelId"]}')
return (
gr.Dropdown.update(choices=models, visible=True),
gr.Gallery.update(value=None, visible=False),
gr.Row.update(visible=True),
)
elif model_source == "Civitai":
civit_model_list = get_civit_list(model_numebr)
models = []
for model in civit_model_list:
image = get_image_from_model(model)
if image is None:
continue
# TODO: add model info
models.append(
(Image.open(image), f'{model["files"][0]["downloadUrl"]}')
)
return (
gr.Dropdown.update(value=None, choices=None, visible=False),
gr.Gallery.update(value=models, visible=True),
gr.Row.update(visible=False),
)
else:
return (
gr.Dropdown.update(value=None, choices=None, visible=False),
gr.Gallery.update(value=None, visible=False),
gr.Row.update(visible=False),
)
get_model_btn.click(
fn=get_model_list,
inputs=[model_source, model_numebr],
outputs=[
hf_models,
civit_models,
sendto_btns,
],
)

View File

@@ -1,17 +1,308 @@
from pathlib import Path
import os
import torch
import time
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import outpaint_inf
from apps.stable_diffusion.src import args
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
scheduler_list_cpu_only,
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
save_output_img,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def outpaint_inf(
prompt: str,
negative_prompt: str,
init_image,
pixels: int,
mask_blur: int,
directions: list,
noise_q: float,
color_variation: float,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"outpaint",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
left = True if "left" in directions else False
right = True if "right" in directions else False
top = True if "up" in directions else False
bottom = True if "down" in directions else False
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
init_image,
pixels,
mask_blur,
left,
right,
top,
bottom,
noise_q,
color_variation,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Outpaint", i + 1, batch_count, batch_size
)
return generated_imgs, text_output, ""
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Inpaint Rest API.
def outpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = outpaint_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["pixels"],
InputData["mask_blur"],
InputData["directions"],
InputData["noise_q"],
InputData["color_variation"],
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Outpainting") as outpaint_web:
@@ -29,23 +320,33 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
outpaint_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files()
+ get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
)
hf_model_id = gr.Textbox(
outpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
@@ -68,10 +369,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -85,8 +386,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="PNDM",
choices=scheduler_list,
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -164,22 +465,29 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
steps = gr.Slider(
1, 100, value=20, step=1, label="Steps"
)
with gr.Row():
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
@@ -189,6 +497,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -203,15 +512,13 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
lambda: -1,
inputs=[],
outputs=[seed],
_js="() => -1",
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -219,19 +526,14 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at {get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
outpaint_status = gr.Textbox(visible=False)
with gr.Row():
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -258,8 +560,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
outpaint_custom_model,
outpaint_hf_model_id,
custom_vae,
precision,
device,
max_length,
@@ -267,14 +570,23 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[outpaint_gallery, std_output],
outputs=[outpaint_gallery, std_output, outpaint_status],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
status_kwargs = dict(
fn=lambda bc, bs: status_label("Outpaint", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=outpaint_status,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -0,0 +1,434 @@
import glob
import gradio as gr
import os
from PIL import Image
from apps.stable_diffusion.src import args
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
from apps.stable_diffusion.web.utils.gradio_configs import (
gradio_tmp_galleries_folder,
)
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
# -- Functions for file, directory and image info querying
output_dir = get_generated_imgs_path()
def outputgallery_filenames(subdir) -> list[str]:
new_dir_path = os.path.join(output_dir, subdir)
if os.path.exists(new_dir_path):
filenames = [
glob.glob(new_dir_path + "/" + ext)
for ext in ("*.png", "*.jpg", "*.jpeg")
]
return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True)
else:
return []
def output_subdirs() -> list[str]:
# Gets a list of subdirectories of output_dir and below, as relative paths.
relative_paths = [
os.path.relpath(entry[0], output_dir)
for entry in os.walk(
output_dir, followlinks=args.output_gallery_followlinks
)
]
# It is less confusing to always including the subdir that will take any images generated
# today even if it doesn't exist yet
if get_generated_imgs_todays_subdir() not in relative_paths:
relative_paths.append(get_generated_imgs_todays_subdir())
# sort subdirectories so that that the date named ones we probably created in this or
# previous sessions come first, sorted with the most recent first. Other subdirs are listed
# after.
generated_paths = sorted(
[path for path in relative_paths if path.isnumeric()], reverse=True
)
result_paths = generated_paths + sorted(
[
path
for path in relative_paths
if (not path.isnumeric()) and path != "."
]
)
return result_paths
# clear zero length temporary files that gradio 3.22.0 buggily creates
# TODO: remove once gradio is upgraded to or past 3.32.0
def clear_zero_length_temps():
zero_length_temps = [
os.path.join(root, file)
for root, dirs, files in os.walk(gradio_tmp_galleries_folder)
for file in files
if os.path.getsize(os.path.join(root, file)) == 0
]
for file in zero_length_temps:
os.remove(file)
# --- Define UI layout for Gradio
with gr.Blocks() as outputgallery_web:
nod_logo = Image.open(nodlogo_loc)
with gr.Row(elem_id="outputgallery_gallery"):
# needed to workaround gradio issue: https://github.com/gradio-app/gradio/issues/2907
dev_null = gr.Textbox("", visible=False)
gallery_files = gr.State(value=[])
subdirectory_paths = gr.State(value=[])
with gr.Column(scale=6):
logo = gr.Image(
label="Getting subdirectories...",
value=nod_logo,
interactive=False,
visible=True,
show_label=True,
elem_id="top_logo",
elem_classes="logo_centered",
)
gallery = gr.Gallery(
label="",
value=gallery_files.value,
visible=False,
show_label=True,
).style(columns=4)
gallery.DEFAULT_TEMP_DIR = gradio_tmp_galleries_folder
with gr.Column(scale=4):
with gr.Box():
with gr.Row():
with gr.Column(scale=16, min_width=160):
subdirectories = gr.Dropdown(
label=f"Subdirectories of {output_dir}",
type="value",
choices=subdirectory_paths.value,
value="",
interactive=True,
).style(container=False)
with gr.Column(
scale=1, min_width=32, elem_id="output_refresh_button"
):
refresh = gr.Button(
variant="secondary",
value="\u21BB", # unicode clockwise arrow circle
).style(size="sm")
image_columns = gr.Slider(
label="Columns shown", value=4, minimum=1, maximum=16, step=1
)
outputgallery_filename = gr.Textbox(
label="Filename", value="None", interactive=False
).style(show_copy_button=True)
with gr.Accordion(
label="Parameter Information", open=False
) as parameters_accordian:
image_parameters = gr.DataFrame(
headers=["Parameter", "Value"],
col_count=2,
wrap=True,
elem_classes="output_parameters_dataframe",
value=[["Status", "No image selected"]],
)
with gr.Accordion(label="Send To", open=True):
with gr.Row():
outputgallery_sendto_txt2img = gr.Button(
value="Txt2Img",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
outputgallery_sendto_img2img = gr.Button(
value="Img2Img",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
outputgallery_sendto_inpaint = gr.Button(
value="Inpaint",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
outputgallery_sendto_outpaint = gr.Button(
value="Outpaint",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
outputgallery_sendto_upscaler = gr.Button(
value="Upscaler",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
# --- Event handlers
def on_clear_gallery():
clear_zero_length_temps()
return [
gr.Gallery.update(
value=[],
visible=False,
),
gr.Image.update(
visible=True,
),
]
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
new_images = outputgallery_filenames(subdir)
new_label = (
f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
)
return [
new_images,
gr.Gallery.update(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image.update(
label=new_label,
visible=len(new_images) == 0,
),
]
def on_refresh(current_subdir: str) -> list:
# get an up to date subdirectory list
refreshed_subdirs = output_subdirs()
# get the images using either the current subdirectory or the most recent valid one
new_subdir = (
current_subdir
if current_subdir in refreshed_subdirs
else refreshed_subdirs[0]
)
new_images = outputgallery_filenames(new_subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, new_subdir)}"
return [
gr.Dropdown.update(
choices=refreshed_subdirs,
value=new_subdir,
),
refreshed_subdirs,
new_images,
gr.Gallery.update(
value=new_images, label=new_label, visible=len(new_images) > 0
),
gr.Image.update(
label=new_label,
visible=len(new_images) == 0,
),
]
def on_new_image(subdir, subdir_paths, status) -> list:
# prevent error triggered when an image generates before the tab has even been selected
subdir_paths = (
subdir_paths
if len(subdir_paths) > 0
else [get_generated_imgs_todays_subdir()]
)
# only update if the current subdir is the most recent one as new images only go there
if subdir_paths[0] == subdir:
clear_zero_length_temps()
new_images = outputgallery_filenames(subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"
return [
new_images,
gr.Gallery.update(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image.update(
label=new_label,
visible=len(new_images) == 0,
),
]
else:
# otherwise change nothing, (only untyped gradio gr.update() does this)
return [gr.update(), gr.update(), gr.update()]
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
# evt.index is an index into the full list of filenames for the current subdirectory
filename = images[evt.index]
params = displayable_metadata(filename)
if params:
return [
filename,
list(map(list, params["parameters"].items())),
]
return [
filename,
[["Status", "No parameters found"]],
]
def on_outputgallery_filename_change(filename: str) -> list:
exists = filename != "None" and os.path.exists(filename)
return [
# disable or enable each of the sendto button based on whether an image is selected
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh to populate
# the subdirectory select box and the images from the most recent subdirectory.
#
# We do it at this point rather than setting this up in the controls' definitions
# as when you refresh the browser you always get what was *initially* set, which
# won't include any new subdirectories or images that might have created since
# the application was started. Doing it this way means a browser refresh/reload
# always gets the most up to date data.
def on_select_tab(subdir_paths):
if len(subdir_paths) == 0:
return on_refresh("")
else:
return (
# Change nothing, (only untyped gr.update() does this)
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
# Unfortunately as of gradio 3.22.0 gr.update against Galleries doesn't support
# things set with .style, nor the elem_classes kwarg so we have to directly set
# things up via JavaScript if we want the client to take notice of any of our
# changes to the number of columns after it decides to put them back to the
# original number when we change something
def js_set_columns_in_browser(timeout_length):
return f"""
(new_cols) => {{
setTimeout(() => {{
required_style = "auto ".repeat(new_cols).trim();
gallery = document.querySelector('#outputgallery_gallery .grid-container');
if (gallery) {{
gallery.style.gridTemplateColumns = required_style
}}
}}, {timeout_length});
return []; // prevents console error from gradio
}}
"""
# --- Wire handlers up to the actions
# - Many actions reset the number of columns shown in the gallery on the browser end,
# so we have to set them back to what we think they should be after the initial
# action.
# - None of the actions on this tab trigger inference, and we want the user to be able
# to do them whilst other tabs have ongoing inference running. Waiting in the queue
# behind inference jobs would mean the UI can't fully respond until the inference tasks
# complete, hence queue=False on all of these.
set_gallery_columns_immediate = dict(
fn=None,
inputs=[image_columns],
# gradio blanks the UI on Chrome on Linux on gallery select if I don't put an output here
outputs=[dev_null],
_js=js_set_columns_in_browser(0),
queue=False,
)
# setting columns after selecting a gallery item needs a real timeout length for the
# number of columns to actually be applied. Not really sure why, maybe something has
# to finish animating?
set_gallery_columns_delayed = dict(
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
)
# clearing images when we need to completely change what's in the gallery avoids current
# images being shown replacing piecemeal and prevents weirdness and errors if the user
# selects an image during the replacement phase.
clear_gallery = dict(
fn=on_clear_gallery,
inputs=None,
outputs=[gallery, logo],
queue=False,
)
image_columns.change(**set_gallery_columns_immediate)
subdirectories.select(**clear_gallery).then(
on_select_subdir,
[subdirectories],
[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
gallery.select(
on_select_image,
[gallery_files],
[outputgallery_filename, image_parameters],
queue=False,
).then(**set_gallery_columns_delayed)
outputgallery_filename.change(
on_outputgallery_filename_change,
[outputgallery_filename],
[
outputgallery_sendto_txt2img,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
outputgallery_sendto_upscaler,
],
queue=False,
)
# We should have been given the .select function for our tab, so set it up
def outputgallery_tab_select(select):
select(
fn=on_select_tab,
inputs=[subdirectory_paths],
outputs=[
subdirectories,
subdirectory_paths,
gallery_files,
gallery,
logo,
],
queue=False,
).then(**set_gallery_columns_immediate)
# We should have been passed a list of components on other tabs that update
# when a new image has generated on that tab, so set things up so the user
# will see that new image if they are looking at today's subdirectory
def outputgallery_watch(components: gr.Textbox):
for component in components:
component.change(
on_new_image,
inputs=[subdirectories, subdirectory_paths, component],
outputs=[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)

View File

@@ -0,0 +1,185 @@
import gradio as gr
import torch
import os
from pathlib import Path
from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
sharkModel = 0
sharded_model = 0
vicuna_model = 0
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
past_key_values = None
def chat(curr_system_message, history, model, device, precision):
print(f"In chat for {model}")
global sharded_model
global past_key_values
global vicuna_model
if "vicuna" in model:
from apps.language_models.src.pipelines.vicuna_pipeline import (
Vicuna,
)
curr_system_message = start_message_vicuna
if vicuna_model == 0:
first_vic_vmfb_path = Path("first_vicuna.vmfb")
second_vic_vmfb_path = Path("second_vicuna.vmfb")
if "cuda" in device:
device = "cuda"
vicuna_model = Vicuna(
"vicuna",
hf_model_path=model,
device=device,
precision=precision,
first_vicuna_vmfb_path=first_vic_vmfb_path,
second_vicuna_vmfb_path=second_vic_vmfb_path,
)
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
prompt = messages.strip()
print("prompt = ", prompt)
sentence = vicuna_model.generate(prompt)
partial_text = ""
for new_text in sentence.split(" "):
# print(new_text)
partial_text += new_text + " "
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
yield history
history[-1][1] = sentence
return history
# else Model is StableLM
global sharkModel
from apps.language_models.src.pipelines.stablelm_pipeline import (
SharkStableLM,
)
if sharkModel == 0:
# max_new_tokens=512
shark_slm = SharkStableLM(
"StableLM"
) # pass elements from UI as required
# Construct the input message string for the model by concatenating the current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
curr_system_message = start_message
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
generate_kwargs = dict(prompt=messages)
words_list = shark_slm.generate(**generate_kwargs)
partial_text = ""
for new_text in words_list:
# print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
yield history
return words_list
with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model = gr.Dropdown(
label="Select Model",
value="TheBloke/vicuna-7B-1.1-HF",
choices=[
"stabilityai/stablelm-tuned-alpha-3b",
"TheBloke/vicuna-7B-1.1-HF",
],
)
supported_devices = [
device for device in available_devices if "cuda" in device
]
enabled = len(supported_devices) > 0
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
if enabled
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
)
precision = gr.Radio(
label="Precision",
value="fp32",
choices=[
"fp16",
"fp32",
],
visible=True,
)
chatbot = gr.Chatbot().style(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
).style(container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
system_msg = gr.Textbox(
start_message, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)

View File

@@ -1,17 +1,273 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import txt2img_inf
from apps.stable_diffusion.src import prompt_examples, args
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list_txt2img,
scheduler_list,
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
utils,
save_output_img,
prompt_examples,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"txt2img",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Text-to-Image", i + 1, batch_count, batch_size
)
return generated_imgs, text_output, ""
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Text2Img Rest API.
def txt2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
res = txt2img_inf(
InputData["prompt"],
InputData["negative_prompt"],
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row(elem_id="ui_title"):
@@ -30,25 +286,34 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
custom_model = gr.Dropdown(
txt2img_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
hf_model_id = gr.Textbox(
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"]
+ get_custom_model_files("vae"),
)
with gr.Column(scale=1, min_width=170):
png_info_img = gr.Image(
txt2img_png_info_img = gr.Image(
label="Import PNG info",
elem_id="txt2img_prompt_image",
type="pil",
@@ -72,10 +337,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -90,7 +355,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
choices=scheduler_list_txt2img,
choices=scheduler_list,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -105,10 +370,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
384,
768,
value=args.height,
step=8,
label="Height",
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
384,
768,
value=args.width,
step=8,
label="Width",
)
precision = gr.Radio(
label="Precision",
@@ -139,23 +412,31 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -170,15 +451,13 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
lambda: -1,
inputs=[],
outputs=[seed],
_js="() => -1",
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
@@ -194,19 +473,14 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at {get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
txt2img_status = gr.Textbox(visible=False)
with gr.Row():
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -230,8 +504,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
txt2img_custom_model,
txt2img_hf_model_id,
custom_vae,
precision,
device,
max_length,
@@ -239,29 +514,32 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[txt2img_gallery, std_output],
outputs=[txt2img_gallery, std_output, txt2img_status],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
status_kwargs = dict(
fn=lambda bc, bs: status_label("Text-to-Image", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=txt2img_status,
)
from apps.stable_diffusion.web.utils.png_metadata import (
import_png_metadata,
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
png_info_img.change(
txt2img_png_info_img.change(
fn=import_png_metadata,
inputs=[
png_info_img,
],
outputs=[
png_info_img,
txt2img_png_info_img,
prompt,
negative_prompt,
steps,
@@ -270,7 +548,20 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
seed,
width,
height,
custom_model,
hf_model_id,
txt2img_custom_model,
txt2img_hf_model_id,
],
outputs=[
txt2img_png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
txt2img_custom_model,
txt2img_hf_model_id,
],
)

View File

@@ -1,17 +1,309 @@
from pathlib import Path
import os
import torch
import time
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import upscaler_inf
from apps.stable_diffusion.src import args
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
scheduler_list_cpu_only,
predefined_upscaler_models,
cancel_sd,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
get_schedulers,
set_init_device_flags,
utils,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generated_imgs_path
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def upscaler_inf(
prompt: str,
negative_prompt: str,
init_image,
height: int,
width: int,
steps: int,
noise_level: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
if init_image is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB").resize((height, width))
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
args.height = 128
args.width = 128
new_config_obj = Config(
"upscaler",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
args.height,
args.width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_size = batch_size
args.max_length = max_length
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
"DDPM"
)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"NOISE LEVEL": noise_level}
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
low_res_img = image
high_res_img = Image.new("RGB", (height * 4, width * 4))
for i in range(0, width, 128):
for j in range(0, height, 128):
box = (j, i, j + 128, i + 128)
upscaled_image = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
low_res_img.crop(box),
batch_size,
args.height,
args.width,
steps,
noise_level,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(high_res_img, img_seed, extra_info)
generated_imgs.append(high_res_img)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log, status_label(
"Upscaler", current_batch + 1, batch_count, batch_size
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output, ""
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Upscaler Rest API.
def upscaler_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = upscaler_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["noise_level"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Upscaler") as upscaler_web:
@@ -29,23 +321,33 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
upscaler_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-x4-upscaler",
choices=["None"]
+ get_custom_model_files()
+ get_custom_model_files(
custom_checkpoint_type="upscaler"
)
+ predefined_upscaler_models,
)
hf_model_id = gr.Textbox(
upscaler_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
@@ -65,13 +367,28 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Input Image", type="pil"
).style(height=300)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="DDIM",
choices=scheduler_list,
choices=scheduler_list_cpu_only,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -128,22 +445,29 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
step=1,
label="Noise Level",
)
with gr.Row():
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
@@ -153,6 +477,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -167,15 +492,13 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
lambda: -1,
inputs=[],
outputs=[seed],
_js="() => -1",
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -183,19 +506,15 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at {get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
upscaler_status = gr.Textbox(visible=False)
with gr.Row():
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -218,21 +537,33 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
upscaler_custom_model,
upscaler_hf_model_id,
custom_vae,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[upscaler_gallery, std_output],
outputs=[upscaler_gallery, std_output, upscaler_status],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
status_kwargs = dict(
fn=lambda bc, bs: status_label("Upscaler", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=upscaler_status,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -5,6 +5,10 @@ import glob
from pathlib import Path
from apps.stable_diffusion.src import args
from dataclasses import dataclass
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
@dataclass
@@ -12,6 +16,7 @@ class Config:
mode: str
model_id: str
ckpt_loc: str
custom_vae: str
precision: str
batch_size: int
max_length: int
@@ -20,6 +25,7 @@ class Config:
device: str
use_lora: str
use_stencil: str
ondemand: str
custom_model_filetypes = (
@@ -27,13 +33,7 @@ custom_model_filetypes = (
"*.safetensors",
) # the tuple of file types
scheduler_list = [
"DDIM",
"PNDM",
"DPMSolverMultistep",
"EulerAncestralDiscrete",
]
scheduler_list_txt2img = [
scheduler_list_cpu_only = [
"DDIM",
"PNDM",
"LMSDiscrete",
@@ -41,6 +41,8 @@ scheduler_list_txt2img = [
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
]
scheduler_list = scheduler_list_cpu_only + [
"SharkEulerDiscrete",
]
@@ -70,24 +72,91 @@ def resource_path(relative_path):
return os.path.join(base_path, relative_path)
def get_custom_model_path():
return Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd(), "models")
def create_custom_models_folders():
dir = ["vae", "lora"]
if not args.ckpt_dir:
dir.insert(0, "models")
else:
if not os.path.isdir(args.ckpt_dir):
sys.exit(
f"Invalid --ckpt_dir argument, {args.ckpt_dir} folder does not exists."
)
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
def get_custom_model_pathfile(custom_model_name):
return os.path.join(get_custom_model_path(), custom_model_name)
def get_custom_model_path(model="models"):
# structure in WebUI :-
# models or args.ckpt_dir
# |___lora
# |___vae
sub_folder = "" if model == "models" else model
if args.ckpt_dir:
return Path(Path(args.ckpt_dir), sub_folder)
else:
return Path(Path.cwd(), "models/" + sub_folder)
def get_custom_model_files():
def get_custom_model_pathfile(custom_model_name, model="models"):
return os.path.join(get_custom_model_path(model), custom_model_name)
def get_custom_model_files(model="models", custom_checkpoint_type=""):
ckpt_files = []
for extn in custom_model_filetypes:
file_types = custom_model_filetypes
if model == "lora":
file_types = custom_model_filetypes + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_custom_model_path(), extn))
for x in glob.glob(
os.path.join(get_custom_model_path(model), extn)
)
]
match custom_checkpoint_type:
case "inpainting":
files = [
val
for val in files
if val.endswith("inpainting" + extn.removeprefix("*"))
]
case "upscaler":
files = [
val
for val in files
if val.endswith("upscaler" + extn.removeprefix("*"))
]
case _:
files = [
val
for val in files
if not (
val.endswith("inpainting" + extn.removeprefix("*"))
or val.endswith("upscaler" + extn.removeprefix("*"))
)
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_custom_vae_or_lora_weights(weights, hf_id, model):
use_weight = ""
if weights == "None" and not hf_id:
use_weight = ""
elif not hf_id:
use_weight = get_custom_model_pathfile(weights, model)
else:
use_weight = hf_id
return use_weight
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:
global_obj.set_sd_status(SD_STATE_CANCEL)
except Exception:
pass
nodlogo_loc = resource_path("logos/nod-logo.png")
available_devices = get_available_devices()

View File

@@ -0,0 +1,9 @@
# functions for generating labels used in common by tabs across the UI
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
if batch_index < batch_count:
bs = f"x{batch_size}" if batch_size > 1 else ""
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
else:
return f"{tab_name} complete"

View File

@@ -8,39 +8,68 @@ Also we could avoid memory leak when switching models by clearing the cache.
"""
def init():
global sd_obj
global config_obj
sd_obj = None
config_obj = None
def _init():
global _sd_obj
global _config_obj
global _schedulers
_sd_obj = None
_config_obj = None
_schedulers = None
def set_sd_obj(value):
global sd_obj
sd_obj = value
global _sd_obj
_sd_obj = value
def set_sd_scheduler(key):
global _sd_obj
_sd_obj.scheduler = _schedulers[key]
def set_sd_status(value):
global _sd_obj
_sd_obj.status = value
def set_cfg_obj(value):
global config_obj
config_obj = value
global _config_obj
_config_obj = value
def set_schedulers(value):
global sd_obj
sd_obj.scheduler = value
global _schedulers
_schedulers = value
def get_sd_obj():
return sd_obj
global _sd_obj
return _sd_obj
def get_sd_status():
global _sd_obj
return _sd_obj.status
def get_cfg_obj():
return config_obj
global _config_obj
return _config_obj
def get_scheduler(key):
global _schedulers
return _schedulers[key]
def clear_cache():
global sd_obj
global config_obj
del sd_obj
del config_obj
global _sd_obj
global _config_obj
global _schedulers
del _sd_obj
del _config_obj
del _schedulers
gc.collect()
_sd_obj = None
_config_obj = None
_schedulers = None

View File

@@ -1,19 +1,48 @@
import os
import shutil
import tempfile
import gradio
from os import listdir
from time import time
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
gradio_tmp_galleries_folder = os.path.join(gradio_tmp_imgs_folder, "galleries")
# Clear all gradio tmp images
def clear_gradio_tmp_imgs_folder():
if not os.path.exists(gradio_tmp_imgs_folder):
return
for fileName in listdir(gradio_tmp_imgs_folder):
# Delete tmp png files
if fileName.startswith("tmp") and fileName.endswith(".png"):
os.remove(gradio_tmp_imgs_folder + fileName)
# clear all gradio tmp files created by generation galleries
print(
"Clearing gradio temporary image files from a prior run. This may take some time..."
)
image_files = [
filename
for filename in os.listdir(gradio_tmp_imgs_folder)
if os.path.isfile(os.path.join(gradio_tmp_imgs_folder, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
cleanup_start = time()
for filename in image_files:
os.remove(gradio_tmp_imgs_folder + filename)
print(
f"Clearing generation temporary image files took {time() - cleanup_start:4f} seconds"
)
else:
print("no generation temporary files to clear")
# Clear all gradio tmp files created by output galleries
if os.path.exists(gradio_tmp_galleries_folder):
cleanup_start = time()
shutil.rmtree(gradio_tmp_galleries_folder, ignore_errors=True)
print(
f"Clearing output gallery temporary image files took {time() - cleanup_start:4f} seconds"
)
else:
print("no output gallery temporary files to clear")
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder

View File

@@ -0,0 +1,6 @@
from .png_metadata import (
import_png_metadata,
)
from .display import (
displayable_metadata,
)

View File

@@ -0,0 +1,31 @@
import csv
import os
from .format import humanize, humanizable
def csv_path(image_filename: str):
return os.path.join(os.path.dirname(image_filename), "imgs_details.csv")
def has_csv(image_filename: str) -> bool:
return os.path.exists(csv_path(image_filename))
def parse_csv(image_filename: str):
# We use a reader instead of a DictReader here for images_details.csv files due to the lack of
# headers, and then match up the return list for each row with our guess at which column format
# the file is using.
# we assume the final column of the csv has the original filename with full path and match that
# against the image_filename. We then exclude the filename from the output, hence the -1's.
csv_filename = csv_path(image_filename)
matches = [
humanize(row)
for row in csv.reader(open(csv_filename, "r", newline=""))
if row
and humanizable(row)
and os.path.basename(image_filename) in row[-1]
]
return matches[0] if matches else {}

View File

@@ -0,0 +1,50 @@
import json
import os
from PIL import Image
from .png_metadata import parse_generation_parameters
from .exif_metadata import has_exif, parse_exif
from .csv_metadata import has_csv, parse_csv
from .format import compact, humanize
def displayable_metadata(image_filename: str) -> dict:
pil_image = Image.open(image_filename)
# we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads,
# and we go via that for SendTo, and is directly tied to the image)
if "parameters" in pil_image.info:
return {
"source": "png",
"parameters": compact(
parse_generation_parameters(pil_image.info["parameters"])
),
}
# we have a matching json file (next most likely to be accurate when it's there)
json_path = os.path.splitext(image_filename)[0] + ".json"
if os.path.isfile(json_path):
with open(json_path) as params_file:
return {
"source": "json",
"parameters": compact(
humanize(json.load(params_file), includes_filename=False)
),
}
# we have a CSV file so try that (can be different shapes, and it usually has no
# headers/param names so of the things we we *know* have parameters, it's the
# last resort)
if has_csv(image_filename):
params = parse_csv(image_filename)
if params: # we might not have found the filename in the csv
return {
"source": "csv",
"parameters": compact(params), # already humanized
}
# EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something*
if has_exif(image_filename):
return {"source": "exif", "parameters": parse_exif(pil_image)}
# we've got nothing
return None

View File

@@ -0,0 +1,52 @@
from PIL import Image
from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS
def has_exif(image_filename: str) -> bool:
return True if Image.open(image_filename).getexif() else False
def parse_exif(pil_image: Image) -> dict:
img_exif = pil_image.getexif()
# See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594
# I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I
# I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a
# dependency
exif_tags = {
TAGS.get(key, key): str(val)
for (key, val) in img_exif.items()
if key in TAGS
and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo)
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
def try_get_ifd(ifd_id):
try:
return img_exif.get_ifd(ifd_id).items()
except KeyError:
return {}
ifd_tags = {
TAGS.get(key, key): str(val)
for ifd_id in IFD
for (key, val) in try_get_ifd(ifd_id)
if ifd_id != IFD.GPSInfo
and key in TAGS
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
gps_tags = {
GPSTAGS.get(key, key): str(val)
for (key, val) in try_get_ifd(IFD.GPSInfo)
if key in GPSTAGS
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
return {**exif_tags, **ifd_tags, **gps_tags}

View File

@@ -0,0 +1,115 @@
# As SHARK has evolved more columns have been added to images_details.csv. However, since
# no version of the CSV has any headers (yet) we don't actually have anything within the
# file that tells us which parameter each column is for. So this is a list of known patterns
# indexed by length which is what we're going to have to use to guess which columns are the
# right ones for the file we're looking at.
# The same ordering is used for JSON, but these do have key names, however they are not very
# human friendly, nor do they match up with the what is written to the .png headers
# So these are functions to try and get something consistent out the raw input from all
# these sources
PARAMS_FORMATS = {
9: {
"VARIANT": "Model",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"OUTPUT": "Filename",
},
10: {
"MODEL": "Model",
"VARIANT": "Variant",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"OUTPUT": "Filename",
},
12: {
"VARIANT": "Model",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
},
}
PARAMS_FORMAT_LONGEST = PARAMS_FORMATS[max(PARAMS_FORMATS.keys())]
def compact(metadata: dict) -> dict:
# we don't want to alter the original dictionary
result = dict(metadata)
# discard the filename because we should already have it
if result.keys() & {"Filename"}:
result.pop("Filename")
# make showing the sizes more compact by using only one line each
if result.keys() & {"Size-1", "Size-2"}:
result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}"
elif result.keys() & {"Height", "Width"}:
result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}"
if result.keys() & {"Hires resize-1", "Hires resize-1"}:
hires_y = result.pop("Hires resize-1")
hires_x = result.pop("Hires resize-2")
if hires_x == 0 and hires_y == 0:
result["Hires resize"] = "None"
else:
result["Hires resize"] = f"{hires_y}x{hires_x}"
return result
def humanizable(metadata: dict | list[str], includes_filename=True) -> dict:
lookup_key = len(metadata) + (0 if includes_filename else 1)
return lookup_key in PARAMS_FORMATS.keys()
def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
lookup_key = len(metadata) + (0 if includes_filename else 1)
# For lists we can only work based on the length, we have no other information
if isinstance(metadata, list):
if humanizable(metadata, includes_filename):
return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata))
else:
raise KeyError(
f"Humanize could not find the format for a parameter list of length {len(metadata)}"
)
# For dictionaries we try to use the matching length parameter format if
# available, otherwise we use the longest. Then we swap keys in the
# metadata that match keys in the format for the friendlier name that we
# have set in the format value
if isinstance(metadata, dict):
if humanizable(metadata, includes_filename):
format = PARAMS_FORMATS[lookup_key]
else:
format = PARAMS_FORMAT_LONGEST
return {
format[key]: value
for (key, value) in metadata.items()
if key in format.keys()
}
raise TypeError("Can only humanize parameter lists or dictionaries")

View File

@@ -1,21 +1,8 @@
import re
from pathlib import Path
from apps.stable_diffusion.web.ui.txt2img_ui import (
png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
custom_model,
hf_model_id,
)
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
scheduler_list_txt2img,
scheduler_list,
predefined_models,
)
@@ -75,7 +62,19 @@ def parse_generation_parameters(x: str):
return res
def import_png_metadata(pil_data):
def import_png_metadata(
pil_data,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
hf_model_id,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
@@ -110,39 +109,44 @@ def import_png_metadata(pil_data):
% metadata["Model"]
)
outputs = {
png_info_img: None,
negative_prompt: metadata["Negative prompt"],
steps: int(metadata["Steps"]),
guidance_scale: float(metadata["CFG scale"]),
seed: int(metadata["Seed"]),
width: float(metadata["Size-1"]),
height: float(metadata["Size-2"]),
}
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
cfg_scale = float(metadata["CFG scale"])
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
outputs[custom_model] = png_custom_model
outputs[hf_model_id] = ""
custom_model = png_custom_model
hf_model_id = ""
if "Model" in metadata and png_hf_model_id:
outputs[custom_model] = "None"
outputs[hf_model_id] = png_hf_model_id
custom_model = "None"
hf_model_id = png_hf_model_id
if "Prompt" in metadata:
outputs[prompt] = metadata["Prompt"]
prompt = metadata["Prompt"]
if "Sampler" in metadata:
if metadata["Sampler"] in scheduler_list_txt2img:
outputs[scheduler] = metadata["Sampler"]
if metadata["Sampler"] in scheduler_list:
sampler = metadata["Sampler"]
else:
print(
"Import PNG info: Unable to find a scheduler for %s"
% metadata["Sampler"]
)
return outputs
except Exception as ex:
if pil_data and pil_data.info.get("parameters"):
print("import_png_metadata failed with %s" % ex)
pass
return {
png_info_img: None,
}
return (
None,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
hf_model_id,
)

View File

@@ -2,4 +2,4 @@
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python generate_sharktank.py
python tank/generate_sharktank.py

View File

@@ -87,11 +87,22 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]
counter = 0
for import_opt in import_options:
for model_name in hf_model_names:
if model_name in to_skip:
continue
for use_tune in tuned_options:
if (
model_name == "stabilityai/stable-diffusion-2-1"
and use_tune == tuned_options[0]
):
continue
elif (
model_name == "stabilityai/stable-diffusion-2-1-base"
and use_tune == tuned_options[1]
):
continue
command = (
[
executable, # executable is the python from the venv used to run this
@@ -174,9 +185,21 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
else:
print(command)
print("failed to generate image for this configuration")
if "2_1_base" in model_name:
print("failed a known successful model.")
exit(1)
with open(dumpfile_name, "r+") as f:
output = f.readlines()
print("\n".join(output))
exit(1)
if os.name == "nt":
counter += 1
if counter % 2 == 0:
extra_flags.append(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
else:
if counter != 1:
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)

View File

@@ -2,9 +2,11 @@ def pytest_addoption(parser):
# Attaches SHARK command-line arguments to the pytest machinery.
parser.addoption(
"--benchmark",
action="store_true",
default="False",
help="Pass option to benchmark and write results.csv",
action="store",
type=str,
default=None,
choices=("baseline", "native", "all"),
help="Benchmarks specified engine(s) and writes bench_results.csv.",
)
parser.addoption(
"--onnx_bench",
@@ -40,7 +42,13 @@ def pytest_addoption(parser):
"--update_tank",
action="store_true",
default="False",
help="Update local shark tank with latest artifacts.",
help="Update local shark tank with latest artifacts if model artifact hash mismatched.",
)
parser.addoption(
"--force_update_tank",
action="store_true",
default="False",
help="Force-update local shark tank with artifacts from specified shark_tank URL (defaults to nightly).",
)
parser.addoption(
"--ci_sha",
@@ -51,15 +59,21 @@ def pytest_addoption(parser):
parser.addoption(
"--local_tank_cache",
action="store",
default="",
default=None,
help="Specify the directory in which all downloaded shark_tank artifacts will be cached.",
)
parser.addoption(
"--tank_url",
type=str,
default="gs://shark_tank/latest",
default="gs://shark_tank/nightly",
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
)
parser.addoption(
"--tank_prefix",
type=str,
default=None,
help="Prefix to gs://shark_tank/ model directories from which to download SHARK tank artifacts. Default is nightly.",
)
parser.addoption(
"--benchmark_dispatches",
default=None,
@@ -70,3 +84,14 @@ def pytest_addoption(parser):
default="./temp_dispatch_benchmarks",
help="Directory in which dispatch benchmarks are saved.",
)
parser.addoption(
"--batchsize",
default=1,
type=int,
help="Batch size for the tested model.",
)
parser.addoption(
"--custom_device",
default=None,
help="Custom device string to run tests with.",
)

View File

@@ -40,7 +40,7 @@ cmake --build build/
*Prepare the model*
```bash
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
```
*Prepare the input*
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
```bash
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
```
VAE and Autoencoder are also available
```bash
# VAE
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
# CLIP Autoencoder
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
```

View File

@@ -21,7 +21,7 @@ endif()
# Compile mnist.mlir to mnist.vmfb.
set(_COMPILE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)
set(_COMPILE_ARGS)
list(APPEND _COMPILE_ARGS "--iree-input-type=mhlo")
list(APPEND _COMPILE_ARGS "--iree-input-type=auto")
list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu")
list(APPEND _COMPILE_ARGS "${IREE_SOURCE_DIR}/samples/models/mnist.mlir")
list(APPEND _COMPILE_ARGS "-o")

View File

@@ -1,3 +1,3 @@
# SHARK Annotator
gradio==3.15.0
gradio==3.34.0
jsonlines

75
docs/shark_sd_blender.md Normal file
View File

@@ -0,0 +1,75 @@
# Overview
This document is intended to provide a starting point for using SHARK stable diffusion with Blender.
We currently make use of the [AI-Render Plugin](https://github.com/benrugg/AI-Render) to integrate with Blender.
## Setup SHARK and prerequisites:
* Download the latest SHARK SD webui .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow instructions on the [README](https://github.com/nod-ai/SHARK#readme)
* Once you have the .exe where you would like SHARK to install, run the .exe from terminal/PowerShell with the `--api` flag:
```
## Run the .exe in API mode:
.\shark_sd_<date>_<ver>.exe --api
## For example:
.\shark_sd_20230411_671.exe --api --server_port=8082
## From a the base directory of a source clone of SHARK:
./setup_venv.ps1
python apps\stable_diffusion\web\index.py --api
```
Your local SD server should start and look something like this:
![image](https://user-images.githubusercontent.com/87458719/231369758-e2c3c45a-eccc-4fe5-a788-4a3bf1ace1d1.png)
* Note: When running in api mode with `--api`, the .exe will not function as a webUI. Thus, the address in the terminal output will only be useful for API requests.
### Install AI Render
- Get AI Render on [Blender Market](https://blendermarket.com/products/ai-render) or [Gumroad](https://airender.gumroad.com/l/ai-render)
- Open Blender, then go to Edit > Preferences > Add-ons > Install and then find the zip file
- We will be using the Automatic1111 SD backend for the AI-Render plugin. Follow instructions [here](https://github.com/benrugg/AI-Render/wiki/Local-Installation) to setup local SD backend.
Your AI-Render preferences should be configured as shown; the highlighted part should match your terminal output:
![image](https://user-images.githubusercontent.com/87458719/231390322-59a54a09-520a-4a08-b658-6e37bd63e932.png)
The [AI-Render README](https://github.com/benrugg/AI-Render/blob/main/README.md) has more details on installation and usage, as well as video tutorials.
## Using AI-Render + SHARK in your Blender project
- In the Render Properties tab, in the AI-Render dropdown, enable AI-Render.
![image](https://user-images.githubusercontent.com/87458719/231392843-9bd51744-3ce2-464e-843a-0c4d4c96df0c.png)
- Select an image size (it's usually better to upscale later than go high on the img2img resolution here.)
![image](https://user-images.githubusercontent.com/87458719/231394288-0c4ab8c5-dc30-4dbe-8bc1-7520ded5efe8.png)
- From here, you can enter a prompt and configure img2img Stable Diffusion parameters, and AI-Render will run SHARK SD img2img on the rendered scene.
- AI-Render has useful presets for aesthetic styles, so you should be able to keep your subject prompt simple and focus on creating a decent Blender scene to start from.
![image](https://user-images.githubusercontent.com/87458719/231440729-2fe69586-41cb-4274-9ce7-f6c08def600b.png)
## Examples:
Scene (Input image):
![blender-sample-2](https://user-images.githubusercontent.com/87458719/231450408-0e680086-3e52-4962-a5c1-c703a94d1583.png)
Prompt:
"A bowl of tangerines in front of rocks, masterpiece, oil on canvas, by Georgia O'Keefe, trending on artstation, landscape painting by Caspar David Friedrich"
Negative Prompt (default):
"ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
Example output:
![blender-sample-2_out](https://user-images.githubusercontent.com/87458719/231451145-a0b56897-a7d0-4add-bbed-7e8af21a65df.png)

View File

@@ -1,278 +0,0 @@
# Lint as: python3
"""SHARK Tank"""
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
# will generate local shark tank folder like this:
# /SHARK
# /gen_shark_tank
# /albert_lite_base
# /...model_name...
#
import os
import csv
import argparse
from shark.shark_importer import SharkImporter
import subprocess as sp
import hashlib
import numpy as np
from pathlib import Path
from apps.stable_diffusion.src.models import (
model_wrappers as mw,
)
from apps.stable_diffusion.src.utils.stable_args import (
args,
)
def create_hash(file_name):
with open(file_name, "rb") as f:
file_hash = hashlib.blake2b()
while chunk := f.read(2**20):
file_hash.update(chunk)
return file_hash.hexdigest()
def save_torch_model(torch_model_list):
from tank.model_utils import (
get_hf_model,
get_vision_model,
get_hf_img_cls_model,
get_fp16_model,
)
with open(torch_model_list) as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")
fields = next(torch_reader)
for row in torch_reader:
torch_model_name = row[0]
tracing_required = row[1]
model_type = row[2]
is_dynamic = row[3]
tracing_required = False if tracing_required == "False" else True
is_dynamic = False if is_dynamic == "False" else True
print("generating artifacts for: " + torch_model_name)
model = None
input = None
if model_type == "stable_diffusion":
args.use_tuned = False
args.import_mlir = True
args.use_tuned = False
args.local_tank_cache = WORKDIR
precision_values = ["fp16"]
seq_lengths = [64, 77]
for precision_value in precision_values:
args.precision = precision_value
for length in seq_lengths:
model = mw.SharkifyStableDiffusionModel(
model_id=torch_model_name,
custom_weights="",
precision=precision_value,
max_len=length,
width=512,
height=512,
use_base_vae=False,
debug=True,
sharktank_dir=WORKDIR,
generate_vmfb=False,
)
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(torch_model_name)
elif model_type == "hf":
model, input, _ = get_hf_model(torch_model_name)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(torch_model_name)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name)
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"
)
os.makedirs(torch_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
model,
(input,),
frontend="torch",
)
mlir_importer.import_debug(
is_dynamic=False,
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name,
)
# Generate torch dynamic models.
if is_dynamic:
mlir_importer.import_debug(
is_dynamic=True,
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name + "_dynamic",
)
def save_tf_model(tf_model_list):
from tank.model_utils_tf import (
get_causal_image_model,
get_causal_lm_model,
get_keras_model,
get_TFhf_model,
)
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
with open(tf_model_list) as csvfile:
tf_reader = csv.reader(csvfile, delimiter=",")
fields = next(tf_reader)
for row in tf_reader:
tf_model_name = row[0]
model_type = row[1]
model = None
input = None
print(f"Generating artifacts for model {tf_model_name}")
if model_type == "hf":
model, input, _ = get_causal_lm_model(tf_model_name)
if model_type == "img":
model, input, _ = get_causal_image_model(tf_model_name)
if model_type == "keras":
model, input, _ = get_keras_model(tf_model_name)
if model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name)
tf_model_name = tf_model_name.replace("/", "_")
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
os.makedirs(tf_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
model,
inputs=input,
frontend="tf",
)
mlir_importer.import_debug(
is_dynamic=False,
dir=tf_model_dir,
model_name=tf_model_name,
)
mlir_hash = create_hash(
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
)
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
def save_tflite_model(tflite_model_list):
from shark.tflite_utils import TFLitePreprocessor
with open(tflite_model_list) as csvfile:
tflite_reader = csv.reader(csvfile, delimiter=",")
for row in tflite_reader:
print("\n")
tflite_model_name = row[0]
tflite_model_link = row[1]
print("tflite_model_name", tflite_model_name)
print("tflite_model_link", tflite_model_link)
tflite_model_name_dir = os.path.join(
WORKDIR, str(tflite_model_name) + "_tflite"
)
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
# Preprocess to get SharkImporter input args
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()
# Use SharkImporter to get SharkInference input args
my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
frontend="tflite",
raw_model_file=raw_model_file_path,
)
my_shark_importer.import_debug(
dir=tflite_model_name_dir,
model_name=tflite_model_name,
func_name="main",
)
mlir_hash = create_hash(
os.path.join(
tflite_model_name_dir,
tflite_model_name + "_tflite" + ".mlir",
)
)
np.save(
os.path.join(tflite_model_name_dir, "hash"),
np.array(mlir_hash),
)
# Validates whether the file is present or not.
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
if __name__ == "__main__":
# Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--torch_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/torch_model_list.csv",
# help="""Contains the file with torch_model name and args.
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
# )
# parser.add_argument(
# "--tf_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tf_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--tflite_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tflite/tflite_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--ci_tank_dir",
# type=bool,
# default=False,
# )
# parser.add_argument("--upload", type=bool, default=False)
# old_args = parser.parse_args()
home = str(Path.home())
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
torch_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tf_model_list.csv"
)
tflite_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
)
save_torch_model(
os.path.join(os.path.dirname(__file__), "tank", "torch_sd_list.csv")
)
save_torch_model(torch_model_csv)
save_tf_model(tf_model_csv)
save_tflite_model(tflite_model_csv)

View File

@@ -6,36 +6,16 @@ from distutils.sysconfig import get_python_lib
import fileinput
from pathlib import Path
# Diffusers 0.13.1 fails with transformers __init.py errros in BLIP. So remove it for now until we fork it
pix2pix_init = Path(get_python_lib() + "/diffusers/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(get_python_lib() + "/diffusers/pipelines/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(
get_python_lib() + "/diffusers/pipelines/stable_diffusion/__init__.py"
# Temorary workaround for transformers/__init__.py.
path_to_tranformers_hook = Path(
get_python_lib()
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
)
for line in fileinput.input(pix2pix_init, inplace=True):
if "StableDiffusionPix2PixZeroPipeline" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
if path_to_tranformers_hook.is_file():
pass
else:
with open(path_to_tranformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")

View File

@@ -1,3 +1,3 @@
[pytest]
addopts = --verbose -p no:warnings
addopts = --verbose -s -p no:warnings
norecursedirs = inference tank/tflite examples benchmarks shark

View File

@@ -2,8 +2,8 @@
--pre
numpy>1.22.4
torchvision
pytorch-triton
torchvision==0.16.0.dev20230322
tabulate
tqdm
@@ -15,8 +15,8 @@ iree-tools-tf
# TensorFlow and JAX.
gin-config
tf-nightly
keras>=2.10
tensorflow>2.11
keras
#tf-models-nightly
#tensorflow-text-nightly
transformers
@@ -33,6 +33,7 @@ lit
pyyaml
python-dateutil
sacremoses
sentencepiece
# web dependecies.
gradio

View File

@@ -16,15 +16,19 @@ parameterized
# Add transformers, diffusers and scipy since it most commonly used
transformers
diffusers
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
scipy
ftfy
gradio
gradio==3.34.0
altair
omegaconf
safetensors
opencv-python
scikit-image
pytorch_lightning # for runwayml models
tk
pywebview
sentencepiece
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile

View File

@@ -45,7 +45,7 @@ if ($arguments -eq "--force"){
Remove-Item .\shark.venv -Force -Recurse
if (Test-Path .\shark.venv\) {
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
break
exit 1
}
}
}
@@ -78,12 +78,12 @@ if (!($PyVer.length -ne 0)) {$p} # return Python --version String if py.exe is u
if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any list
{
Write-Host "Please install Python 3.11 and try again"
break
exit 34
}
Write-Host "Installing Build Dependencies"
# make sure we really use 3.11 from list, even if it's not the default.
if (!($PyVer.length -ne 0)) {py -3.11 -m venv .\shark.venv\}
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
else {python -m venv .\shark.venv\}
.\shark.venv\Scripts\activate
python -m pip install --upgrade pip

View File

@@ -2,9 +2,10 @@
# Sets up a venv suitable for running samples.
# e.g:
# ./setup_venv.sh #setup a default $PYTHON3 shark.venv
# Environment Variables by the script.
# Environment variables used by the script.
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
# VENV_DIR=myshark.venv #create a venv called myshark.venv
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
# IMPORTER=1 #Install importer deps
# BENCHMARK=1 #Install benchmark deps
@@ -26,15 +27,17 @@ PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; prin
echo "Python: $PYTHON"
echo "Python version: $PYTHON_VERSION_X_Y"
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
VENV_DIR=${VENV_DIR:-shark.venv}
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
PYTHON="$(which python3)"
else
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
if [[ "$SKIP_VENV" != "1" ]]; then
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
VENV_DIR=${VENV_DIR:-shark.venv}
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
PYTHON="$(which python3)"
else
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
fi
fi
Red=`tput setaf 1`
@@ -129,11 +132,11 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
TV_VERSION=${TV_VER:9:18}
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu117."
echo "Successfully Installed torch + cu118."
else
echo "Could not install torch + cu117." >&2
echo "Could not install torch + cu118." >&2
fi
fi
@@ -147,8 +150,7 @@ if [[ ! -z "${ONNX}" ]]; then
fi
fi
if [[ -z "${CONDA_PREFIX}" ]]; then
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
echo "${Green}Before running examples activate venv with:"
echo " ${Green}source $VENV_DIR/bin/activate"
fi

View File

@@ -0,0 +1,73 @@
from transformers import AutoTokenizer, FlaxAutoModel
import torch
import jax
from typing import Union, Dict, List, Any
import numpy as np
from shark.shark_inference import SharkInference
import io
NumpyTree = Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]
def convert_torch_tensor_tree_to_numpy(
tree: Union[torch.tensor, Dict[str, torch.tensor], List[torch.tensor]]
) -> NumpyTree:
return jax.tree_util.tree_map(
lambda torch_tensor: torch_tensor.cpu().detach().numpy(), tree
)
def convert_int64_to_int32(tree: NumpyTree) -> NumpyTree:
return jax.tree_util.tree_map(
lambda tensor: np.array(tensor, dtype=np.int32)
if tensor.dtype == np.int64
else tensor,
tree,
)
def get_sample_input():
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased"
)
inputs_torch = tokenizer("Hello, World!", return_tensors="pt")
return convert_int64_to_int32(
convert_torch_tensor_tree_to_numpy(inputs_torch.data)
)
def get_jax_model():
return FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
def export_jax_to_mlir(jax_model: Any, sample_input: NumpyTree):
model_mlir = jax.jit(jax_model).lower(**sample_input).compiler_ir()
byte_stream = io.BytesIO()
model_mlir.operation.write_bytecode(file=byte_stream)
return byte_stream.getvalue()
def assert_array_list_allclose(x, y, *args, **kwargs):
assert len(x) == len(y)
for a, b in zip(x, y):
np.testing.assert_allclose(
np.asarray(a), np.asarray(b), *args, **kwargs
)
sample_input = get_sample_input()
jax_model = get_jax_model()
mlir = export_jax_to_mlir(jax_model, sample_input)
# Compile and load module.
shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo")
shark_inference.compile()
# Run main function.
result = shark_inference("main", jax.tree_util.tree_flatten(sample_input)[0])
# Run JAX model.
reference_result = jax.tree_util.tree_flatten(jax_model(**sample_input))[0]
# Verify result.
assert_array_list_allclose(result, reference_result, atol=1e-5)

View File

@@ -0,0 +1,6 @@
flax
jax[cpu]
nodai-SHARK
orbax
transformers
torch

View File

@@ -70,11 +70,11 @@ mlir_model, func_name, inputs, golden_out = download_model(
"resnet50", frontend="torch"
)
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
shark_module = SharkInference(mlir_model, mlir_dialect="linalg")
shark_module.compile()
path = shark_module.save_module()
shark_module.load_module(path)
result = shark_module.forward((img.detach().numpy(),))
result = shark_module("forward", (img.detach().numpy(),))
print("The top 3 results obtained via shark_runner is:")
print(top3_possibilities(torch.from_numpy(result)))

View File

@@ -35,8 +35,9 @@ def run_cmd(cmd, debug=False):
stderr=subprocess.PIPE,
check=True,
)
result_str = result.stdout.decode()
return result_str
stdout = result.stdout.decode()
stderr = result.stderr.decode()
return stdout, stderr
except subprocess.CalledProcessError as e:
print(e.output)
sys.exit(f"Exiting program due to error running {cmd}")
@@ -44,10 +45,15 @@ def run_cmd(cmd, debug=False):
def iree_device_map(device):
uri_parts = device.split("://", 2)
iree_driver = (
_IREE_DEVICE_MAP[uri_parts[0]]
if uri_parts[0] in _IREE_DEVICE_MAP
else uri_parts[0]
)
if len(uri_parts) == 1:
return _IREE_DEVICE_MAP[uri_parts[0]]
return iree_driver
else:
return f"{_IREE_DEVICE_MAP[uri_parts[0]]}://{uri_parts[1]}"
return f"{iree_driver}://{uri_parts[1]}"
def get_supported_device_list():
@@ -56,6 +62,8 @@ def get_supported_device_list():
_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
@@ -67,11 +75,13 @@ _IREE_DEVICE_MAP = {
def iree_target_map(device):
if "://" in device:
device = device.split("://")[0]
return _IREE_TARGET_MAP[device]
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
_IREE_TARGET_MAP = {
"cpu": "llvm-cpu",
"cpu-task": "llvm-cpu",
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
@@ -109,10 +119,8 @@ def check_device_drivers(device):
subprocess.check_output("rocminfo")
except Exception:
return True
# Unknown device.
else:
return True
# Unknown device. We assume drivers are installed.
return False

View File

@@ -90,6 +90,7 @@ def build_benchmark_args(
benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}")
# if time_extractor:
# benchmark_cl.append(time_extractor)
benchmark_cl.append(f"--print_statistics=true")
return benchmark_cl
@@ -129,7 +130,8 @@ def build_benchmark_args_non_tensor_input(
def run_benchmark_module(benchmark_cl):
"""
Run benchmark command, extract result and return iteration/seconds.
Run benchmark command, extract result and return iteration/seconds, host
peak memory, and device peak memory.
# TODO: Add an example of the benchmark command.
Input: benchmark command.
@@ -138,15 +140,22 @@ def run_benchmark_module(benchmark_cl):
assert os.path.exists(
benchmark_path
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
bench_result = run_cmd(" ".join(benchmark_cl))
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
try:
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
match = regex_split.search(bench_stdout)
time_ms = float(match.group(1))
unit = match.group(3)
except AttributeError:
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
match = regex_split.search(bench_stdout)
time_ms = float(match.group(1))
unit = match.group(2)
return 1.0 / (time * 0.001)
iter_per_second = 1.0 / (time_ms * 0.001)
# Extract peak memory.
host_regex = re.compile(r".*HOST_LOCAL:\s*([0-9]+)B peak")
host_peak_b = int(host_regex.search(bench_stderr).group(1))
device_regex = re.compile(r".*DEVICE_LOCAL:\s*([0-9]+)B peak")
device_peak_b = int(device_regex.search(bench_stderr).group(1))
return iter_per_second, host_peak_b, device_peak_b

View File

@@ -23,6 +23,7 @@ import re
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan"]:
@@ -30,6 +31,9 @@ def get_iree_device_args(device, extra_args=[]):
f"Specific device selection only supported for vulkan now."
f"Proceeding with {device} as device."
)
device_num = device_uri[1]
else:
device_num = 0
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
@@ -42,7 +46,9 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] in ["metal", "vulkan"]:
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(extra_args=extra_args)
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
@@ -52,12 +58,11 @@ def get_iree_device_args(device, extra_args=[]):
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg"]:
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
return ["--iree-llvmcpu-target-cpu-features=host"]
elif frontend in ["tensorflow", "tf", "mhlo"]:
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
return [
"--iree-llvmcpu-target-cpu-features=host",
"--iree-mhlo-demote-i64-to-i32=false",
"--iree-flow-demote-i64-to-i32",
]
else:
@@ -70,6 +75,7 @@ def get_iree_common_args():
return [
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
]
@@ -188,21 +194,23 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
benchmark_bash.write(" ".join(benchmark_cl))
benchmark_bash.close()
benchmark_data = run_benchmark_module(benchmark_cl)
iter_per_second, _, _ = run_benchmark_module(
benchmark_cl
)
benchmark_file = open(
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
)
benchmark_file.write(f"DISPATCH: {d_}\n")
benchmark_file.write(str(benchmark_data) + "\n")
benchmark_file.write(str(iter_per_second) + "\n")
benchmark_file.write(
"SHARK BENCHMARK RESULT: "
+ str(1 / (benchmark_data * 0.001))
+ str(1 / (iter_per_second * 0.001))
+ "\n"
)
benchmark_file.close()
benchmark_runtimes[d_] = 1 / (benchmark_data * 0.001)
benchmark_runtimes[d_] = 1 / (iter_per_second * 0.001)
elif ".mlir" in f_ and "benchmark" not in f_:
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
@@ -256,8 +264,8 @@ def compile_module_to_flatbuffer(
args += extra_args
if frontend in ["tensorflow", "tf"]:
input_type = "mhlo"
elif frontend in ["mhlo", "tosa"]:
input_type = "auto"
elif frontend in ["stablehlo", "tosa"]:
input_type = frontend
elif frontend in ["tflite", "tflite-tosa"]:
input_type = "tosa"
@@ -304,7 +312,7 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
ModuleCompiled = ctx.modules.module
ModuleCompiled = getattr(ctx.modules, vm_module.name)
return ModuleCompiled, config
@@ -358,7 +366,7 @@ def export_iree_module_to_vmfb(
def export_module_to_mlir_file(module, frontend, directory: str):
# TODO: write proper documentation.
mlir_str = module
if frontend in ["tensorflow", "tf", "mhlo", "tflite"]:
if frontend in ["tensorflow", "tf", "mhlo", "stablehlo", "tflite"]:
mlir_str = module.decode("utf-8")
elif frontend in ["pytorch", "torch"]:
mlir_str = module.operation.get_asm()

View File

@@ -30,11 +30,10 @@ def get_iree_gpu_args():
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
) and (shark_args.enable_tf32 == True):
return [
"--iree-hal-cuda-disable-loop-nounroll-wa",
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
]
else:
return ["--iree-hal-cuda-disable-loop-nounroll-wa"]
return []
# Get the default gpu args given the architecture.

View File

@@ -117,7 +117,8 @@ def get_extensions(triple):
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
ext.append("VK_NV_cooperative_matrix")
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
ext.append("VK_KHR_shader_integer_dot_product")
return make_ext_list(ext_list=ext)
@@ -131,7 +132,9 @@ def get_vendor(triple):
return "ARM"
if arch == "m1":
return "Apple"
if arch in ["turing", "ampere"]:
if arch in ["arc", "UHD"]:
return "Intel"
if arch in ["turing", "ampere", "pascal"]:
return "NVIDIA"
if arch == "ardeno":
return "Qualcomm"
@@ -149,7 +152,7 @@ def get_device_type(triple):
return "Unknown"
if arch == "cpu":
return "CPU"
if arch in ["turing", "ampere"]:
if arch in ["turing", "ampere", "arc", "pascal"]:
return "DiscreteGPU"
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]:
if product == "ivega10":
@@ -226,6 +229,7 @@ def get_vulkan_target_capabilities(triple):
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["shaderIntegerDotProduct"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
@@ -234,12 +238,12 @@ def get_vulkan_target_capabilities(triple):
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
if arch == "rdna3":
# TODO: Get scope value
cap["coopmatCases"] = [
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
]
if product == "rx5700xt":
cap["storagePushConstant16"] = False
cap["storagePushConstant8"] = False
@@ -272,7 +276,7 @@ def get_vulkan_target_capabilities(triple):
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["shaderIntegerDotProduct"] = True
cap["storagePushConstant16"] = False
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
@@ -303,6 +307,7 @@ def get_vulkan_target_capabilities(triple):
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["shaderIntegerDotProduct"] = False
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
@@ -343,6 +348,38 @@ def get_vulkan_target_capabilities(triple):
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "arc":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
cap["subgroupSize"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderFloat64"] = False
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = False
cap["shaderIntegerDotProduct"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "cpu":
if product == "swiftshader":
cap["maxComputeSharedMemorySize"] = 16384
@@ -356,6 +393,40 @@ def get_vulkan_target_capabilities(triple):
"ShuffleRelative",
]
elif arch in ["pascal"]:
cap["maxComputeSharedMemorySize"] = 49152
cap["maxComputeWorkGroupInvocations"] = 1536
cap["maxComputeWorkGroupSize"] = [1536, 1024, 64]
cap["subgroupSize"] = 32
cap["minSubgroupSize"] = 32
cap["maxSubgroupSize"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["shaderFloat16"] = False
cap["shaderFloat64"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["shaderIntegerDotProduct"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch in ["ampere", "turing"]:
cap["maxComputeSharedMemorySize"] = 49152
cap["maxComputeWorkGroupInvocations"] = 1024
@@ -380,6 +451,7 @@ def get_vulkan_target_capabilities(triple):
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["shaderIntegerDotProduct"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True

View File

@@ -21,8 +21,9 @@ from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_vulkan_device_name():
vulkaninfo_dump = run_cmd("vulkaninfo").split(linesep)
def get_vulkan_device_name(device_num=0):
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
@@ -30,8 +31,8 @@ def get_vulkan_device_name():
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing first one: {vulkaninfo_list[0]}")
return vulkaninfo_list[0]
print(f"Choosing device: {vulkaninfo_list[device_num]}")
return vulkaninfo_list[device_num]
def get_os_name():
@@ -106,21 +107,26 @@ def get_vulkan_target_triple(device_name):
# Windows: AMD Radeon RX 7900 XTX
elif all(x in device_name for x in ("RX", "7900")):
triple = f"rdna3-7900-{system_os}"
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
triple = f"rdna3-w7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
triple = f"rdna2-unknown-{system_os}"
# Intel Targets
elif any(x in device_name for x in ("A770", "A750")):
triple = f"arc-770-{system_os}"
else:
triple = None
return triple
def get_vulkan_triple_flag(device_name="", extra_args=[]):
def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
for flag in extra_args:
if "-iree-vulkan-target-triple=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None
if device_name == "" or device_name == [] or device_name is None:
vulkan_device = get_vulkan_device_name()
vulkan_device = get_vulkan_device_name(device_num=device_num)
else:
vulkan_device = device_name
triple = get_vulkan_target_triple(vulkan_device)
@@ -138,8 +144,8 @@ def get_vulkan_triple_flag(device_name="", extra_args=[]):
return None
def get_iree_vulkan_args(extra_args=[]):
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
def get_iree_vulkan_args(device_num=0, extra_args=[]):
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
res_vulkan_flag = []
vulkan_triple_flag = None
@@ -150,7 +156,9 @@ def get_iree_vulkan_args(extra_args=[]):
break
if vulkan_triple_flag is None:
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
vulkan_triple_flag = get_vulkan_triple_flag(
device_num=device_num, extra_args=extra_args
)
if vulkan_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)

View File

@@ -30,8 +30,8 @@ import os
import sys
from typing import Dict, List
import iree.compiler._mlir_libs
from iree.compiler import ir
from iree.compiler.transforms import ireec as ireec_trans
def model_annotation(
@@ -311,11 +311,18 @@ def add_attributes(op: ir.Operation, config: List[Dict]):
split_k = config["split_k"]
elif "SPIRV" in config["pipeline"]:
pipeline = config["pipeline"]
tile_sizes = [
config["work_group_tile_sizes"],
config["parallel_tile_sizes"],
config["reduction_tile_sizes"],
]
if pipeline == "SPIRVMatmulPromoteVectorize":
tile_sizes = [
config["work_group_tile_sizes"]
+ [config["reduction_tile_sizes"][-1]],
]
else:
tile_sizes = [
config["work_group_tile_sizes"],
config["parallel_tile_sizes"],
config["reduction_tile_sizes"],
]
workgroup_size = config["work_group_sizes"]
if "vector_tile_sizes" in config.keys():
tile_sizes += [config["vector_tile_sizes"]]
@@ -409,7 +416,6 @@ def shape_list_to_string(input):
def create_context() -> ir.Context:
context = ir.Context()
ireec_trans.register_all_dialects(context)
context.allow_unregistered_dialects = True
return context

View File

@@ -14,8 +14,10 @@
import argparse
import os
import subprocess
parser = argparse.ArgumentParser(description="SHARK runner.")
parser.add_argument(
"--device",
type=str,
@@ -54,7 +56,7 @@ parser.add_argument(
)
parser.add_argument(
"--shark_prefix",
default="latest",
default=None,
help="gs://shark_tank/<this_flag>/model_directories",
)
parser.add_argument(

View File

@@ -21,9 +21,17 @@ from shark.iree_utils.benchmark_utils import (
from shark.parser import shark_args
from datetime import datetime
import time
from typing import Optional
import csv
import os
TF_CPU_DEVICE = "/CPU:0"
TF_GPU_DEVICE = "/GPU:0"
def _bytes_to_mb_str(bytes_: Optional[int]) -> str:
return "" if bytes_ is None else f"{bytes_ / 1e6:.6f}"
class OnnxFusionOptions(object):
def __init__(self):
@@ -70,6 +78,7 @@ class SharkBenchmarkRunner(SharkRunner):
self.vmfb_file = None
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
self.import_args = {}
SharkRunner.__init__(
self,
mlir_module,
@@ -104,7 +113,6 @@ class SharkBenchmarkRunner(SharkRunner):
def benchmark_torch(self, modelname):
import torch
import torch._dynamo as dynamo
from tank.model_utils import get_torch_model
if self.device == "cuda":
@@ -116,31 +124,54 @@ class SharkBenchmarkRunner(SharkRunner):
torch_device = torch.device(
"cuda:0" if self.device == "cuda" else "cpu"
)
HFmodel, input = get_torch_model(modelname)[:2]
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
frontend_model = HFmodel.model
frontend_model.to(torch_device)
input.to(torch_device)
# frontend_model = torch.compile(frontend_model, mode="max-autotune", backend="inductor")
# TODO: re-enable as soon as pytorch CUDA context issues are resolved
try:
frontend_model = torch.compile(
frontend_model, mode="max-autotune", backend="inductor"
)
except RuntimeError:
frontend_model = HFmodel.model
for i in range(shark_args.num_warmup_iterations):
frontend_model.forward(input)
if self.device == "cuda":
torch.cuda.reset_peak_memory_stats()
begin = time.time()
for i in range(shark_args.num_iterations):
out = frontend_model.forward(input)
if i == shark_args.num_iterations - 1:
end = time.time()
break
end = time.time()
if self.device == "cuda":
stats = torch.cuda.memory_stats()
device_peak_b = stats["allocated_bytes.all.peak"]
frontend_model.to(torch.device("cpu"))
input.to(torch.device("cpu"))
torch.cuda.empty_cache()
else:
device_peak_b = None
print(
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
if self.device == "cuda":
# Set device to CPU so we don't run into segfaults exiting pytest subprocesses.
torch_device = torch.device("cpu")
return [
f"{shark_args.num_iterations/(end-begin)}",
f"{((end-begin)/shark_args.num_iterations)*1000}",
"", # host_peak_b (CPU usage) is not reported by PyTorch.
_bytes_to_mb_str(device_peak_b),
]
def benchmark_tf(self, modelname):
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
@@ -155,38 +186,55 @@ class SharkBenchmarkRunner(SharkRunner):
from tank.model_utils_tf import get_tf_model
# tf_device = "/GPU:0" if self.device == "cuda" else "/CPU:0"
tf_device = "/CPU:0"
# tf_device = TF_GPU_DEVICE if self.device == "cuda" else TF_CPU_DEVICE
tf_device = TF_CPU_DEVICE
with tf.device(tf_device):
(
model,
input,
) = get_tf_model(
modelname
modelname, self.import_args
)[:2]
frontend_model = model
for i in range(shark_args.num_warmup_iterations):
frontend_model.forward(*input)
if tf_device == TF_GPU_DEVICE:
tf.config.experimental.reset_memory_stats(tf_device)
begin = time.time()
for i in range(shark_args.num_iterations):
out = frontend_model.forward(*input)
if i == shark_args.num_iterations - 1:
end = time.time()
break
end = time.time()
if tf_device == TF_GPU_DEVICE:
memory_info = tf.config.experimental.get_memory_info(tf_device)
device_peak_b = memory_info["peak"]
else:
# tf.config.experimental does not currently support measuring
# CPU memory usage.
device_peak_b = None
print(
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
return [
f"{shark_args.num_iterations/(end-begin)}",
f"{((end-begin)/shark_args.num_iterations)*1000}",
"", # host_peak_b (CPU usage) is not reported by TensorFlow.
_bytes_to_mb_str(device_peak_b),
]
def benchmark_c(self):
result = run_benchmark_module(self.benchmark_cl)
print(f"Shark-IREE-C benchmark:{result} iter/second")
return [f"{result}", f"{1000/result}"]
iter_per_second, host_peak_b, device_peak_b = run_benchmark_module(
self.benchmark_cl
)
print(f"Shark-IREE-C benchmark:{iter_per_second} iter/second")
return [
f"{iter_per_second}",
f"{1000/iter_per_second}",
_bytes_to_mb_str(host_peak_b),
_bytes_to_mb_str(device_peak_b),
]
def benchmark_python(self, inputs):
input_list = [x for x in inputs]
@@ -196,8 +244,7 @@ class SharkBenchmarkRunner(SharkRunner):
begin = time.time()
for i in range(shark_args.num_iterations):
out = self.run("forward", input_list)
if i == shark_args.num_iterations - 1:
end = time.time()
end = time.time()
print(
f"Shark-IREE Python benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
@@ -306,11 +353,21 @@ for currently supported models. Exiting benchmark ONNX."
return comp_str
def benchmark_all_csv(
self, inputs: tuple, modelname, dynamic, device_str, frontend
self,
inputs: tuple,
modelname,
dynamic,
device_str,
frontend,
import_args,
mode="native",
):
self.setup_cl(inputs)
self.import_args = import_args
self.mode = mode
field_names = [
"model",
"batch_size",
"engine",
"dialect",
"device",
@@ -324,8 +381,19 @@ for currently supported models. Exiting benchmark ONNX."
"tags",
"notes",
"datetime",
"host_memory_mb",
"device_memory_mb",
"measured_host_memory_mb",
"measured_device_memory_mb",
]
engines = ["frontend", "shark_python", "shark_iree_c"]
# "frontend" must be the first element.
if self.mode == "native":
engines = ["shark_python", "shark_iree_c"]
if self.mode == "baseline":
engines = ["frontend"]
if self.mode == "all":
engines = ["frontend", "shark_python", "shark_iree_c"]
if shark_args.onnx_bench == True:
engines.append("onnxruntime")
@@ -336,75 +404,78 @@ for currently supported models. Exiting benchmark ONNX."
with open("bench_results.csv", mode="a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=field_names)
bench_result = {}
bench_result["model"] = modelname
bench_info = {}
bench_info["model"] = modelname
bench_info["batch_size"] = str(import_args["batch_size"])
bench_info["dialect"] = self.mlir_dialect
bench_info["iterations"] = shark_args.num_iterations
if dynamic == True:
bench_result["shape_type"] = "dynamic"
bench_info["shape_type"] = "dynamic"
else:
bench_result["shape_type"] = "static"
bench_result["device"] = device_str
bench_info["shape_type"] = "static"
bench_info["device"] = device_str
if "fp16" in modelname:
bench_result["data_type"] = "float16"
bench_info["data_type"] = "float16"
else:
bench_result["data_type"] = inputs[0].dtype
bench_info["data_type"] = inputs[0].dtype
for e in engines:
(
bench_result["param_count"],
bench_result["tags"],
bench_result["notes"],
) = ["", "", ""]
engine_result = {}
self.frontend_result = None
if e == "frontend":
bench_result["engine"] = frontend
engine_result["engine"] = frontend
if check_requirements(frontend):
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
engine_result["host_memory_mb"],
engine_result["device_memory_mb"],
) = self.benchmark_frontend(modelname)
self.frontend_result = bench_result["ms/iter"]
bench_result["vs. PyTorch/TF"] = "baseline"
self.frontend_result = engine_result["ms/iter"]
engine_result["vs. PyTorch/TF"] = "baseline"
(
bench_result["param_count"],
bench_result["tags"],
bench_result["notes"],
engine_result["param_count"],
engine_result["tags"],
engine_result["notes"],
) = self.get_metadata(modelname)
else:
self.frontend_result = None
continue
elif e == "shark_python":
bench_result["engine"] = "shark_python"
engine_result["engine"] = "shark_python"
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
) = self.benchmark_python(inputs)
bench_result[
engine_result[
"vs. PyTorch/TF"
] = self.compare_bench_results(
self.frontend_result, bench_result["ms/iter"]
self.frontend_result, engine_result["ms/iter"]
)
elif e == "shark_iree_c":
bench_result["engine"] = "shark_iree_c"
engine_result["engine"] = "shark_iree_c"
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
engine_result["host_memory_mb"],
engine_result["device_memory_mb"],
) = self.benchmark_c()
bench_result[
engine_result[
"vs. PyTorch/TF"
] = self.compare_bench_results(
self.frontend_result, bench_result["ms/iter"]
self.frontend_result, engine_result["ms/iter"]
)
elif e == "onnxruntime":
bench_result["engine"] = "onnxruntime"
engine_result["engine"] = "onnxruntime"
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
) = self.benchmark_onnx(modelname, inputs)
bench_result["dialect"] = self.mlir_dialect
bench_result["iterations"] = shark_args.num_iterations
bench_result["datetime"] = str(datetime.now())
writer.writerow(bench_result)
engine_result["datetime"] = str(datetime.now())
writer.writerow(bench_info | engine_result)

View File

@@ -61,6 +61,8 @@ def download_public_file(
continue
destination_filename = os.path.join(destination_folder_name, blob_name)
if os.path.isdir(destination_filename):
continue
with open(destination_filename, "wb") as f:
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
storage_client.download_blob_to_file(blob, file_obj)
@@ -127,33 +129,108 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
and os.path.isfile(os.path.join(model_dir, "hash.npy"))
):
print(f"""Using cached models from {WORKDIR}...""")
print(
f"""Model artifacts for {model_name} found at {WORKDIR}..."""
)
return True
return False
def _internet_connected():
import requests as req
try:
req.get("http://1.1.1.1")
return True
except:
return False
def get_git_revision_short_hash() -> str:
import subprocess
if shark_args.shark_prefix is not None:
prefix_kw = shark_args.shark_prefix
else:
import json
dir_path = os.path.dirname(os.path.realpath(__file__))
src = os.path.join(dir_path, "..", "tank_version.json")
with open(src, "r") as f:
data = json.loads(f.read())
prefix_kw = data["version"]
print(f"Checking for updates from gs://shark_tank/{prefix_kw}")
return prefix_kw
def get_sharktank_prefix():
tank_prefix = ""
if not _internet_connected():
print(
"No internet connection. Using the model already present in the tank."
)
tank_prefix = "none"
else:
desired_prefix = get_git_revision_short_hash()
storage_client_a = storage.Client.create_anonymous_client()
base_bucket_name = "shark_tank"
base_bucket = storage_client_a.bucket(base_bucket_name)
dir_blobs = base_bucket.list_blobs(prefix=f"{desired_prefix}")
for blob in dir_blobs:
dir_blob_name = blob.name.split("/")
if desired_prefix in dir_blob_name[0]:
tank_prefix = dir_blob_name[0]
break
else:
continue
if tank_prefix == "":
print(
f"shark_tank bucket not found matching ({desired_prefix}). Defaulting to nightly."
)
tank_prefix = "nightly"
return tank_prefix
# Downloads the torch model from gs://shark_tank dir.
def download_model(
model_name,
dynamic=False,
tank_url="gs://shark_tank/latest",
tank_url=None,
frontend=None,
tuned=None,
import_args={"batch_size": 1},
):
model_name = model_name.replace("/", "_")
dyn_str = "_dynamic" if dynamic else ""
os.makedirs(WORKDIR, exist_ok=True)
model_dir_name = model_name + "_" + frontend
shark_args.shark_prefix = get_sharktank_prefix()
if import_args["batch_size"] and import_args["batch_size"] != 1:
model_dir_name = (
model_name
+ "_"
+ frontend
+ "_BS"
+ str(import_args["batch_size"])
)
elif any(model in model_name for model in ["clip", "unet", "vae"]):
# TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation.
model_dir_name = model_name
else:
model_dir_name = model_name + "_" + frontend
model_dir = os.path.join(WORKDIR, model_dir_name)
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
if not tank_url:
tank_url = "gs://shark_tank/" + shark_args.shark_prefix
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
if not check_dir_exists(
model_dir_name, frontend=frontend, dynamic=dyn_str
):
print(
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
f"Downloading artifacts for model {model_name} from: {full_gs_url}"
)
download_public_file(full_gs_url, model_dir)
elif shark_args.force_update_tank == True:
print(
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
@@ -179,6 +256,7 @@ def download_model(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
except FileNotFoundError:
print(f"Model artifact hash not found at {model_dir}.")
upstream_hash = None
if local_hash != upstream_hash and shark_args.update_tank == True:
print(f"Updating artifacts for model {model_name}...")
@@ -186,17 +264,31 @@ def download_model(
elif local_hash != upstream_hash:
print(
"Hash does not match upstream in gs://shark_tank/latest. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
"Hash does not match upstream in gs://shark_tank/. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
)
else:
print(
"Local and upstream hashes match. Using cached model artifacts."
)
model_dir = os.path.join(WORKDIR, model_dir_name)
tuned_str = "" if tuned is None else "_" + tuned
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
filename = os.path.join(model_dir, model_name + suffix)
print(
f"Verifying that model artifacts were downloaded successfully to {filename}..."
)
if not os.path.exists(filename):
from tank.generate_sharktank import gen_shark_files
print(
"The model data was not found. Trying to generate artifacts locally."
)
gen_shark_files(model_name, frontend, WORKDIR, import_args)
assert os.path.exists(filename), f"MLIR not found at {filename}"
with open(filename, mode="rb") as f:
mlir_file = f.read()
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
@@ -204,13 +296,3 @@ def download_model(
inputs_tuple = tuple([inputs[key] for key in inputs])
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
return mlir_file, function_name, inputs_tuple, golden_out_tuple
def _internet_connected():
import requests as req
try:
req.get("http://1.1.1.1")
return True
except:
return False

View File

@@ -0,0 +1,38 @@
import json
from collections import OrderedDict
class GenerateConfigFile:
def __init__(
self,
model,
num_sharding_stages: int,
sharding_stages_id: list[str] = None,
):
self.model = model
self.num_sharding_stages = num_sharding_stages
self.sharding_stages_id = sharding_stages_id
assert self.num_sharding_stages == len(
self.sharding_stages_id
), "Number of sharding stages should be equal to the list of their ID"
def generate_json(self):
model_dictionary = dict()
for name, m in self.model.named_modules():
if name == "":
continue
# Remove non-leaf nodes from the config as they aren't an operation
substring_before_final_period = name.split(".")[:-1]
substring_before_final_period = ".".join(
substring_before_final_period
)
if substring_before_final_period in model_dictionary:
del model_dictionary[substring_before_final_period]
layer_dict = {n: "None" for n in self.sharding_stages_id}
model_dictionary[name] = layer_dict
with open("model_config.json", "w") as outfile:
json.dump(model_dictionary, outfile)

View File

@@ -9,8 +9,8 @@ import hashlib
def create_hash(file_name):
with open(file_name, "rb") as f:
file_hash = hashlib.blake2b()
while chunk := f.read(2**20):
file_hash = hashlib.blake2b(digest_size=64)
while chunk := f.read(2**10):
file_hash.update(chunk)
return file_hash.hexdigest()
@@ -81,7 +81,7 @@ class SharkImporter:
# NOTE: The default function for torch is "forward" and tf-lite is "main".
def _torch_mlir(self, is_dynamic, tracing_required):
def _torch_mlir(self, is_dynamic, tracing_required, mlir_type):
from shark.torch_mlir_utils import get_torch_mlir_module
return get_torch_mlir_module(
@@ -90,6 +90,7 @@ class SharkImporter:
is_dynamic,
tracing_required,
self.return_str,
mlir_type,
)
def _tf_mlir(self, func_name, save_dir="."):
@@ -120,6 +121,7 @@ class SharkImporter:
tracing_required=False,
func_name="forward",
save_dir="./shark_tmp/",
mlir_type="linalg",
):
if self.frontend in ["torch", "pytorch"]:
if self.inputs == None:
@@ -127,7 +129,10 @@ class SharkImporter:
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module"
)
sys.exit(1)
return self._torch_mlir(is_dynamic, tracing_required), func_name
return (
self._torch_mlir(is_dynamic, tracing_required, mlir_type),
func_name,
)
if self.frontend in ["tf", "tensorflow"]:
return self._tf_mlir(func_name, save_dir), func_name
if self.frontend in ["tflite", "tf-lite"]:
@@ -143,14 +148,23 @@ class SharkImporter:
# Saves `function_name.npy`, `inputs.npz`, `golden_out.npz` and `model_name.mlir` in the directory `dir`.
def save_data(
self, dir, model_name, mlir_data, func_name, inputs, outputs
self,
dir,
model_name,
mlir_data,
func_name,
inputs,
outputs,
mlir_type="linalg",
):
import numpy as np
inputs_name = "inputs.npz"
outputs_name = "golden_out.npz"
func_file_name = "function_name"
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
model_name_mlir = (
model_name + "_" + self.frontend + "_" + mlir_type + ".mlir"
)
print(f"saving {model_name_mlir} to {dir}")
try:
inputs = [x.cpu().detach() for x in inputs]
@@ -165,8 +179,17 @@ class SharkImporter:
if self.frontend == "torch":
with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file:
mlir_file.write(mlir_data)
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
hash_gen_attempts = 2
for i in range(hash_gen_attempts):
try:
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
except FileNotFoundError as err:
if i < hash_gen_attempts:
continue
else:
raise err
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
return
def import_debug(
@@ -177,19 +200,23 @@ class SharkImporter:
dir=tempfile.gettempdir(),
model_name="model",
golden_values=None,
mlir_type="linalg",
):
if self.inputs == None:
print(
f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir."
)
sys.exit(1)
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
model_name_mlir = (
model_name + "_" + self.frontend + "_" + mlir_type + ".mlir"
)
artifact_path = os.path.join(dir, model_name_mlir)
imported_mlir = self.import_mlir(
is_dynamic,
tracing_required,
func_name,
save_dir=artifact_path,
mlir_type=mlir_type,
)
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
# TODO: Check for multiple outputs.
@@ -215,6 +242,7 @@ class SharkImporter:
imported_mlir[1],
self.inputs,
golden_out,
mlir_type,
)
return (
imported_mlir,
@@ -284,6 +312,47 @@ def get_f16_inputs(inputs, is_f16, f16_input_mask):
return tuple(f16_masked_inputs)
# Upcasts the block/list of ops.
def add_upcast(fx_g):
import torch
for node in fx_g.graph.nodes:
if node.target in [torch.ops.aten.mul]:
# This is a very strict check.
if hasattr(node.args[1], "target"):
if (
node.args[1].target in [torch.ops.aten.rsqrt]
and node.args[1].args[0].target in [torch.ops.aten.add]
and node.args[1].args[0].args[0].target
in [torch.ops.aten.mean]
and node.args[1].args[0].args[0].args[0].target
in [torch.ops.aten.pow]
):
print("found an upcasting block let's upcast it.")
pow_node = node.args[1].args[0].args[0].args[0]
mul_node = node
with fx_g.graph.inserting_before(pow_node):
lhs = pow_node.args[0]
upcast_lhs = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(lhs,),
kwargs={"dtype": torch.float32},
)
pow_node.args = (upcast_lhs, pow_node.args[1])
with fx_g.graph.inserting_before(mul_node):
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(mul_node,),
kwargs={"dtype": torch.float16},
)
mul_node.append(new_node)
mul_node.replace_all_uses_with(new_node)
new_node.args = (mul_node,)
new_node.kwargs = {"dtype": torch.float16}
fx_g.graph.lint()
def transform_fx(fx_g):
import torch
@@ -292,13 +361,48 @@ def transform_fx(fx_g):
"device": torch.device(type="cpu"),
"pin_memory": False,
}
kwargs_dict1 = {
"dtype": torch.float16,
}
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.zeros,
]:
node.kwargs = kwargs_dict
if node.kwargs.get("dtype") == torch.float32:
node.kwargs = kwargs_dict
# Vicuna
if node.target in [
torch.ops.aten._to_copy,
]:
if node.kwargs.get("dtype") == torch.float32:
node.kwargs = kwargs_dict1
if node.target in [
torch.ops.aten.masked_fill,
]:
if node.args[2] > torch.finfo(torch.half).max:
max_val = torch.finfo(torch.half).max
node.args = (node.args[0], node.args[1], max_val)
elif node.args[2] < torch.finfo(torch.half).min:
min_val = torch.finfo(torch.half).min
node.args = (node.args[0], node.args[1], min_val)
if node.target in [
torch.ops.aten.full,
]:
if node.args[1] > torch.finfo(torch.half).max:
max_val = torch.finfo(torch.half).max
node.args = (node.args[0], max_val)
node.kwargs = kwargs_dict
elif node.args[1] < torch.finfo(torch.half).min:
min_val = torch.finfo(torch.half).min
node.args = (node.args[0], min_val)
node.kwargs = kwargs_dict
# Inputs and outputs of aten.var.mean should be upcasted to fp32.
if node.target in [torch.ops.aten.var_mean]:
with fx_g.graph.inserting_before(node):
@@ -308,6 +412,7 @@ def transform_fx(fx_g):
kwargs={},
)
node.args = (new_node, node.args[1])
if node.name.startswith("getitem"):
with fx_g.graph.inserting_before(node):
if node.args[0].target in [torch.ops.aten.var_mean]:
@@ -320,6 +425,7 @@ def transform_fx(fx_g):
node.replace_all_uses_with(new_node)
new_node.args = (node,)
new_node.kwargs = {"dtype": torch.float16}
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
@@ -331,6 +437,14 @@ def transform_fx(fx_g):
node.replace_all_uses_with(new_node)
new_node.args = (node,)
# Required for cuda debugging.
# for node in fx_g.graph.nodes:
# if node.op == "call_function":
# if node.kwargs.get("device") == torch.device(type="cpu"):
# new_kwargs = node.kwargs.copy()
# new_kwargs["device"] = torch.device(type="cuda")
# node.kwargs = new_kwargs
fx_g.graph.lint()
@@ -382,6 +496,9 @@ def import_with_fx(
return_str=False,
save_dir=tempfile.gettempdir(),
model_name="model",
mlir_type="linalg",
is_dynamic=False,
tracing_required=False,
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
@@ -431,6 +548,8 @@ def import_with_fx(
if is_f16:
fx_g = fx_g.half()
transform_fx(fx_g)
# TODO: Have to make it more generic.
add_upcast(fx_g)
fx_g.recompile()
if training:
@@ -438,6 +557,9 @@ def import_with_fx(
inputs = flatten_training_input(inputs)
ts_graph = torch.jit.script(fx_g)
if mlir_type == "torchscript":
return ts_graph
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)
mlir_importer = SharkImporter(
ts_graph,
@@ -448,7 +570,12 @@ def import_with_fx(
if debug: # and not is_f16:
(mlir_module, func_name), _, _ = mlir_importer.import_debug(
dir=save_dir, model_name=model_name, golden_values=golden_values
dir=save_dir,
model_name=model_name,
golden_values=golden_values,
mlir_type=mlir_type,
is_dynamic=is_dynamic,
tracing_required=tracing_required,
)
return mlir_module, func_name

View File

@@ -25,7 +25,14 @@ import sys
# supported dialects by the shark-runtime.
supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite", "tm_tensor"}
supported_dialects = {
"linalg",
"auto",
"stablehlo",
"tosa",
"tf-lite",
"tm_tensor",
}
class SharkRunner:

View File

@@ -59,6 +59,7 @@ class SharkTrainer:
"torch",
"tensorflow",
"tf",
"stablehlo",
"mhlo",
"linalg",
"tosa",
@@ -84,7 +85,7 @@ class SharkTrainer:
"tm_tensor",
extra_args=extra_args,
)
elif self.frontend in ["tensorflow", "tf", "mhlo"]:
elif self.frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
self.shark_runner = SharkRunner(
self.model,
self.input,

View File

@@ -0,0 +1,133 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
import torch
from typing import List
def get_sorted_params(named_params):
return [i[1] for i in sorted(named_params.items())]
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
torch._dynamo.config.verbose = True
var_id = 0
@make_simple_dynamo_backend
def shark_torchdynamo_backend(
fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
):
if _returns_nothing(fx_graph):
return fx_graph
removed_none_indexes = _remove_nones(fx_graph)
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
mlir_module = torch_mlir.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors"
)
from contextlib import redirect_stdout
global var_id
with open(f"linalg_gen_{var_id}.mlir", "w") as f:
with redirect_stdout(f):
print(mlir_module)
print("saving!")
var_id += 1
from shark.shark_inference import SharkInference
import io
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
)
shark_module.compile()
def compiled_callable(*inputs):
inputs = [x.numpy() for x in inputs]
result = shark_module("forward", inputs)
if was_unwrapped:
result = [
result,
]
if not isinstance(result, list):
result = torch.tensor(x)
else:
result = tuple(torch.tensor(x) for x in result)
result = list(result)
for removed_index in removed_none_indexes:
result.insert(removed_index, None)
result = tuple(result)
return result
return compiled_callable

Some files were not shown because too many files have changed in this diff Show More