Compare commits

...

203 Commits

Author SHA1 Message Date
Stefan Kapusniak
788d469c5b UI/Web Refix remaining gradio deprecation warning (#1638) 2023-07-08 13:48:36 -07:00
Stefan Kapusniak
8a59f7cc27 UI/Web add 'open folder' button to output gallery (#1634)
* Adds a button that opens the currently selected subdirectory using
the default OS file manager
* Improve output gallery handling of having images deleted out from
under it.
* Don't show VAE or LoRA lines in parameter info panel when their
value is 'None'
* Use a css class for small icon buttons on the output gallery
tab instead using the same id for multiple buttons
2023-07-08 12:44:59 -07:00
Stefan Kapusniak
1c2ec3c7a2 Some Fixes for Gradio 3.36.1 (#1637)
* Clear .style deprecation warnings.
* Re-remove download button from Nod logos.
* Add work around for `container=false` not doing what it did before on
dropdowns to the output gallery CSS
2023-07-08 11:20:34 -07:00
powderluv
af0f715e20 Unpin gradio 2023-07-08 09:41:14 -07:00
jinchen62
47ec7275e6 Fix brevitas quantize argument (#1633) 2023-07-07 11:30:31 -07:00
powderluv
3a24cff901 change binary names 2023-07-06 23:59:14 -07:00
powderluv
1f72907886 Fix the pyinstaller for chatbots (#1631) 2023-07-06 23:30:01 -07:00
Daniel Garvey
06c8aabd01 remove local-sync from webui (#1629) 2023-07-06 13:58:59 -07:00
Phaneesh Barwaria
55a12cc0c4 cpu name in device (#1628)
* show cpu name in devices

* change device order for chatbot
2023-07-06 12:00:09 -07:00
Ean Garvey
7dcbbde523 Xfail models for data tiling flag changes (#1624) 2023-07-06 06:57:17 -07:00
Abhishek Varma
1b62dc4529 [Vicuna] Revert the formatting for Brevitas op (#1626)
-- This commit reverts the formatting for Brevitas op.
-- It also excludes vicuna.py script from `black` formatter.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-07-06 06:56:17 -07:00
Daniel Garvey
c5a47887f4 Revert revert negative prompt change (#1625)
* revert default flag changes

* revert revert negative prompt change

* revert revert negative prompt change
2023-07-05 22:09:06 -07:00
Daniel Garvey
c72d0eaf87 revert default flag changes (#1622) 2023-07-05 15:43:26 -05:00
powderluv
c41f58042a Update compile_utils.py (#1617)
* Update compile_utils.py

* Update compile_utils.py

* Update compile_utils.py
2023-07-05 10:06:48 -07:00
xzuyn
043e5a5c7a fix a mistake I made, and more formatting changes, and add ++/Karras (#1619)
* fixed missing line break in `stablelm_ui.py` `start_message`
- also more formatting changes

* fix variable spelling mistake

* revert some formatting cause black wants it different

* one less line, still less than 79

* add ++, karras, and karras++ types of dpmsolver.

* black line length 79

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-05 09:00:16 -07:00
Abhishek Varma
a1b1ce935c int8 e2e for WebUI (#1620) 2023-07-05 07:08:36 -07:00
jinchen62
bc6fee1a0c Add int4/int8 vicuna (#1598) 2023-07-05 07:01:51 -07:00
xzuyn
91ab594744 minor fix, some changes, some additions, and cleaning up (#1618)
* - fix overflowing text (a janky fix)
- add DEISMultistep scheduler as an option
- set default scheduler to DEISMultistep
- set default CFG to 3.5
- set default steps to 16
- add `xzuyn/PhotoMerge` as a model option
- add 3 new example prompts (which work nicely with PhotoMerge)
- formatting

* Set DEISMultistep in the cpu_only list instead

* formatting

* formatting

* modify prompts

* resize window to 81% & 85% monitor resolution instead of (WxH / 1.0625).

* increase steps to 32 after some testing. somewhere in between 16 and 32 is best compromise on speed/quality for DEIS, so 32 steps to play it safe.

* black line length 79

* revert settings DEIS as default scheduler.

* add more schedulers & revert accidental DDIM change
- add DPMSolverSingleStep, KDPM2AncestralDiscrete, & HeunDiscrete.
- did not add `DPMSolverMultistepInverse` or `DDIMInverse` as they only output latent noise, there are a few I did not try adding yet.
- accidentally set `upscaler_ui.py` to EulerDiscrete by default last commit while reverting DEIS changes.
- add `xzuyn/PhotoMerge-inpainting` as an in or out painting model.

* black line length 79

* add help section stuff and some other changes
- list the rest of the schedulers in argument help section.
- replace mutable default arguments.
- increased default window height to 91% to remove any scrolling for the main txt2img page (tested on a 1920x1080 monitor). width is the same as its just enough to have the image output on the side instead of the bottom.
- cleanup
2023-07-04 18:51:23 -07:00
Eliasj42
4015793f84 changed method of compiling vicuna to remove first and second vicuna (#1611)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-03 12:12:43 -07:00
Ean Garvey
d63ce76dd8 Use sortable image filenames for SD outputs. (#1528) 2023-07-03 10:30:47 -07:00
Prashant Kumar
1c32915570 Add the shark compile downstream due to https://github.com/pytorch/pytorch/pull/104185#issuecomment-1615110613 (#1615) 2023-07-01 08:30:58 -07:00
Ean Garvey
6d286c0609 Enable tuning for rectangle sizes on rdna2. (#1608) 2023-06-30 22:28:24 -07:00
Stefan Kapusniak
7392b22731 UI/Web Reduce animation of default --progress_bars (#1613) 2023-06-30 21:12:10 -07:00
jinchen62
534de05791 Update precision check for vicuna (#1610) 2023-06-29 16:16:33 -05:00
Daniel Garvey
5779e8c039 int4/int8 vicuna download support (#1609)
* set task_topology_max_group to cpu_count

by default. Can be overriden with a flag of the same str

* add download for int4/int8 mlir
2023-06-29 13:35:51 -07:00
Abhishek Varma
d496053590 [SHARK] Add a compile API to use for quick testing of inference (#1606) 2023-06-28 08:40:28 -07:00
gpetters94
6274a813c9 Add unet512 support for the other StableDiffusion pipelines (#1602) 2023-06-27 12:28:57 -07:00
Gaurav Shukla
1d6a1f9f8a [vicuna] Add tokens streaming(step=3) (#1600)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-27 08:59:27 -07:00
Daniel Garvey
75672c0e28 set task_topology_max_group to cpu_count (#1594)
by default. Can be overriden with a flag of the same str
2023-06-26 14:54:06 -07:00
Prashant Kumar
74a7202173 Make the tensors contiguous. 2023-06-26 17:29:54 +05:30
Prashant Kumar
27a08735db Add the shark backend for torch.compile API. (#1596) 2023-06-26 03:53:32 -07:00
Stefan Kapusniak
eaa49cce17 UI/App - Allow text selection (#1593)
* When run in app mode on windows, allows selection of text from
non-input controls, which is the same behaviour as web mode.
2023-06-26 02:16:53 -07:00
powderluv
10657d6fb1 Disable upx 2023-06-25 07:28:52 -07:00
Stefan Kapusniak
e3ab844cd1 Fix output gallery for csv format inc. VAE & LoRA (#1591) 2023-06-24 06:20:53 -07:00
powderluv
5ce6001b41 Update stablelm_ui.py to default to fp16 2023-06-23 22:55:47 -07:00
powderluv
501d0ca52e Add sentencepiece to webui for pyinstaller 2023-06-23 22:52:06 -07:00
powderluv
b444528715 Pin torch-mlir for windows too 2023-06-23 19:19:28 -07:00
Ean Garvey
6e6c90f62b Pin torch-mlir and use local-task in OPT. (#1592) 2023-06-23 19:17:05 -07:00
AyaanShah2204
8cdb38496e Final REST API Fixes (#1590)
* fixed outpaint api and added tests

* fixed text2img api

* more elegant generator to subscriptable conversion

* final fixes
2023-06-23 16:46:47 -07:00
powderluv
726d73d6ba Revert "[vicuna] Add streaming of tokens (#1587)" (#1588)
This reverts commit 4d55e51d46.
2023-06-23 10:29:00 -07:00
Gaurav Shukla
4d55e51d46 [vicuna] Add streaming of tokens (#1587)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-23 08:20:46 -07:00
Prashant Kumar
6ef78ee7ba Add cpu compile time flags. (#1585) 2023-06-23 07:23:26 -07:00
jinchen62
4002da7161 Add int4/int8 options to chatbot webui (#1586) 2023-06-23 07:18:34 -07:00
powderluv
ecb5e8e5d8 Update txt2img_ui.py 2023-06-23 06:42:12 -07:00
PhaneeshB
28e0919321 Add AMD cpu device 2023-06-23 18:47:04 +05:30
Daniel Garvey
28f4d44a6b downloader was double downloading (#1580) 2023-06-22 18:30:27 -07:00
AyaanShah2204
97f7e79391 [Blender Integration] Fixed Inpainting REST API (#1577)
* fixed inpaint api

* added inpainting test

* fixed linter errors

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-22 16:08:26 -07:00
Nelson Sharpe
44a8f2f8db Include VAE & LoRA data into PNG metadata (#1573)
* include custom lora and vae data in png metadata

* include pycharm settings

* lint with black
2023-06-22 16:05:54 -07:00
Eliasj42
8822b9acd7 added ability to use config file to shard vicuna (#1565)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-06-22 17:40:35 -05:00
Daniel Garvey
0ca3b9fce3 fix some mmap and vicuna bugs (#1576) 2023-06-22 17:39:55 -05:00
Nithin Meganathan
045f2bb147 Add dispatch-level config file generator for manual annotation (#1566) 2023-06-22 15:11:41 -07:00
Prashant Kumar
a811b867b9 Add shark_eager mode.
-- Eager mode with step by step op compilation and execution.
2023-06-22 22:59:14 +05:30
Abhishek Varma
cdd505e2dd [SharkInference-SharkRuntime] Adds capability to mmap vmfbs
-- This commit is based on [VmModule.mmap() API](https://github.com/openxla/iree/pull/14124).
-- It thereby adds capability to mmap vmfbs in SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-22 20:43:40 +05:30
powderluv
1b0f39107c Move torch_mlir import to the top (#1574) 2023-06-21 22:31:35 -07:00
powderluv
b9b8955f74 exclude vulkan on macos 2023-06-21 22:22:27 -07:00
powderluv
6f7a85eee3 switch to metal backend for CI 2023-06-21 22:17:11 -07:00
Ranvir Singh Virk
18c8e9e51e Metal typo fix (#1572)
* fixing typos for metal changes

* black formating
2023-06-21 21:56:11 -07:00
Daniel Garvey
a202bb466a fp16 fixes for webui (#1571) 2023-06-21 20:24:02 -07:00
Ranvir Singh Virk
07c1e1d712 Adding metal_utils for iree_utils (#1561)
* Adding metal_utils for iree_utils

* Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)

-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)

* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.

* Fixing iree-metal-target-platform

* adding metal to txt2img pipeline

* Fixing Copyright date

* removing debug prints

* black lint formating

* fixing device dump

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <avarma094@gmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-21 19:09:03 -07:00
Ranvir Singh Virk
18daec78c8 Added check for python version (#1570)
* Added check for python version

* Update for PYTHON_VERSION_X_Y
2023-06-21 18:56:47 -07:00
Ean Garvey
1a8e2024d6 Exclude non-square sizes from use_tuned on rdna2 (#1568) 2023-06-21 11:36:55 -05:00
AyaanShah2204
d61b6641fb Rest API: Resolved Generator Object not Subscripatable error (#1556) 2023-06-20 19:27:41 -07:00
Phaneesh Barwaria
88cc2423cc Enable Vicuna fp16 cpu (#1562)
* fix second vic mlir gen

* fp16 mlir/vmfb download from shark_tank
2023-06-20 13:43:21 -05:00
Ean Garvey
ccf944c1bd Enable tuner for upscaler unet. (#1563) 2023-06-20 13:40:13 -05:00
Ean Garvey
0def74f520 [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)
* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.
2023-06-20 10:26:36 -07:00
Abhishek Varma
3fb72e192e Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)
-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-20 10:04:17 -07:00
Vivek Khandelwal
855435ee24 Fix for the user input for Falcon pipeline 2023-06-20 18:09:32 +05:30
Elias Joseph
6f9f868fc0 fixed a bug where designating device for vicuna didn't work 2023-06-20 17:09:32 +05:30
powderluv
fb865f1b99 Move to checkout@v3
This will break Windows again but we have to fix it up since the old node.js is now deprecated.
2023-06-19 18:44:36 -07:00
rprasad2
3e5c50f07b changes for tuning (#1542)
* Add tuning sizes for rdna3
2023-06-19 15:29:08 -05:00
powderluv
a544f30a8f Move mega to the shark examples (#1555) 2023-06-19 11:10:51 -07:00
Abhishek Varma
1fe56d460a [MEGABYTE] Add script to compile MEGABYTE through SHARK (#1553)
-- Usage: `python mega_test.py`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-19 11:00:35 -07:00
Vivek Khandelwal
fafd713141 Minor change to falcon pipeline 2023-06-19 22:36:32 +05:30
Vivek Khandelwal
015d0132c3 Modify falcon pipeline to add fp16 support (#1551) 2023-06-19 09:57:13 -07:00
powderluv
20ddd96ef7 unpin diffusers (#1550) 2023-06-18 13:45:55 -07:00
powderluv
ee33cfd2d1 Add PIL in main index.py (#1549)
* Add PIL in main index.py

This is to ensure pyinstaller picks it up

* Update index.py
2023-06-18 11:51:44 -07:00
Stefan Kapusniak
a3cba21d5b Fix load of unet512 vmfb fail on get of iree opts (#1546)
* Change retrieval of Iree options used when loading an existing
unet512 vmfb to look up the "unet" options rather than attempt to
find a non-existent set of options for "unet512"

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-18 06:42:20 -07:00
Stefan Kapusniak
a7b6ec4095 Fix unet512 always being used when --max_length=77 (#1547)
* Switches a few places in the SD pipeline where an assumption of
max_length=64 was being made, to using the actual max_length
as passed into the pipeline. This prevents unet512 always being
used and producing different images than previously when
--max_length=77
2023-06-18 06:41:25 -07:00
Ean Garvey
d80b087d95 Add PIL hidden imports to sd spec. (#1544)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-18 06:39:08 -07:00
Stefan Kapusniak
297a209608 Remove workarounds for gradio tempfile bugs (#1548) 2023-06-17 19:50:36 -07:00
gpetters94
b204113563 Add UNet512 (#1504)
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-06-17 03:46:25 -04:00
Chi_Liu
f60ab1f4fa Add Deberta to stablehlo in shark tank (#1545) 2023-06-16 13:24:44 -07:00
Surya Jasper
b203779462 Added Adreno target triples to vulkan_utils (#1543) 2023-06-15 16:42:59 -07:00
Stefan Kapusniak
38570a9bbb Some Fixes for update to gradio 3.34.0 (#1538)
* Fixes randomize seed buttons that stopped working.
* Update now deprecated method to set initial colums for output
gallery to the newer undeprecated one.
2023-06-15 01:10:36 -07:00
dependabot[bot]
a5c882f296 Bump gradio from 3.15.0 to 3.34.0 (#1518)
Bumps [gradio](https://github.com/gradio-app/gradio) from 3.15.0 to 3.34.0.
- [Release notes](https://github.com/gradio-app/gradio/releases)
- [Changelog](https://github.com/gradio-app/gradio/blob/main/CHANGELOG.md)
- [Commits](https://github.com/gradio-app/gradio/compare/v3.15.0...v3.34.0)

---
updated-dependencies:
- dependency-name: gradio
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-14 18:13:48 -07:00
Ean Garvey
eb6d11cfed Change mlir dialects for tf tests to stablehlo. (#1535)
* Change mlir dialects for tf tests to stablehlo

* Update shark_runner.py
2023-06-14 10:43:49 -07:00
Vivek Khandelwal
46184a81ac Add Falcon pipeline (#1534) 2023-06-14 09:39:16 -07:00
PhaneeshB
149165a2f0 add multi-device mutli-precision vmfb names 2023-06-14 22:08:24 +05:30
dan
bec82a665f mega vicuna merge
single endpoint in apps/language/models/scripts/vicuna.py
removed main functions from pipelines
replaced divergent utils compile with shark_importer
adds support for different precisions
2023-06-14 19:06:29 +05:30
Ean Garvey
9551490341 Remove deprecared --iree-mhlo-demote-164-to-132 flag usage. (#1533) 2023-06-13 22:40:47 -05:00
Ean Garvey
49b3ecdbca (pytest) don't run redundant tests in cpu suite (#1532) 2023-06-13 22:40:33 -05:00
Ean Garvey
f53e3594c3 OPT Refactor (#1516)
* Change script to 1.3b model and add pytorch comparison

* fix CLI command

* Match OPT transformers model updates + numerics against latest version

* Cleanup OPT sentence completion script.

* Fix formatting and add standalone validation scripts.

* Add minimal OPT wrapper and example with import_with_fx

* Rename OPT full model wrapper.

* Cleanup test scripts for OPT.
2023-06-13 22:40:07 -05:00
Ean Garvey
5562d1dfda Fix xfails for cpu pytest cases (#1527)
Adding cpu-sync and cpu-task device configs was allowing respective tests to bypass the xfail conditional for cpu pytests marked in tank/all_models.csv. This commit updates the conditional to xfail those cases for cpu-sync and cpu-task as well.
2023-06-13 17:01:51 -07:00
Stefan Kapusniak
c7b0c2961e UI/Web Improve output gallery temp file handling (#1531)
* On startup report that cleaning up of temp files is taking place, in
case it takes a long time.
* Have the output gallery tab delete any zero length temporary files
generated by gradio < 3.32.0 for its gallery control whenever it
needs to update that control with images. This prevents such
files multiplying out of control.
2023-06-13 16:25:37 -05:00
Ean Garvey
44273b0791 Fix conditional in transform_fx() (#1530) 2023-06-13 16:24:53 -05:00
Prashant Kumar
0a4c8fcb3e Minor changes in the fx transforms. 2023-06-13 21:23:35 +05:30
Stefan Kapusniak
2fec3c8169 re-indents add_upcast in shark importer (#1523)
* The two with blocks in add_upcast appear to be underindented making
SD 1.4 break on rdna3, I've pushed them out one more tab, and then
everything appears to work again.
2023-06-12 14:41:10 -05:00
Gaurav Shukla
5e7d5930dd [vicuna] Add device and precision propagation in vicuna (#1520)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-12 12:14:43 -05:00
Prashant Kumar
b6dbd20250 Modify the fx transforms. (#1521)
- The bounds are set properly.
- The upcasting and downcasting is done for vicuna.
2023-06-12 09:40:14 -07:00
Nithin Meganathan
34f1295349 Add a model config generator (#1511)
Model config generator takes a PyTorch model as input and generates a JSON file with model layers and other propperties that define sharding on a particular hardware.
2023-06-09 15:32:00 -07:00
Phaneesh Barwaria
1980d7b2c3 Cpu device map (#1515)
* update cpu iree device

* fix vmfb paths vic unsharded
2023-06-09 11:27:02 -05:00
powderluv
2cfacc5051 fix osx torch_mlir (#1513)
* fix osx torch_mlir

* Update index.py

* Update index.py
2023-06-09 00:57:26 -07:00
Phaneesh Barwaria
436f58ddc4 cli using generate and mem fixes (#1509) 2023-06-08 13:13:32 -05:00
Phaneesh Barwaria
6b29bd17c8 Enable compilation vicuna (#1507)
* add cli for unsharded vic

* enable mlir download and compile
2023-06-07 13:08:22 -07:00
Ean Garvey
2c3485ca3e Add standalone OPT sentence completion script. (#1506) 2023-06-07 10:58:03 -07:00
Daniel Garvey
f206ecc635 reenable compilation in vicuna pipeline, add flags (#1505)
* replace vicuna.py backend with pipeline

* add some memory management to fist vicuna compile

reenable compilation
2023-06-07 09:49:27 -07:00
Stefan Kapusniak
a187e05ae6 Prevent having no cuda devices breaking the UI (#1503)
Don't break the UI when the LLM tab only wants cuda devices but there
aren't any.
2023-06-06 11:41:16 -07:00
Gaurav Shukla
8c21960486 [vicuna] Set only cuda devices in vicuna UI for now
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-06 22:15:20 +05:30
Gaurav Shukla
be62fce676 [vicuna] Fix vicuna chatbot (#1499)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-06 09:23:32 -07:00
PhaneeshB
f23b778a6c remove old vicuna scripts 2023-06-06 21:35:58 +05:30
PhaneeshB
436edf900d add vic sharded pipeline 2023-06-06 21:35:58 +05:30
Gaurav Shukla
ed58c2553f [vicuna] Integrate vicuna in shark studio
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-06 20:57:48 +05:30
Stefan Kapusniak
f2ca58e844 Add .csv and .json param info to output gallery (#1495) 2023-06-06 07:08:34 -07:00
Ean Garvey
1dbcc736eb [SD] (RDNA2) Enable new tuning for sd1.4 (#1498) 2023-06-06 06:48:58 -07:00
Phaneesh Barwaria
a83808ddc5 Vicuna cuda on A100 40G (#1496)
* vic chat with memory management (precompiled vmfb)

* fix vmfb path and download
2023-06-06 15:10:33 +05:30
Ean Garvey
a07fe80530 Update OPT, ResNet example scripts. (#1492)
* Update API in OPT example.

* fix resnet50 script

* Add OPT1.3b test script.
2023-06-05 20:19:35 -07:00
Ean Garvey
d0ba3ef8fa disable use_tuned on SD1.4 for rdna2 (#1490)
this is a temporary measure while we retune SD1.4 for rdna2. The current config fails during iree-compile.
2023-06-05 19:46:16 -05:00
Stefan Kapusniak
8400529c2c Fix output gallery not using shark_tmp (#1493)
This fix the gallery component of the  output gallery dumping temporary
files into the standard folders rather than shark_tmp so those files never
got cleared out on restart and would build up.
2023-06-05 16:23:49 -05:00
powderluv
7eaee9c242 update SHARK to nodai SHARK 2023-06-05 00:44:49 -07:00
powderluv
8230eebce5 Switch to CPU torch builds for shark.whl 2023-06-05 00:36:03 -07:00
Ean Garvey
6296ea4be9 fix config handling for sd1.4 on rdna2 (#1489) 2023-06-05 00:02:30 -07:00
Ean Garvey
4151ec3a8f (pytest) tag efficientnet, mobilenet as xfails on vulkan (#1488) 2023-06-04 23:22:32 -07:00
powderluv
a2467e8d43 Enable SHARK whl packages 2023-06-04 23:21:22 -07:00
Ean Garvey
e677178bcc Replace RDNA2 SD lowering configs. (#1486) 2023-06-05 00:57:43 -05:00
Anush Elangovan
7ef1bea953 XFAIL some macos tests 2023-06-04 15:27:03 -07:00
Chi_Liu
ad89bb1413 Add distilgpt2 to stablehlo in shark tank (#1481) 2023-06-02 16:44:46 -05:00
Ean Garvey
218ed78c40 Change instances of input_type='mhlo' to 'auto' (#1482) 2023-06-02 16:43:47 -05:00
Stefan Kapusniak
6046f36ab6 UI/Web: Fix upscaler stop button (mostly) (#1479)
* UI/Web: Fix upscaler stop button

* Hook the cancel_sd function up to the Stop button.
* Adds checks for SD_STATE_CANCEL in the upscaler ui inference function.
* Set and check for SD_STATE_IDLE, SD_STATE_CANCEL in the upscaler
pipeline.

* UI/Web: lint fixes for upscaler stop button fix

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-01 22:26:55 -07:00
Foxlum
5915bf7de3 Add to and tweak vulkan configuration environments. (#1475)
* Update vulkan_target_env_utils.py

* Update vulkan_target_env_utils.py

Adjust target environment capabilities.

* Update vulkan_target_env_utils.py

black linted?
2023-06-01 22:25:20 -07:00
Phaneesh Barwaria
f0a4e59758 LLM Pipeline Wrapper (#1477)
* [LLM] Add LLM pipeline

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

* add base pipeline and stableLM

* StableLM on UI - full block

* add SLM default model name

* add vicuna with pipeline

* add one token gen api for vic

* Fix stableLM bugs

* debug vic memory

* lint fix

---------

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
Co-authored-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-31 10:17:20 -07:00
Stefan Kapusniak
1ddef26af5 Web/UI: Add an Output Gallery tab for SD (#1470)
* WebUI: Adds an Output Gallery tab

Adds an new Output Gallery tab to the ui/webui with these features:

* Subdirectory select dropdown listing subdirectories at any depth below
the <output_dir>/generated_imgs directory,
* Large, full height, gallery area displaying the images in the selected
subdirectory. Shows nod logo when no images are in the selected
subdirectory.
* Slider that changes the number of columns of images that the gallery
displays from between 1 to 16 columns (defaults to 4).
* Expandable parameter info panel showing any generation parameters
saved in the file of the selected image for PNGs, alternatively the
image's EXIF data for JPEGs
* Send to buttons for txt2img, img2img, inpaint, outpaint and upscaler.
* Auto update of gallery and gallery label (to show generation status),
when a new image is generated by any of the stable diffusion tabs, and
is outputted to the currently selected subdirectory.
* Command line option for enabling and disabling the output gallery
(defaults to enabled)
* Command line option for following symlinks when getting entries
for the subdirectory list (defaults to off, as Python os.walk doesn't
check for circular references if following symlinks)

* Reformat with black

Reformat changes with black and then adjust some places where black's
formatting then needed some rephrasing of the code to make things
clearer.

* Add back transformers and sd_cancel imports

Adds back the transformers import in index.py needed for .exe
generation. Add comment so it doesn't get mistakenly removed
next time.
Adds back sd_cancel import in upscaler.py that is currently unused
but should be being used for the 'Stop' button.
2023-05-30 13:47:48 -07:00
Chi_Liu
ba8eddb12f Add GPT3/OPT to Stablehlo in shark tank (#1468)
Co-authored-by: AmosLewis <Amos_Lewsi@foxmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-05-29 21:58:39 -07:00
yzhang93
47b346d428 Modify the lowering config format for SPIRVMatmulPromoteVectorize pipeline (#1471) 2023-05-29 21:53:48 -07:00
Ean Garvey
1b4f4f5f4d Fix download path for SD1.4 Unet. (#1469) 2023-05-26 11:59:51 -07:00
Elias Joseph
73cd7e8320 added full vicuna to vicuna.py 2023-05-26 22:06:40 +05:30
Ean Garvey
19c0ae3702 Cleanup SD pipeline utils (#1466) 2023-05-25 12:50:11 -05:00
Ean Garvey
54e57f7771 Revive SD downloads from shark_tank. (#1465) 2023-05-25 12:03:21 -05:00
PhaneeshB
6d64b8e273 vic and slm common generation base 2023-05-25 20:29:41 +05:30
PhaneeshB
a8ea0326f5 correct SLM saved vmfb naming 2023-05-25 20:29:41 +05:30
PhaneeshB
58e9194553 add Lists import 2023-05-25 20:29:41 +05:30
PhaneeshB
eb360e255d remove unused imports 2023-05-25 20:29:41 +05:30
PhaneeshB
a6f88d7f72 refactor mlir compile 2023-05-25 20:29:41 +05:30
Prashant Kumar
8e571d165f Enable cpu f16 dtype tracing for the vicuna model. (#1461) 2023-05-24 09:37:57 -07:00
Ean Garvey
3cddd01b10 Update OPT tokenizer and xfail a few more large tests on macos CI (#1459)
* Update opt_torch_test.py

* Update all_models.csv
2023-05-23 14:36:57 -07:00
Chi_Liu
64c2b2d96b Add gpt2 to stablehlo support in shark tank (#1447)
- Add torch decomposition support when generating shark tank
- Add gpt2 stablehlo
2023-05-22 10:45:51 -07:00
Phaneesh Barwaria
f5ce121988 SLM on Sharkstudio (#1454)
* localize import, fix file reading, device cpu

* extract out model args
2023-05-19 11:21:08 -07:00
Ean Garvey
991f144598 Add iree hidden imports to SD spec (#1456)
* Add iree hidden imports to SD spec

* Update shark_sd_cli.spec
2023-05-19 11:19:16 -07:00
PhaneeshB
09bea17e59 fix #2 SLM in SharkStudio 2023-05-18 00:56:22 +05:30
Daniel Garvey
aefcf80b48 swap to cpu an remove hardcoded paths (#1448)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-05-17 10:53:34 -07:00
PhaneeshB
512235892e fix SLM for SharkStudio 2023-05-17 22:34:30 +05:30
PhaneeshB
6602a2f5ba add continuous output for CLI 2023-05-17 18:33:46 +05:30
Boian Petkantchin
20114deea0 In MiniLM JAX example verify MLIR result against JAX 2023-05-16 09:54:07 -07:00
Boian Petkantchin
9acf519078 Add option to skip venv creation in setup script 2023-05-16 09:54:07 -07:00
Boian Petkantchin
bdf37b5311 If device/backend is unknown pass it to IREE verbatim 2023-05-16 09:54:07 -07:00
powderluv
8ee2ac89f8 Rename sharded_vicuna_fp32_web.py to vicuna_web.py 2023-05-16 09:41:35 -07:00
powderluv
60cb48be2e Rename sharded_vicuna_fp32.py to vicuna.py 2023-05-16 09:40:51 -07:00
powderluv
86a215b063 Delete sharded_vicunia.py 2023-05-16 09:37:39 -07:00
powderluv
d6e3a9a236 Delete standalone_vicuna.py 2023-05-16 09:37:26 -07:00
Chi_Liu
a0097a1ead Add mlir_type for torch_model_list.csv (#1428)
- Enable stablehlo/tosa mlir output for torch model
- Add BERT stablehlo support
2023-05-15 10:23:54 -07:00
Ean Garvey
a9bae00606 Fix vulkan device selection at compile time and adapt to IREE python changes. (#1407)
* Add support for vulkan device selection at compile time.

* Don't convert device ID to int and fix .exe imports
2023-05-12 23:31:50 -07:00
Daniel Garvey
4731c1a835 prevent loading tokenizer on import (#1432)
also adds sentencepiece dep for exe
moved vicuna imports to after an if statement
in general we should avoid importing files that load whole models as
global variables
2023-05-12 19:11:45 -07:00
Ean Garvey
4c07e47e8c Specify a few models for expected failure on CUDA CI. (#1430) 2023-05-12 17:03:37 -05:00
Gaurav Shukla
e0cc2871bb [SD] Yield 2 tokens at a time in vicuna
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 23:49:01 +05:30
Gaurav Shukla
649f39408b [SD] Fix vicuna response
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 18:06:21 +05:30
Gaurav Shukla
c142297d73 [SD] Fix gradio to 3.22.0 version
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com
2023-05-11 18:05:55 +05:30
Gaurav Shukla
9e07360b00 [SD] Standalone vicuna with web
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 17:23:44 +05:30
Gaurav Shukla
7b74c86e42 [SD] Fix SAMPLE_INPUT_LEN import issue
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 15:41:43 +05:30
Eliasj42
fa833f8366 fixed spacing issue with chat-bot (#1417)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-10 16:07:50 -07:00
Gaurav Shukla
fcb059aa38 [SD] Integrate vicuna in the web (#1410) 2023-05-10 11:30:22 -07:00
PhaneeshB
517c670f82 vicuna chat cli 2023-05-10 22:55:06 +05:30
Eliasj42
59df14f18b added vicuna demo (#1408)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-09 21:18:20 -07:00
Ean Garvey
6c95ac0f37 Revert dialect registration in model annotator (#1406)
Matches https://github.com/nod-ai/SHARK-Runtime/pull/58
2023-05-09 11:50:19 -07:00
Daniel Garvey
7a4a51ae73 vulkan vic f16 (#1404)
Co-authored-by: dan <dan@nod-labs.com>
2023-05-08 16:46:53 -07:00
powderluv
d816cc015e Revert "added standalone vicuna script (#1399)" (#1402)
This reverts commit 0e4a8ca240.
2023-05-05 16:08:05 -07:00
Eliasj42
54ce3d48ca added standalone vicuna script (#1401)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-05 18:05:52 -05:00
Eliasj42
0e4a8ca240 added standalone vicuna script (#1399)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-05 15:46:05 -07:00
Daniel Garvey
6ca1298675 maximizes window size for webview launch (#1394) 2023-05-04 20:43:06 -07:00
jinchen62
bbef7a6464 Redesign model manager webui (#1391) 2023-05-04 20:41:29 -07:00
Ean Garvey
cdf2d61d53 Remove imports from iree.compiler.transforms from model annotator. (#1392) 2023-05-04 20:40:19 -07:00
Ean Garvey
6c14847d1f xfail some large tests on macOS builder and switch to hash updates. (#1341)
* Update test-models.yml

* Disable large tests on macOS builder
2023-05-04 19:47:03 -05:00
Gaurav Shukla
68ecdd2a73 [SD] Add LoRA as experimental tab
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-04 22:30:25 +05:30
Gaurav Shukla
3f4d444d18 [SD] Fix stable LM chatbot
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-04 22:30:25 +05:30
m68k-fr
e473d0375b [Web] Models folders cleanup (#1365) 2023-05-03 16:13:20 -05:00
Ean Garvey
e38d96850f Fix input image loading in img2img rest API (#1388) 2023-05-03 15:51:00 -05:00
Gaurav Shukla
fed63dfd4b [SD] Add stableLM chatbot (#1383)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-05-03 15:37:20 -05:00
Boian Petkantchin
eba4d06405 In MiniLM JAX example do not hardcode device (#1385)
* In MiniLM JAX example do not hardcode device

* In MiniLM JAX example don't use bytecode MLIR

---------

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-05-03 10:34:42 -07:00
Boian Petkantchin
4cfba153d2 Add example JAX MiniLM inference (#1380)
* Do not hardcode the name of the VM module in get_iree_module

* Add example JAX MiniLM inference

---------

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-05-02 15:03:54 -07:00
jinchen62
307c05f38d Convert original vae to diffusers (#1382) 2023-05-02 01:27:28 -07:00
jinchen62
696df349cb Fix curl issue (#1369) 2023-04-28 09:31:14 -07:00
jinchen62
cb54cb1348 Add model manager tab for SD webui (#1368) 2023-04-28 02:43:40 -07:00
Daniel Garvey
9bdb86637d add tkinter launch for webui (#1364) 2023-04-27 19:17:55 -05:00
jinchen62
fb6f26517f Fix webui note (#1367) 2023-04-27 16:14:43 -07:00
Chi_Liu
aa8ada9da9 Add support for torch to stablehlo and tosa in shark_importer (#1360) 2023-04-27 08:09:45 -07:00
powderluv
1db906a373 Revert "Add model manager tab for webui (#1359)" (#1362)
This reverts commit 9d1d1617d8.
2023-04-26 22:25:26 -07:00
jinchen62
9d1d1617d8 Add model manager tab for webui (#1359) 2023-04-26 13:38:18 -07:00
jinchen62
7112789cb8 Add support of using civitai model download url (#1357) 2023-04-25 23:39:52 -07:00
jinchen62
d6b8be2849 Add drawing canvas for img2img stencil scribble (#1355) 2023-04-25 14:41:01 -07:00
powderluv
822171277c Revert "[SD] Add FastChat as part of SD WebUI (#1349)" (#1350)
This reverts commit a5ae9d9f02.
2023-04-24 15:22:25 -07:00
Abhishek Varma
a5ae9d9f02 [SD] Add FastChat as part of SD WebUI (#1349)
-- This commit includes FastChat as part of SD WebUI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-04-24 11:12:58 -07:00
powderluv
09e3f63d5b Fix pascal (#1346)
* Add fp32 for upscaler VAE

* Plumb Pascal vulkan support
2023-04-23 20:28:25 -07:00
powderluv
d60a5a9396 Add fp32 for upscaler VAE (#1345) 2023-04-23 15:27:55 -07:00
m68k-fr
90df0ee365 [Web] Gallery set to a 768px reference for high-end desktop users (#1344) 2023-04-23 11:48:06 -07:00
nirvedhmeshram
133c1bcadd add device to scheduler model names (#1338) 2023-04-22 20:13:56 -05:00
116 changed files with 8641 additions and 1220 deletions

View File

@@ -2,4 +2,4 @@
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py

View File

@@ -50,27 +50,13 @@ jobs:
shell: powershell
run: |
./setup_venv.ps1
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
python process_skipfiles.py
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
# GHA windows VM OOMs so disable for now
#- name: Build and validate the SHARK Runtime package
# shell: powershell
# run: |
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
#- uses: actions/upload-artifact@v2
# with:
# path: dist/*
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets
uses: dwenegar/upload-release-assets@v1
@@ -78,7 +64,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
assets_path: ./dist/*
assets_path: ./dist/nodai*
#asset_content_type: application/vnd.microsoft.portable-executable
- name: Publish Release

View File

@@ -35,6 +35,8 @@ jobs:
include:
- os: ubuntu-latest
suite: lint
- os: MacStudio
suite: metal
exclude:
- os: ubuntu-latest
suite: vulkan
@@ -46,6 +48,8 @@ jobs:
suite: cuda
- os: MacStudio
suite: cpu
- os: MacStudio
suite: vulkan
- os: icelake
suite: vulkan
- os: icelake
@@ -61,7 +65,6 @@ jobs:
steps:
- uses: actions/checkout@v3
if: matrix.os != '7950x'
- name: Set Environment Variables
if: matrix.os != '7950x'
@@ -84,9 +87,6 @@ jobs:
#cache-dependency-path: |
# **/requirements-importer.txt
# **/requirements.txt
- uses: actions/checkout@v2
if: matrix.os == '7950x'
- name: Install dependencies
if: matrix.suite == 'lint'
@@ -129,15 +129,14 @@ jobs:
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
if: matrix.suite == 'metal' && matrix.os == 'MacStudio'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan --update_tank
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'

4
.gitignore vendored
View File

@@ -2,6 +2,8 @@
__pycache__/
*.py[cod]
*$py.class
*.mlir
*.vmfb
# C extensions
*.so
@@ -157,7 +159,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
# vscode related
.vscode

View File

@@ -1,25 +1,14 @@
import torch
import shark
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
import torch_mlir
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
import torch_mlir
from apps.stable_diffusion.src.utils import (
base_models,
get_opt_flags,
get_vmfb_path_name,
)
from apps.stable_diffusion.src.models.model_wrappers import replace_shape_str
import os
from io import BytesIO
tokenizer = AutoTokenizer.from_pretrained(
"stabilityai/stablelm-tuned-alpha-7b"
from pathlib import Path
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
get_vmfb_from_path,
)
@@ -34,6 +23,97 @@ class StopOnTokens(StoppingCriteria):
return False
def shouldStop(tokens):
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if tokens[0][-1] == stop_id:
return True
return False
MAX_SEQUENCE_LENGTH = 256
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def compile_stableLM(
model,
model_inputs,
model_name,
model_vmfb_name,
device="cuda",
precision="fp32",
):
from shark.shark_inference import SharkInference
# device = "cuda" # "cpu"
# TODO: vmfb and mlir name should include precision and device
vmfb_path = (
Path(model_name + f"_{device}.vmfb")
if model_vmfb_name is None
else Path(model_vmfb_name)
)
shark_module = get_vmfb_from_path(
vmfb_path, device, mlir_dialect="tm_tensor"
)
if shark_module is not None:
return shark_module
mlir_path = Path(model_name + ".mlir")
print(
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
)
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
)
print("Saved vmfb at ", str(path))
return shark_module
class StableLMModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
combine_input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
output = self.model(**combine_input_dict)
return output.logits
# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
@@ -41,167 +121,90 @@ system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM will refuse to participate in anything that could harm a human.
"""
prompt = f"{system_prompt}<|USER|>What's your mood today?<|ASSISTANT|>"
inputs = tokenizer(prompt, return_tensors="pt")
def get_tokenizer():
model_path = "stabilityai/stablelm-tuned-alpha-3b"
tok = AutoTokenizer.from_pretrained(model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
print("Sucessfully loaded the tokenizer to the memory")
return tok
class SLM(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
"stabilityai/stablelm-tuned-alpha-7b"
# sharkStableLM = compile_stableLM
# (
# None,
# tuple([input_ids, attention_mask]),
# "stableLM_linalg_f32_seqLen256",
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
# )
def generate(
new_text,
max_new_tokens,
sharkStableLM,
tokenizer=None,
):
if tokenizer is None:
tokenizer = get_tokenizer()
# Construct the input message string for the model by
# concatenating the current system message and conversation history
# Tokenize the messages string
# sharkStableLM = compile_stableLM
# (
# None,
# tuple([input_ids, attention_mask]),
# "stableLM_linalg_f32_seqLen256",
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
# )
words_list = []
for i in range(max_new_tokens):
# numWords = len(new_text.split())
# if(numWords>220):
# break
params = {
"new_text": new_text,
}
generated_token_op = generate_new_token(
sharkStableLM, tokenizer, params
)
def forward(self, input_ids, attention_mask):
return self.model(input_ids, attention_mask)[0]
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]
if stop_generation:
break
print(detok, end="", flush=True)
words_list.append(detok)
if detok == "":
break
new_text = new_text + detok
return words_list
slm_model = SLM()
res_pytorch = slm_model(inputs["input_ids"], inputs["attention_mask"])
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
fx_g = make_fx(
slm_model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(inputs["input_ids"], inputs["attention_mask"])
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
transform_fx(fx_g)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
module = torch_mlir.compile(
ts_g,
[inputs["input_ids"], inputs["attention_mask"]],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
)
shark_module.compile()
result_shark = shark_module(
"forward", [inputs["input_ids"], inputs["attention_mask"]]
)
print("Result PyTorch")
print(res_pytorch)
print("Result SHARK")
print(result_shark)
def generate_new_token(shark_model, tokenizer, params):
new_text = params["new_text"]
model_inputs = tokenizer(
[new_text],
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
output = shark_model(
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
)
output = torch.from_numpy(output)
next_toks = torch.topk(output, 1)
stop_generation = False
if shouldStop(next_toks.indices):
stop_generation = True
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
detok = tokenizer.decode(
new_token,
skip_special_tokens=True,
)
ret_dict = {
"new_token": new_token,
"detok": detok,
"stop_generation": stop_generation,
}
return ret_dict

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,22 @@
import torch
class FalconModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": None,
"use_cache": True,
}
output = self.model(
**input_dict,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)[0]
return output[:, -1, :]

View File

@@ -0,0 +1,15 @@
import torch
class StableLMModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
combine_input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
output = self.model(**combine_input_dict)
return output.logits

View File

@@ -0,0 +1,308 @@
import torch
from transformers import AutoModelForCausalLM
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
class FirstVicuna(torch.nn.Module):
def __init__(self, model_path, precision="fp32", weight_group_size=128):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("First Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048,
)
print("Weight quantization applied.")
def forward(self, input_ids):
op = self.model(input_ids=input_ids, use_cache=True)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class SecondVicuna(torch.nn.Module):
def __init__(self, model_path, precision="fp32", weight_group_size=128):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048,
)
print("Weight quantization applied.")
def forward(
self,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
):
# input_ids = input_tuple[0]
# input_tuple = torch.unbind(pkv, dim=0)
token = i0
past_key_values = (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class CombinedModel(torch.nn.Module):
def __init__(
self,
first_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF",
second_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF",
):
super().__init__()
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
def forward(self, input_ids):
first_output = self.first_vicuna(input_ids=input_ids, use_cache=True)
logits = first_output[0]
pkv = first_output[1:]
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
second_output = self.second_vicuna(secondVicunaInput)
return second_output

View File

@@ -0,0 +1,228 @@
import torch
class FirstVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states, attention_mask, position_ids):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class SecondVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
past_key_value0,
past_key_value1,
),
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers = layers
self.norm = norm
self.embedding = embedding
self.lmhead = lmhead
self.model.model.norm = self.norm
self.model.model.embed_tokens = self.embedding
self.model.lm_head = self.lmhead
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers
)
def forward(
self,
input_ids,
is_first=True,
past_key_values=None,
attention_mask=None,
):
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
class LMHead(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states):
output = self.model(hidden_states)
return output
class LMHeadCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, hidden_states):
hidden_states = hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
class VicunaNorm(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states):
output = self.model(hidden_states)
return output
class VicunaNormCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, hidden_states):
hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
class VicunaEmbedding(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids):
output = self.model(input_ids)
return output
class VicunaEmbeddingCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,))
output = torch.tensor(output)
return output
class CompiledVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
if past_key_value is None:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"first_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
else:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"second_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)

View File

@@ -0,0 +1,41 @@
from abc import ABC, abstractmethod
class SharkLLMBase(ABC):
def __init__(
self, model_name, hf_model_path=None, max_num_tokens=512
) -> None:
self.model_name = model_name
self.hf_model_path = hf_model_path
self.max_num_tokens = max_num_tokens
self.shark_model = None
self.device = "cpu"
self.precision = "fp32"
@classmethod
@abstractmethod
def compile(self):
pass
@classmethod
@abstractmethod
def generate(self, prompt):
pass
@classmethod
@abstractmethod
def generate_new_token(self, params):
pass
@classmethod
@abstractmethod
def get_tokenizer(self):
pass
@classmethod
@abstractmethod
def get_src_model(self):
pass
def load_init_from_config(self):
pass

View File

@@ -0,0 +1,512 @@
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.utils import (
get_vmfb_from_path,
)
from io import BytesIO
from pathlib import Path
from contextlib import redirect_stdout
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
StoppingCriteriaList,
)
import copy
import re
import torch
import torch_mlir
import os
import argparse
parser = argparse.ArgumentParser(
prog="falcon runner",
description="runs a falcon model",
)
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
parser.add_argument(
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
"--falcon_vmfb_path", default=None, help="path to falcon's vmfb"
)
parser.add_argument(
"--falcon_mlir_path",
default=None,
help="path to falcon's mlir file",
)
parser.add_argument(
"--use_precompiled_model",
default=True,
action=argparse.BooleanOptionalAction,
help="use the precompiled vmfb",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
parser.add_argument(
"--cli",
default=True,
action=argparse.BooleanOptionalAction,
help="Run model in cli mode",
)
class Falcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path,
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=None,
falcon_vmfb_path=None,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_padding_length = 100
self.device = device
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.src_model = self.get_src_model()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, trust_remote_code=True
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 11
return tokenizer
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return falcon_model
def compile_falcon(self):
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ "_"
+ self.device
+ ".vmfb",
self.falcon_vmfb_path.absolute(),
single_file=True,
)
vmfb = get_vmfb_from_path(
self.falcon_vmfb_path, self.device, "linalg"
)
if vmfb is not None:
return vmfb
print(
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
if not mlir_generated:
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 100)
)
compilation_attention_mask = torch.ones(
1, 100, dtype=torch.int64
)
falconCompileInput = (
compilation_input_ids,
compilation_attention_mask,
)
model = FalconModel(self.src_model)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
falconCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
print(f"[DEBUG] generating torch mlir")
module = torch_mlir.compile(
ts_graph,
[*falconCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module
print(f"[DEBUG] writing mlir to file")
with open(f"{self.model_name}.mlir", "wb") as f_:
with redirect_stdout(f_):
print(module.operation.get_asm())
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-spirv-index-bits=64",
],
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile(self):
falcon_shark_model = self.compile_falcon()
return falcon_shark_model
def generate(self, prompt):
model_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.max_padding_length,
add_special_tokens=False,
return_tensors="pt",
)
model_inputs["prompt_text"] = prompt
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
generate_kwargs = {
"max_length": self.max_num_tokens,
"do_sample": True,
"top_k": 10,
"num_return_sequences": 1,
"eos_token_id": 11,
}
generate_kwargs["input_ids"] = input_ids
generate_kwargs["attention_mask"] = attention_mask
generation_config_ = GenerationConfig.from_model_config(
self.src_model.config
)
generation_config = copy.deepcopy(generation_config_)
model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
stopping_criteria = StoppingCriteriaList()
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
model_kwargs,
) = self.src_model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
self.logits_processor = self.src_model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids.shape[-1],
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.src_model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
)
self.logits_warper = self.src_model._get_logits_warper(
generation_config
)
(
self.input_ids,
self.model_kwargs,
) = self.src_model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
**model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id) if eos_token_id is not None else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
output_attentions = generation_config.output_attentions # False
output_hidden_states = generation_config.output_hidden_states # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
all_text = prompt
for i in range(self.max_num_tokens - 1):
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
all_text = all_text + new_word
print(f"{new_word}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(input_ids, self.scores)
):
break
torch.cuda.empty_cache()
gc.collect()
return all_text
def generate_new_token(self):
model_inputs = self.src_model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = torch.from_numpy(
self.shark_model(
"forward",
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
if self.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
if __name__ == "__main__":
args = parser.parse_args()
falcon_mlir_path = (
Path(
"falcon_"
+ args.falcon_variant_to_use
+ "_"
+ args.precision
+ ".mlir"
)
if args.falcon_mlir_path is None
else Path(args.falcon_mlir_path)
)
falcon_vmfb_path = (
Path(
"falcon_"
+ args.falcon_variant_to_use
+ "_"
+ args.precision
+ "_"
+ args.device
+ ".vmfb"
)
if args.falcon_vmfb_path is None
else Path(args.falcon_vmfb_path)
)
falcon = Falcon(
"falcon_" + args.falcon_variant_to_use,
hf_model_path="tiiuae/falcon-"
+ args.falcon_variant_to_use
+ "-instruct",
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
falcon_vmfb_path=falcon_vmfb_path,
)
import gc
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
continue_execution = True
print("\n-----\nScript executing for the following config: \n")
print("Falcon Model: ", falcon.model_name)
print("Precision: ", args.precision)
print("Device: ", args.device)
while continue_execution:
use_default_prompt = input(
"\nDo you wish to use the default prompt text? Y/N ?: "
)
if use_default_prompt in ["Y", "y"]:
prompt = default_prompt_text
else:
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)
res_str = falcon.generate(prompt)
torch.cuda.empty_cache()
gc.collect()
print(
"\n\n-----\nHere's the complete formatted result: \n\n",
res_str,
)
continue_execution = input(
"\nDo you wish to run script one more time? Y/N ?: "
)
continue_execution = (
True if continue_execution in ["Y", "y"] else False
)

View File

@@ -0,0 +1,185 @@
import torch
import torch_mlir
from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
from io import BytesIO
from pathlib import Path
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
get_vmfb_from_path,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.stablelm_model import (
StableLMModel,
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
class SharkStableLM(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
max_num_tokens=512,
device="cuda",
precision="fp32",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
self.precision = precision
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
def shouldStop(self, tokens):
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if tokens[0][-1] == stop_id:
return True
return False
def get_src_model(self):
model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, torch_dtype=torch.float32
)
return model
def get_model_inputs(self):
input_ids = torch.randint(3, (1, self.max_sequence_len))
attention_mask = torch.randint(3, (1, self.max_sequence_len))
return input_ids, attention_mask
def compile(self):
tmp_model_name = (
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
)
# device = "cuda" # "cpu"
# TODO: vmfb and mlir name should include precision and device
model_vmfb_name = None
vmfb_path = (
Path(tmp_model_name + f"_{self.device}.vmfb")
if model_vmfb_name is None
else Path(model_vmfb_name)
)
shark_module = get_vmfb_from_path(
vmfb_path, self.device, mlir_dialect="tm_tensor"
)
if shark_module is not None:
return shark_module
mlir_path = Path(tmp_model_name + ".mlir")
print(
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
model = StableLMModel(self.get_src_model())
model_inputs = self.get_model_inputs()
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(tmp_model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
)
print("Saved vmfb at ", str(path))
return shark_module
def get_tokenizer(self):
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
return tok
def generate(self, prompt):
words_list = []
for i in range(self.max_num_tokens):
params = {
"new_text": prompt,
}
generated_token_op = self.generate_new_token(params)
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]
if stop_generation:
break
print(detok, end="", flush=True) # this is for CLI and DEBUG
words_list.append(detok)
if detok == "":
break
prompt = prompt + detok
return words_list
def generate_new_token(self, params):
new_text = params["new_text"]
model_inputs = self.tokenizer(
[new_text],
padding="max_length",
max_length=self.max_sequence_len,
truncation=True,
return_tensors="pt",
)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
output = self.shark_model(
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
)
output = torch.from_numpy(output)
next_toks = torch.topk(output, 1)
stop_generation = False
if self.shouldStop(next_toks.indices):
stop_generation = True
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
detok = self.tokenizer.decode(
new_token,
skip_special_tokens=True,
)
ret_dict = {
"new_token": new_token,
"detok": detok,
"stop_generation": stop_generation,
}
return ret_dict
# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""

View File

@@ -0,0 +1,25 @@
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from pathlib import Path
# expects a Path / str as arg
# returns None if path not found or SharkInference module
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
if not isinstance(vmfb_path, Path):
vmfb_path = Path(vmfb_path)
from shark.shark_inference import SharkInference
if not vmfb_path.exists():
return None
print("Loading vmfb from: ", vmfb_path)
shark_module = SharkInference(
None, device=device, mlir_dialect=mlir_dialect
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
return shark_module

View File

@@ -10,7 +10,7 @@ Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=mhlo for tf models
# use iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb

View File

@@ -103,6 +103,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
total_time = time.time() - start_time

View File

@@ -81,6 +81,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -79,6 +79,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -223,7 +223,8 @@ def lora_train(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, both must not be "
"empty.",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:

View File

@@ -17,6 +17,10 @@ from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def load_mlir_module():
if "upscaler" in args.hf_model_id:
is_upscaler = True
else:
is_upscaler = False
sd_model = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
@@ -27,6 +31,7 @@ def load_mlir_module():
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
is_upscaler=is_upscaler,
use_tuned=False,
low_cpu_mem_usage=args.low_cpu_mem_usage,
return_mlir=True,

View File

@@ -61,6 +61,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -21,7 +21,7 @@ if __name__ == "__main__":
print("Flag --img_path is required.")
exit()
# When the models get uploaded, it should be default to False.
# When the models get uploaded, it should be defaulted to False.
args.import_mlir = True
cpu_scheduling = not args.scheduler.startswith("Shark")
@@ -73,6 +73,7 @@ if __name__ == "__main__":
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -14,21 +14,29 @@ datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += copy_metadata('Pillow')
datas += copy_metadata('sentencepiece')
datas += collect_data_files('tokenizers')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv-python')
datas += collect_data_files('opencv_python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('google_cloud_storage')
datas += collect_data_files('shark')
datas += collect_data_files('tkinter')
datas += collect_data_files('webview')
datas += collect_data_files('sentencepiece')
datas += collect_data_files('jsonschema')
datas += collect_data_files('jsonschema_specifications')
datas += collect_data_files('cpuinfo')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
@@ -44,6 +52,8 @@ block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("transformers") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['web/index.py'],
@@ -69,11 +79,11 @@ exe = EXE(
a.zipfiles,
a.datas,
[],
name='shark_sd',
name='nodai_shark_studio',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx=False,
upx_exclude=[],
runtime_tmpdir=None,
console=True,

View File

@@ -29,6 +29,7 @@ datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('py-cpuinfo')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
@@ -42,6 +43,7 @@ block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['scripts/main.py'],

View File

@@ -1,9 +1,11 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from collections import defaultdict
from pathlib import Path
import torch
import safetensors.torch
import traceback
import subprocess
import sys
import os
from apps.stable_diffusion.src.utils import (
@@ -12,6 +14,7 @@ from apps.stable_diffusion.src.utils import (
base_models,
args,
preprocessCKPT,
convert_original_vae,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
@@ -42,6 +45,7 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(width * mul_val)
elif "/" in shape[i]:
import math
div_val = int(shape[i].split("/")[1])
if "batch_size" in shape[i]:
new_shape.append(math.ceil(batch_size / div_val))
@@ -56,7 +60,9 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
def check_compilation(model, model_name):
if not model:
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
raise Exception(
f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
)
class SharkifyStableDiffusionModel:
@@ -91,10 +97,25 @@ class SharkifyStableDiffusionModel:
self.custom_weights = custom_weights
self.use_quantize = use_quantize
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = (
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
)
if not os.path.isfile(weights_path):
subprocess.run(
["wget", custom_weights, "-O", weights_path]
)
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(
custom_weights
)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
@@ -114,7 +135,7 @@ class SharkifyStableDiffusionModel:
+ "_"
+ precision
)
print(f'use_tuned? sharkify: {use_tuned}')
print(f"use_tuned? sharkify: {use_tuned}")
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
@@ -151,14 +172,24 @@ class SharkifyStableDiffusionModel:
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
sub_model_list = [
"clip",
"unet",
"unet512",
"stencil_unet",
"vae",
"vae_encode",
"stencil_adaptor",
]
index = 0
for model in sub_model_list:
sub_model = model
model_config = self.model_name
if "vae" == model:
if self.custom_vae != "":
model_config = model_config + get_path_stem(self.custom_vae)
model_config = model_config + get_path_stem(
self.custom_vae
)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
@@ -185,7 +216,11 @@ class SharkifyStableDiffusionModel:
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, self.max_len, self.width, self.height, self.batch_size
shape,
self.max_len,
self.width,
self.height,
self.batch_size,
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
@@ -197,10 +232,12 @@ class SharkifyStableDiffusionModel:
sys.exit("shape isn't specified correctly.")
input_map.append(tensor)
return input_map
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
@@ -214,7 +251,11 @@ class SharkifyStableDiffusionModel:
vae_encode = VaeEncodeModel()
inputs = tuple(self.inputs["vae_encode"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
is_f16 = (
True
if not self.is_upscaler and self.precision == "fp16"
else False
)
shark_vae_encode, vae_encode_mlir = compile_through_fx(
vae_encode,
inputs,
@@ -231,7 +272,13 @@ class SharkifyStableDiffusionModel:
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
@@ -267,7 +314,11 @@ class SharkifyStableDiffusionModel:
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
is_f16 = (
True
if not self.is_upscaler and self.precision == "fp16"
else False
)
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
@@ -291,7 +342,10 @@ class SharkifyStableDiffusionModel:
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -304,12 +358,43 @@ class SharkifyStableDiffusionModel:
self.in_channels = self.unet.in_channels
self.train(False)
def forward( self, latent, timestep, text_embedding, guidance_scale, control1,
control2, control3, control4, control5, control6, control7,
control8, control9, control10, control11, control12, control13,
def forward(
self,
latent,
timestep,
text_embedding,
guidance_scale,
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
control13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple([ control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12,])
db_res_samples = tuple(
[
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]
)
mb_res_samples = control13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
@@ -330,7 +415,25 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
input_mask = [
True,
True,
True,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
@@ -374,16 +477,23 @@ class SharkifyStableDiffusionModel:
stencil_image = torch.cat(
[stencil_image_input] * 2
) # needs to be same as controlledUNET latents
down_block_res_samples, mid_block_res_sample = self.cnet.forward(
(
down_block_res_samples,
mid_block_res_sample,
) = self.cnet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
controlnet_cond=stencil_image,
return_dict=False,
)
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
return tuple(
list(down_block_res_samples) + [mid_block_res_sample]
)
scnet = StencilControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
scnet = StencilControlNetModel(
low_cpu_mem_usage=self.low_cpu_mem_usage
)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
@@ -403,9 +513,14 @@ class SharkifyStableDiffusionModel:
)
return shark_cnet, cnet_mlir
def get_unet(self):
def get_unet(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
@@ -414,17 +529,26 @@ class SharkifyStableDiffusionModel:
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.in_channels = self.unet.config.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
if(args.attention_slicing.isdigit()):
self.unet.set_attention_slice(int(args.attention_slicing))
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
# TODO: Instead of flattening the `control` try to use the list.
def forward(
self, latent, timestep, text_embedding, guidance_scale,
self,
latent,
timestep,
text_embedding,
guidance_scale,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
@@ -440,17 +564,33 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["unet512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["unet"]
)
input_mask = [True, True, True, False]
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
@@ -459,15 +599,17 @@ class SharkifyStableDiffusionModel:
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_unet_upscaler(self):
def get_unet_upscaler(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
@@ -490,17 +632,27 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
input_mask = [True, True, True, False]
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
@@ -508,7 +660,12 @@ class SharkifyStableDiffusionModel:
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
@@ -516,7 +673,9 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
update_lora_weight(
self.text_encoder, use_lora, "text_encoder"
)
def forward(self, input):
return self.text_encoder(input)[0]
@@ -555,30 +714,47 @@ class SharkifyStableDiffusionModel:
vae_checkpoint = None
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
if custom_vae.endswith(".ckpt"):
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
vae_checkpoint = torch.load(
self.custom_vae, map_location="cpu"
)
else:
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
vae_checkpoint = safetensors.torch.load_file(
self.custom_vae, device="cpu"
)
if "state_dict" in vae_checkpoint:
vae_checkpoint = vae_checkpoint["state_dict"]
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
def compile_unet_variants(self, model):
try:
vae_checkpoint = convert_original_vae(vae_checkpoint)
finally:
vae_dict = {
k: v
for k, v in vae_checkpoint.items()
if k[0:4] != "loss" and k not in vae_ignore_keys
}
return vae_dict
def compile_unet_variants(self, model, use_large=False):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler()
return self.get_unet_upscaler(use_large=use_large)
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
from apps.stable_diffusion.src.models.opt_params import (
get_unet,
)
return get_unet()
else:
return self.get_unet()
return self.get_unet(use_large=use_large)
else:
return self.get_controlled_unet()
def vae_encode(self):
try:
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
self.inputs["vae_encode"] = self.get_input_info_for(
base_models["vae_encode"]
)
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
check_compilation(compiled_vae_encode, "Vae Encode")
@@ -600,25 +776,35 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def unet(self):
def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
compiled_unet = None
unet_inputs = base_models[model]
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet, unet_mlir = self.compile_unet_variants(model)
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[self.base_model_id]
)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[model_id]
)
try:
compiled_unet, unet_mlir = self.compile_unet_variants(model)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
print(
"Retrying with a different base model configuration"
)
continue
# -- Once a successful compilation has taken place we'd want to store
@@ -641,7 +827,11 @@ class SharkifyStableDiffusionModel:
def vae(self):
try:
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
vae_input = (
base_models["vae"]["vae_upscaler"]
if self.is_upscaler
else base_models["vae"]["vae"]
)
self.inputs["vae"] = self.get_input_info_for(vae_input)
is_base_vae = self.base_vae
@@ -659,7 +849,9 @@ class SharkifyStableDiffusionModel:
def controlnet(self):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
check_compilation(compiled_stencil_adaptor, "Stencil")

View File

@@ -17,9 +17,13 @@ hf_model_variant_map = {
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
"stabilityai/stable-diffusion-2-inpainting": [
"stablediffusion",
"inpaint_v2",
],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
@@ -27,9 +31,12 @@ def get_quantize_model():
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
sys.exit(
"The int8 quantized model currently requires the height and width to be 512, and max_length to be 77"
)
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]

View File

@@ -15,6 +15,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -38,6 +43,11 @@ class Image2ImagePipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -135,6 +145,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
):
# prompts and negative prompts must be a list.
@@ -156,7 +167,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -14,6 +14,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -37,6 +42,11 @@ class InpaintPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -378,6 +388,7 @@ class InpaintPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -408,7 +419,10 @@ class InpaintPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -14,6 +14,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -38,6 +43,11 @@ class OutpaintPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -379,6 +389,7 @@ class OutpaintPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -409,7 +420,10 @@ class OutpaintPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -14,6 +14,12 @@ from diffusers import (
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -38,6 +44,12 @@ class StencilPipeline(StableDiffusionPipeline):
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -204,6 +216,7 @@ class StencilPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
):
# Control Embedding check & conversion
@@ -230,7 +243,10 @@ class StencilPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -13,6 +13,10 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -34,6 +38,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -81,6 +89,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -112,7 +121,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -17,9 +17,14 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_IDLE,
SD_STATE_CANCEL,
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import (
@@ -65,6 +70,11 @@ class UpscalerPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
low_res_scheduler: Union[
DDIMScheduler,
@@ -76,6 +86,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -84,6 +98,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.low_res_scheduler = low_res_scheduler
self.status = SD_STATE_IDLE
def prepare_extra_step_kwargs(self, generator, eta):
accepts_eta = "eta" in set(
@@ -164,7 +179,11 @@ class UpscalerPipeline(StableDiffusionPipeline):
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
self.status = SD_STATE_IDLE
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
@@ -178,15 +197,26 @@ class UpscalerPipeline(StableDiffusionPipeline):
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
@@ -210,8 +240,12 @@ class UpscalerPipeline(StableDiffusionPipeline):
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -236,6 +270,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -257,7 +292,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# 4. Preprocess image

View File

@@ -15,6 +15,9 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
@@ -48,6 +51,10 @@ class StableDiffusionPipeline:
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -57,6 +64,7 @@ class StableDiffusionPipeline:
self.vae = None
self.text_encoder = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
@@ -66,7 +74,8 @@ class StableDiffusionPipeline:
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
# TODO: Find a better workaround for fetching base_model_id early
# enough for CLIPTokenizer.
try:
self.tokenizer = get_tokenizer()
except:
@@ -81,13 +90,15 @@ class StableDiffusionPipeline:
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder = self.sd_model.clip()
else:
try:
self.text_encoder = get_clip()
except:
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.text_encoder = self.sd_model.clip()
@@ -104,7 +115,8 @@ class StableDiffusionPipeline:
else:
try:
self.unet = get_unet()
except:
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.unet = self.sd_model.unet()
@@ -112,6 +124,24 @@ class StableDiffusionPipeline:
del self.unet
self.unet = None
def load_unet_512(self):
if self.unet_512 is not None:
return
if self.import_mlir or self.use_lora:
self.unet_512 = self.sd_model.unet(use_large=True)
else:
try:
self.unet_512 = get_unet(use_large=True)
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.unet_512 = self.sd_model.unet(use_large=True)
def unload_unet_512(self):
del self.unet_512
self.unet_512 = None
def load_vae(self):
if self.vae is not None:
return
@@ -121,7 +151,8 @@ class StableDiffusionPipeline:
else:
try:
self.vae = get_vae()
except:
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.vae = self.sd_model.vae()
@@ -200,7 +231,10 @@ class StableDiffusionPipeline:
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
@@ -219,16 +253,28 @@ class StableDiffusionPipeline:
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
@@ -251,6 +297,7 @@ class StableDiffusionPipeline:
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -272,6 +319,10 @@ class StableDiffusionPipeline:
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
import_mlir: bool,
model_id: str,
@@ -356,16 +407,21 @@ class StableDiffusionPipeline:
prompt (`str` or `list(int)`):
prompt to be encoded
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation.
Ignored when not using guidance
(i.e., ignored if `guidance_scale` is less than `1`).
model_max_length (int):
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
SHARK: pass the max length instead of relying on
pipe.tokenizer.model_max_length
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not,
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
SHARK: must be set to True as we always expect neg embeddings
(defaulted to True)
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
The max multiple length of prompt embeddings compared to the
max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error
(defaulted to 1)
num_images_per_prompt (`int`):
number of images that should be generated per prompt
SHARK: num_images_per_prompt is not used (defaulted to 1)
@@ -384,9 +440,11 @@ class StableDiffusionPipeline:
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
f"`negative_prompt`: "
f"{negative_prompt} has batch size {len(negative_prompt)}, "
f"but `prompt`: {prompt} has batch size {batch_size}. "
f"Please make sure that passed `negative_prompt` matches "
"the batch size of `prompt`."
)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
@@ -399,16 +457,43 @@ class StableDiffusionPipeline:
)
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = text_embeddings.shape
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# text_embeddings = text_embeddings.repeat(
# 1,
# num_images_per_prompt,
# 1
# )
# text_embeddings = (
# text_embeddings.view(
# bs_embed * num_images_per_prompt,
# seq_len,
# -1
# )
# )
if do_classifier_free_guidance:
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = uncond_embeddings.shape
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# uncond_embeddings = (
# uncond_embeddings.repeat(
# 1,
# num_images_per_prompt,
# 1
# )
# )
# uncond_embeddings = (
# uncond_embeddings.view(
# bs_embed * num_images_per_prompt,
# seq_len,
# -1
# )
# )
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if text_embeddings.shape[1] > model_max_length:
pad = (0, 0) * (len(text_embeddings.shape) - 2)
pad = pad + (0, 512 - text_embeddings.shape[1])
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
@@ -443,7 +528,8 @@ re_attention = re.compile(
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Parses a string with attention tokens and returns a list of pairs:
text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12

View File

@@ -8,6 +8,9 @@ from diffusers import (
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
@@ -38,9 +41,28 @@ def get_schedulers(model_id):
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
)
schedulers[
"DPMSolverMultistep++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
)
schedulers[
"DPMSolverMultistepKarras"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_karras_sigmas=True,
)
schedulers[
"DPMSolverMultistepKarras++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
@@ -62,5 +84,21 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverSinglestep"
] = DPMSolverSinglestepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"KDPM2AncestralDiscrete"
] = KDPM2AncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers

View File

@@ -40,6 +40,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
def compile(self):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
device = args.device.split(":", 1)[0].strip()
model_input = {
"euler": {
@@ -92,7 +93,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
@@ -101,7 +102,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
self.step_model, _ = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)

View File

@@ -24,14 +24,18 @@ from apps.stable_diffusion.src.utils.utils import (
get_available_devices,
get_opt_flags,
preprocessCKPT,
convert_original_vae,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
get_path_stem,
get_extended_name,
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
clear_all,
save_output_img,
get_generation_text_info,
update_lora_weight,
resize_stencil,
_compile_module,
)

View File

@@ -3,7 +3,7 @@
"stablediffusion/untuned":"gs://shark_tank/nightly"
},
{
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",

View File

@@ -5,4 +5,7 @@
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"],
["A photo of a beach, sunset, calm, beautiful landscape, waves, water"],
["(a large body of water with snowy mountains in the background), (fog, foggy, rolling fog), (clouds, cloudy, rolling clouds), dramatic sky and landscape, extraordinary landscape, (beautiful snow capped mountain background), (forest, dirt path)"],
["a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smokes coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"]]

View File

@@ -116,7 +116,7 @@ def load_lower_configs(base_model_id=None):
else:
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
if not spec or spec in ["sm_80"]:
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
@@ -125,8 +125,38 @@ def load_lower_configs(base_model_id=None):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
elif spec in ["rdna3"] and version in [
"v2_1",
"v2_1base",
"v1_4",
"v1_5",
]:
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.max_length}_"
f"{args.precision}_"
f"{device}_"
f"{spec}_"
f"{args.width}x{args.height}.json"
)
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.precision}_"
f"{device}_"
f"{spec}_"
f"{args.width}x{args.height}.json"
)
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.precision}_"
f"{device}_"
f"{spec}.json"
)
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
@@ -171,9 +201,22 @@ def dump_after_mlir(input_mlir, use_winograd):
device, device_spec_args = get_device_args()
if use_winograd:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))"
)
else:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
)
dump_module = ireec.compile_str(
input_mlir,

View File

@@ -19,48 +19,56 @@ p = argparse.ArgumentParser(
)
##############################################################################
### Stable Diffusion Params
# Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
)
p.add_argument(
"-p",
"--prompts",
nargs="+",
default=["cyberpunk forest by Salvador Dali"],
help="text of which images to be generated.",
default=[
"a photo taken of the front of a super-car drifting on a road near "
"mountains at high speeds with smokes coming off the tires, front "
"angle, front point of view, trees in the mountains of the "
"background, ((sharp focus))"
],
help="Text of which images to be generated.",
)
p.add_argument(
"--negative_prompts",
nargs="+",
default=["trees, green"],
help="text you don't want to see in the generated image.",
default=[
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
"blurry, ugly, blur, oversaturated, cropped"
],
help="Text you don't want to see in the generated image.",
)
p.add_argument(
"--img_path",
type=str,
help="Path to the image input for img2img/inpainting",
help="Path to the image input for img2img/inpainting.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
help="The number of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=int,
default=-1,
help="the seed to use. -1 for a random one.",
help="The seed to use. -1 for a random one.",
)
p.add_argument(
@@ -68,7 +76,7 @@ p.add_argument(
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `batch_count`.",
help="The number of inferences to be made in a single `batch_count`.",
)
p.add_argument(
@@ -76,7 +84,7 @@ p.add_argument(
type=int,
default=512,
choices=range(128, 769, 8),
help="the height of the output image.",
help="The height of the output image.",
)
p.add_argument(
@@ -84,77 +92,86 @@ p.add_argument(
type=int,
default=512,
choices=range(128, 769, 8),
help="the width of the output image.",
help="The width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="the value to be used for guidance scaling.",
help="The value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="the value to be used for noise level of upscaler.",
help="The value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
help="Max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--max_embeddings_multiples",
type=int,
default=5,
help="The max multiple length of prompt embeddings compared to the max "
"output length of text encoder.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="the strength of change applied on the given input image for img2img",
help="The strength of change applied on the given input image for "
"img2img.",
)
##############################################################################
### Stable Diffusion Training Params
# Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model",
help="Directory to save the lora fine tuned model.",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt",
help="Directory containing images that are an example of the prompt.",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The no. of steps to train",
help="The number of steps to train.",
)
##############################################################################
### Inpainting and Outpainting Params
# Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting",
help="Path to the mask image input for inpainting.",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture",
help="If inpaint only masked area or whole picture.",
)
p.add_argument(
@@ -162,7 +179,7 @@ p.add_argument(
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding",
help="Number of pixels for only masked padding.",
)
p.add_argument(
@@ -170,7 +187,7 @@ p.add_argument(
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting",
help="Number of expended pixels for one direction for outpainting.",
)
p.add_argument(
@@ -178,89 +195,92 @@ p.add_argument(
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting",
help="Number of blur pixels for outpainting.",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend left for outpainting",
help="If expend left for outpainting.",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend right for outpainting",
help="If expend right for outpainting.",
)
p.add_argument(
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend top for outpainting",
help="If expend top for outpainting.",
)
p.add_argument(
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend bottom for outpainting",
help="If expend bottom for outpainting.",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0)",
help="Fall-off exponent for outpainting (lower=higher detail) "
"(min=0.0, max=4.0).",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0)",
help="Color variation for outpainting (min=0.0, max=1.0).",
)
##############################################################################
### Model Config and Usage Params
# Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
"--device", type=str, default="vulkan", help="Device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
"--precision", type=str, default="fp16", help="Precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
help="Attempts to load the model from a precompiled flat-buffer "
"and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
help="Saves the compiled flat-buffer to the local directory.",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
help="Download and use the tuned version of the model if available.",
)
p.add_argument(
@@ -274,28 +294,34 @@ p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
"HeunDiscrete].",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
help="Specify the format in which output image is save. "
"Supported options: jpg / png.",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
help="Directory path to save the output images and json.",
)
p.add_argument(
"--batch_count",
type=int,
default=1,
help="number of batch to be generated with random seeds in single execution",
help="Number of batch to be generated with random seeds in "
"single execution.",
)
p.add_argument(
@@ -309,7 +335,8 @@ p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
"needs to be plugged in.",
)
p.add_argument(
@@ -323,14 +350,15 @@ p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption",
help="Use the accelerate package to reduce cpu memory consumption.",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer)",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
"or an integer).",
)
p.add_argument(
@@ -343,187 +371,233 @@ p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
help="Use standalone LoRA weight using a HF ID or a checkpoint "
"file (~3 MB).",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
help="Runs the quantized version of stable diffusion model. "
"This is currently in experimental phase. "
"Currently, only runs the stable-diffusion-2-1-base model in "
"int8 quantization.",
)
p.add_argument(
"--ondemand",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM",
help="Load and unload models for low VRAM.",
)
##############################################################################
### IREE - Vulkan supported flags
# IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan",
help="Specify target triple for vulkan.",
)
p.add_argument(
"--iree_metal_target_platform",
type=str,
default="",
help="Specify target triple for metal.",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info",
help="Profiles vulkan device and collects the .rdc info.",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
help="Flag for setting VMA preferredLargeHeapBlockSize for "
"vulkan device, default is 4G.",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
help="Flag for disabling vulkan validation layers when benchmarking.",
)
##############################################################################
### Misc. Debug and Optimization flags
# Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
help="Use the default scheduler precompiled into the model if available.",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
help="Specify where to save downloaded shark_tank artifacts. "
"If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
help="When enabled call amdllpc to get ISA dumps. "
"Use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
help="Dispatches to return benchmark data on. "
'Use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
help="Directory where you want to store dispatch data "
'generated with "--dispatch_benchmarks".',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
help="Flag for inserting debug frames between iterations "
"for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
help="Flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
help="Flag setting warmup count for CLIP and VAE [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
help="Flag to clear all mlir and vmfb from common locations. "
"Recompiling will take several minutes.",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save a generation information json file with the image.",
help="Flag for whether or not to save a generation information "
"json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
help="Flag for whether or not to save generation information in "
"PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="if import_mlir is True, saves mlir via the debug option in shark importer. Does nothing if import_mlir is false (the default)",
help="If import_mlir is True, saves mlir via the debug option "
"in shark importer. Does nothing if import_mlir is false (the default).",
)
##############################################################################
### Web UI flags
# Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the progress bar animation during image generation",
help="Flag for removing the progress bar animation during "
"image generation.",
)
p.add_argument(
"--ckpt_dir",
type=str,
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="One of: [api, app, web].",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
help="Flag for generating a public URL.",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
help="Flag for setting server port.",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for enabling rest API",
help="Flag for enabling rest API.",
)
p.add_argument(
"--output_gallery",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the output gallery tab, and avoid exposing "
"images under --output_dir in the UI.",
)
p.add_argument(
"--output_gallery_followlinks",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether the output gallery tab in the UI should "
"follow symlinks when listing subdirectories under --output_dir.",
)
##############################################################################
### SD model auto-annotation flags
# SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file",
help="Directory to save the annotated mlir file.",
)
p.add_argument(
@@ -537,31 +611,31 @@ p.add_argument(
"--save_annotation",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file",
help="Save annotated mlir file.",
)
##############################################################################
### SD model auto-tuner flags
# SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file",
help="Directory to save the tuned config file.",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning",
help="Number of iterations for tuning.",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all",
help="Op to be optimized, options are matmul, bmm, conv and all.",
)

View File

@@ -18,6 +18,7 @@ from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.metal_utils import get_metal_target_triple
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
@@ -25,7 +26,13 @@ from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
create_vae_diffusers_config,
convert_ldm_vae_checkpoint,
)
import requests
from io import BytesIO
from omegaconf import OmegaConf
from cpuinfo import get_cpu_info
def get_extended_name(model_name):
@@ -42,6 +49,7 @@ def get_vmfb_path_name(model_name):
def _load_vmfb(shark_module, vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
model = "unet" if "unet512" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module.load_module(vmfb_path, extra_args=extra_args)
@@ -73,12 +81,13 @@ def _compile_module(shark_module, model_name, extra_args=[]):
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
def get_shark_model(tank_url, model_name, extra_args=None):
if extra_args is None:
extra_args = []
from shark.parser import shark_args
# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache
from shark.shark_downloader import download_model
if "cuda" in args.device:
@@ -106,12 +115,15 @@ def compile_through_fx(
save_dir=tempfile.gettempdir(),
debug=False,
generate_vmfb=True,
extra_args=[],
extra_args=None,
base_model_id=None,
model_name=None,
precision=None,
return_mlir=False,
device=None,
):
if extra_args is None:
extra_args = []
if not return_mlir and model_name is not None:
vmfb_path = get_vmfb_path_name(extended_model_name)
if os.path.isfile(vmfb_path):
@@ -141,7 +153,10 @@ def compile_through_fx(
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
if "unet" in model_name.split("_")[0]:
if (
"unet" in model_name.split("_")[0]
or "unet_512" in model_name.split("_")[0]
):
args.annotation_model = "unet"
mlir_module = sd_model_annotation(
mlir_module, extended_model_name, base_model_id
@@ -149,7 +164,7 @@ def compile_through_fx(
shark_module = SharkInference(
mlir_module,
device=args.device,
device=args.device if device is None else device,
mlir_dialect="tm_tensor",
)
if generate_vmfb:
@@ -194,13 +209,15 @@ def get_device_mapping(driver, key_combination=3):
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
dict: map to possible device names user can input mapped to desired
combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
@@ -214,7 +231,7 @@ def get_device_mapping(driver, key_combination=3):
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
@@ -227,10 +244,12 @@ def get_device_mapping(driver, key_combination=3):
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
"""Gives the appropriate device data (supported name/path) for user
selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
@@ -238,7 +257,8 @@ def map_device_to_name_path(device, key_combination=3):
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
str / tuple: returns the mapping str or tuple of mapping str for
the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
@@ -261,10 +281,21 @@ def set_init_device_flags():
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
f"Found device {device_name}. Using target triple "
f"{args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "metal" in args.device:
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_metal_target_platform:
triple = get_metal_target_triple(device_name)
if triple is not None:
args.iree_metal_target_platform = triple
print(
f"Found device {device_name}. Using target triple "
f"{args.iree_metal_target_platform}."
)
elif "cpu" in args.device:
args.device = "cpu"
@@ -289,13 +320,24 @@ def set_init_device_flags():
if (
args.precision != "fp16"
or args.height not in [512, 768]
or (args.height == 512 and args.width != 512)
or (args.height == 768 and args.width != 768)
or (args.height == 512 and args.width not in [512, 768])
or (args.height == 768 and args.width not in [512, 768])
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
args.use_tuned = False
elif (
args.height != args.width
and "rdna2" in args.iree_vulkan_target_triple
and base_model_id
not in [
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
]
):
args.use_tuned = False
elif base_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
@@ -333,13 +375,26 @@ def set_init_device_flags():
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
]
or "rdna3" not in args.iree_vulkan_target_triple
or "rdna" not in args.iree_vulkan_target_triple
)
):
args.use_tuned = False
elif "rdna2" in args.iree_vulkan_target_triple and (
base_model_id
not in [
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]
):
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
print(
f"Using tuned models for {base_model_id}(fp16) on "
f"device {args.device}."
)
else:
print("Tuned models are currently not supported for this setting.")
@@ -396,8 +451,12 @@ def get_available_devices():
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_list.append(f"{device['name']} => {driver_name}://{i}")
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
@@ -405,9 +464,14 @@ def get_available_devices():
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("device => cpu")
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
return available_devices
@@ -464,7 +528,7 @@ def get_path_stem(path):
def get_path_to_diffusers_checkpoint(custom_weights):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = path.stem
diffusers_directory_name = os.path.join("diffusers", path.stem)
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
@@ -484,10 +548,10 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
# EMA weights usually yield higher quality images for inference but
# non-EMA weights have been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
# they want to go for EMA weight extraction or not.
extract_ema = False
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
@@ -503,6 +567,25 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
print("Loading complete")
def convert_original_vae(vae_checkpoint):
vae_state_dict = {}
for key in list(vae_checkpoint.keys()):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
"main/configs/stable-diffusion/v1-inference.yaml"
)
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
vae_state_dict, vae_config
)
return converted_vae_checkpoint
def processLoRA(model, use_lora, splitting_prefix):
state_dict = ""
if ".safetensors" in use_lora:
@@ -613,7 +696,7 @@ def update_lora_weight(model, use_lora, model_name):
# `fetch_and_update_base_model_id` is a resource utility function which
# helps maintaining mapping of the model to run with its base model.
# helps to maintain mapping of the model to run with its base model.
# If `base_model` is "", then this function tries to fetch the base model
# info for the `model_to_run`.
def fetch_and_update_base_model_id(model_to_run, base_model=""):
@@ -630,13 +713,15 @@ def fetch_and_update_base_model_id(model_to_run, base_model=""):
return base_model
elif base_model == "":
return base_model
# Update JSON data to contain an entry mapping model_to_run with base_model.
# Update JSON data to contain an entry mapping model_to_run with
# base_model.
json_data.update(data)
with open(variants_path, "w", encoding="utf-8") as jsonFile:
json.dump(json_data, jsonFile)
# Generate and return a new seed if the provided one is not in the supported range (including -1)
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed):
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
@@ -655,7 +740,8 @@ def clear_all():
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# Temporary workaround of deleting yaml files to incorporate
# diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
@@ -673,24 +759,41 @@ def clear_all():
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
def get_generated_imgs_path() -> Path:
return Path(
args.output_dir if args.output_dir else Path.cwd(), "generated_imgs"
)
def get_generated_imgs_todays_subdir() -> str:
return dt.now().strftime("%Y%m%d")
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info={}):
output_path = args.output_dir if args.output_dir else Path.cwd()
def save_output_img(output_img, img_seed, extra_info=None):
if extra_info is None:
extra_info = {}
generated_imgs_path = Path(
output_path, "generated_imgs", dt.now().strftime("%Y%m%d")
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
img_model = args.hf_model_id
if args.ckpt_loc:
img_model = Path(os.path.basename(args.ckpt_loc)).stem
img_vae = None
if args.custom_vae:
img_vae = Path(os.path.basename(args.custom_vae)).stem
img_lora = None
if args.use_lora:
img_lora = Path(os.path.basename(args.use_lora)).stem
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
@@ -701,17 +804,30 @@ def save_output_img(output_img, img_seed, extra_info={}):
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
f"{args.prompts[0]}"
f"\nNegative prompt: {args.negative_prompts[0]}"
f"\nSteps: {args.steps},"
f"Sampler: {args.scheduler}, "
f"CFG scale: {args.guidance_scale}, "
f"Seed: {img_seed},"
f"Size: {args.width}x{args.height}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_lora}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
f"[ERROR] Format {args.output_img_format} is not "
f"supported yet. Image saved as png instead."
f"Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
# importance for each data point. Something to consider.
new_entry = {
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
@@ -725,12 +841,17 @@ def save_output_img(output_img, img_seed, extra_info={}):
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
"VAE": img_vae,
"LORA": img_lora,
}
new_entry.update(extra_info)
with open(csv_path, "a", encoding="utf-8") as csv_obj:
csv_mode = "a" if os.path.isfile(csv_path) else "w"
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
if csv_mode == "w":
dictwriter_obj.writeheader()
dictwriter_obj.writerow(new_entry)
csv_obj.close()
@@ -744,16 +865,27 @@ def save_output_img(output_img, img_seed, extra_info={}):
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
text_output += (
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
text_output += (
f"\nsteps={args.steps}, "
f"guidance_scale={args.guidance_scale}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={args.height}x{args.width}, "
f"batch_count={args.batch_count}, "
f"batch_size={args.batch_size}, "
f"max_length={args.max_length}"
)
return text_output
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# For stencil, the input image can be of any size, but we need to ensure that
# it conforms with our model constraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.

View File

@@ -1,22 +1,55 @@
from multiprocessing import Process, freeze_support
import os
import sys
import transformers
if sys.platform == "darwin":
# import before IREE to avoid torch-MLIR library issues
import torch_mlir
import shutil
import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
from apps.stable_diffusion.src import args, clear_all
import apps.stable_diffusion.web.utils.global_obj as global_obj
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
# import before IREE to avoid MLIR library issues
import torch_mlir
if args.clear_all:
clear_all()
def launch_app(address):
from tkinter import Tk
import webview
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False)
if __name__ == "__main__":
if args.api:
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
if args.api or "api" in args.ui.split(","):
from apps.stable_diffusion.web.ui import (
txt2img_api,
img2img_api,
upscaler_api,
inpaint_api,
outpaint_api,
)
from fastapi import FastAPI, APIRouter
import uvicorn
@@ -28,26 +61,26 @@ if __name__ == "__main__":
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route(
# "/sdapi/v1/outpaint", outpaint_api, methods=["post"]
# )
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
app.include_router(APIRouter())
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
sys.exit(0)
import gradio as gr
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
config_gradio_tmp_imgs_folder,
)
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create the custom model folder if it doesn't already exist
dir = ["models", "vae", "lora"]
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
config_gradio_tmp_imgs_folder()
import gradio as gr
# Create custom models folders if they don't exist
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
create_custom_models_folders()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
@@ -60,36 +93,69 @@ if __name__ == "__main__":
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_png_info_img,
txt2img_status,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_status,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
img2img_sendto_upscaler,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_status,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
inpaint_sendto_upscaler,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_status,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
outpaint_sendto_upscaler,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_status,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
lora_train_web,
model_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
modelmanager_sendto_inpaint,
modelmanager_sendto_outpaint,
modelmanager_sendto_upscaler,
stablelm_chat,
outputgallery_web,
outputgallery_tab_select,
outputgallery_watch,
outputgallery_filename,
outputgallery_sendto_txt2img,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
outputgallery_sendto_upscaler,
)
# init global sd pipeline and config
@@ -105,6 +171,27 @@ if __name__ == "__main__":
outputs,
)
def register_modelmanager_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
@@ -119,11 +206,29 @@ if __name__ == "__main__":
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.Tabs(visible=False) as experimental_tabs:
with gr.TabItem(label="LoRA Training", id=5):
with gr.TabItem(label="Model Manager", id=5):
model_web.render()
with gr.TabItem(label="Chat Bot(Experimental)", id=6):
stablelm_chat.render()
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
lora_train_web.render()
if args.output_gallery:
with gr.TabItem(label="Output Gallery", id=8) as og_tab:
outputgallery_web.render()
# extra output gallery configuration
outputgallery_tab_select(og_tab.select)
outputgallery_watch(
[
txt2img_status,
img2img_status,
inpaint_status,
outpaint_status,
upscaler_status,
]
)
# send to buttons
register_button_click(
txt2img_sendto_img2img,
1,
@@ -220,10 +325,77 @@ if __name__ == "__main__":
[upscaler_gallery],
[outpaint_init_image, tabs],
)
if args.output_gallery:
register_outputgallery_button(
outputgallery_sendto_txt2img,
0,
[outputgallery_filename],
[txt2img_png_info_img, tabs],
)
register_outputgallery_button(
outputgallery_sendto_img2img,
1,
[outputgallery_filename],
[img2img_init_image, tabs],
)
register_outputgallery_button(
outputgallery_sendto_inpaint,
2,
[outputgallery_filename],
[inpaint_init_image, tabs],
)
register_outputgallery_button(
outputgallery_sendto_outpaint,
3,
[outputgallery_filename],
[outpaint_init_image, tabs],
)
register_outputgallery_button(
outputgallery_sendto_upscaler,
4,
[outputgallery_filename],
[upscaler_init_image, tabs],
)
register_modelmanager_button(
modelmanager_sendto_txt2img,
0,
[hf_models],
[txt2img_custom_model, txt2img_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_img2img,
1,
[hf_models],
[img2img_custom_model, img2img_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_inpaint,
2,
[hf_models],
[inpaint_custom_model, inpaint_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_outpaint,
3,
[hf_models],
[outpaint_custom_model, outpaint_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_upscaler,
4,
[hf_models],
[upscaler_custom_model, upscaler_hf_model_id, tabs],
)
sd_web.queue()
if args.ui == "app":
t = Process(
target=launch_app, args=[f"http://localhost:{args.server_port}"]
)
t.start()
sd_web.launch(
share=args.share,
inbrowser=True,
inbrowser=args.ui == "web",
server_name="0.0.0.0",
server_port=args.server_port,
)

View File

@@ -2,7 +2,11 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_inf,
txt2img_api,
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_png_info_img,
txt2img_status,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
@@ -12,8 +16,11 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
img2img_api,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_status,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
img2img_sendto_upscaler,
@@ -22,8 +29,11 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
inpaint_inf,
inpaint_api,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_status,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
inpaint_sendto_upscaler,
@@ -32,8 +42,11 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
outpaint_inf,
outpaint_api,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_status,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
outpaint_sendto_upscaler,
@@ -42,10 +55,34 @@ from apps.stable_diffusion.web.ui.upscaler_ui import (
upscaler_inf,
upscaler_api,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_status,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
)
from apps.stable_diffusion.web.ui.model_manager import (
model_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
modelmanager_sendto_inpaint,
modelmanager_sendto_outpaint,
modelmanager_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat
from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_web,
outputgallery_tab_select,
outputgallery_watch,
outputgallery_filename,
outputgallery_sendto_txt2img,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
outputgallery_sendto_upscaler,
)

View File

@@ -173,7 +173,30 @@ footer {
#gallery .thumbnail-item.thumbnail-lg {
aspect-ratio: unset;
max-height: calc(55vh - (2 * var(--spacing-lg)));
min-height: 390px
}
@media (min-width: 1921px) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
#gallery .grid-wrap, #gallery .preview{
min-height: calc(768px + 4px + var(--size-14));
max-height: calc(768px + 4px + var(--size-14));
}
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
#gallery .thumbnail-item.thumbnail-lg {
max-height: 770px !important;
}
}
/* Don't upscale when viewing in solo image mode */
#gallery .preview img {
object-fit: scale-down;
}
/* Navbar images in cover mode*/
#gallery .preview .thumbnail-item img {
object-fit: cover;
}
/* Limit the stable diffusion text output height */
#std_output textarea {
max-height: 215px;
}
/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */
@@ -204,6 +227,66 @@ footer {
}
/* Hide the download icon from the nod logo */
#top_logo .download {
#top_logo button {
display: none;
}
/* workarounds for container=false not currently working for dropdowns */
.dropdown_no_container {
padding: 0 !important;
}
#output_subdir_container :first-child {
border: none;
}
/* reduced animation load when generating */
.generating {
animation-play-state: paused !important;
}
/* better clarity when progress bars are minimal */
.meta-text {
background-color: var(--block-label-background-fill);
}
/* output gallery tab */
.output_parameters_dataframe tbody td {
font-size: small;
line-height: var(--line-xs)
}
.output_icon_button {
max-width: 30px;
align-self: end;
padding-bottom: 8px;
}
.outputgallery_sendto {
min-width: 7em !important;
}
/* output gallery should take up most of the viewport height regardless of image size/number */
#outputgallery_gallery .fixed-height {
min-height: 89vh !important;
}
/* don't stretch non-square images to be square, breaking their aspect ratio */
#outputgallery_gallery .thumbnail-item.thumbnail-lg > img {
object-fit: contain !important;
}
/* centered logo for when there are no images */
#top_logo.logo_centered {
height: 100%;
width: 100%;
}
#top_logo.logo_centered img{
object-fit: scale-down;
position: absolute;
width: 80%;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}

View File

@@ -1,9 +1,8 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
import PIL
from PIL import Image
import base64
from io import BytesIO
@@ -25,10 +24,14 @@ from apps.stable_diffusion.src import (
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
import numpy as np
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
@@ -41,7 +44,7 @@ init_import_mlir = args.import_mlir
def img2img_inf(
prompt: str,
negative_prompt: str,
init_image,
image_dict,
height: int,
width: int,
steps: int,
@@ -84,9 +87,14 @@ def img2img_inf(
args.img_path = "not none"
args.ondemand = ondemand
if init_image is None:
if image_dict is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB")
if use_stencil == "scribble":
image = image_dict["mask"].convert("RGB")
elif isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
else:
image = image_dict["image"].convert("RGB")
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
@@ -96,9 +104,13 @@ def img2img_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
args.hf_model_id = hf_model_id
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
@@ -121,7 +133,8 @@ def img2img_inf(
image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
f"Shark schedulers are not supported. Switching to EulerDiscrete "
f"scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
@@ -238,6 +251,7 @@ def img2img_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
seeds.append(img_seed)
@@ -249,11 +263,17 @@ def img2img_inf(
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed, extra_info)
save_output_img(
out_imgs[0],
img_seed,
extra_info,
)
generated_imgs.extend(out_imgs)
# yield generated_imgs, text_output
yield generated_imgs, text_output, status_label(
"Image-to-Image", current_batch + 1, batch_count, batch_size
)
return generated_imgs, text_output
return generated_imgs, text_output, ""
def decode_base64_to_image(encoding):
@@ -290,7 +310,9 @@ def img2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = img2img_inf(
@@ -323,6 +345,10 @@ def img2img_api(
lora_hf_id="",
ondemand=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -340,30 +366,48 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
# janky fix for overflowing text
i2i_model_info = (str(get_custom_model_path())).replace(
"\\", "\n\\"
)
i2i_model_info = f"Custom Model Path: {i2i_model_info}"
img2img_custom_model = gr.Dropdown(
label=f"Models",
info=i2i_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
hf_model_id = gr.Textbox(
img2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
i2i_vae_info = (str(get_custom_model_path("vae"))).replace(
"\\", "\n\\"
)
i2i_vae_info = f"VAE Path: {i2i_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=i2i_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -375,19 +419,23 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
img2img_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
label="Input Image",
source="upload",
tool="sketch",
type="pil",
height=300,
)
with gr.Accordion(label="Stencil Options", open=False):
with gr.Row():
@@ -397,17 +445,77 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
value="None",
choices=["None", "canny", "openpose", "scribble"],
)
def show_canvas(choice):
if choice == "scribble":
return (
gr.Slider.update(visible=True),
gr.Slider.update(visible=True),
gr.Button.update(visible=True),
)
else:
return (
gr.Slider.update(visible=False),
gr.Slider.update(visible=False),
gr.Button.update(visible=False),
)
def create_canvas(w, h):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
create_button = gr.Button(
label="Start",
value="Open drawing canvas!",
visible=False,
)
create_button.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[img2img_init_image],
)
use_stencil.change(
fn=show_canvas,
inputs=use_stencil,
outputs=[canvas_width, canvas_height, create_button],
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
i2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=i2i_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -514,10 +622,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
lambda: -1,
inputs=[],
outputs=[seed],
_js="() => -1",
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
@@ -528,19 +636,17 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=2,
object_fit="contain",
)
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
img2img_status = gr.Textbox(visible=False)
with gr.Row():
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
img2img_sendto_outpaint = gr.Button(
@@ -565,8 +671,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
img2img_custom_model,
img2img_hf_model_id,
custom_vae,
precision,
device,
@@ -578,13 +684,21 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
lora_hf_id,
ondemand,
],
outputs=[img2img_gallery, std_output],
show_progress=args.progress_bar,
outputs=[img2img_gallery, std_output, img2img_status],
show_progress="minimal" if args.progress_bar else "none",
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Image-to-Image", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=img2img_status,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],

View File

@@ -1,4 +1,3 @@
from pathlib import Path
import os
import torch
import time
@@ -26,7 +25,11 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
@@ -89,9 +92,13 @@ def inpaint_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
args.hf_model_id = hf_model_id
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
@@ -198,6 +205,7 @@ def inpaint_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
@@ -210,7 +218,9 @@ def inpaint_inf(
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output, status_label(
"Inpaint", i + 1, batch_count, batch_size
)
return generated_imgs, text_output
@@ -249,7 +259,9 @@ def inpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
init_image = decode_base64_to_image(InputData["image"])
mask = decode_base64_to_image(InputData["mask"])
@@ -270,7 +282,7 @@ def inpaint_api(
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
precision="fp16",
device=available_devices[0],
@@ -281,6 +293,10 @@ def inpaint_api(
lora_hf_id="",
ondemand=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -298,30 +314,52 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
# janky fix for overflowing text
inpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
inpaint_model_info = (
f"Custom Model Path: {inpaint_model_info}"
)
inpaint_custom_model = gr.Dropdown(
label=f"Models",
info=inpaint_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files()
+ get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
)
hf_model_id = gr.Textbox(
inpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
inpaint_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
inpaint_vae_info = f"VAE Path: {inpaint_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=inpaint_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -333,13 +371,13 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
@@ -348,19 +386,29 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
source="upload",
tool="sketch",
type="pil",
).style(height=350)
height=350,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
inpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
inpaint_lora_info = f"LoRA Path: {inpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=inpaint_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -474,10 +522,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
lambda: -1,
inputs=[],
outputs=[seed],
_js="() => -1",
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
@@ -488,19 +536,18 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
inpaint_status = gr.Textbox(visible=False)
with gr.Row():
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
inpaint_sendto_outpaint = gr.Button(
@@ -526,8 +573,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
inpaint_custom_model,
inpaint_hf_model_id,
custom_vae,
precision,
device,
@@ -538,13 +585,20 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
lora_hf_id,
ondemand,
],
outputs=[inpaint_gallery, std_output],
show_progress=args.progress_bar,
outputs=[inpaint_gallery, std_output, inpaint_status],
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=inpaint_status,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],

View File

@@ -24,15 +24,25 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
# janky fix for overflowing text
train_lora_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
train_lora_model_info = (
f"Custom Model Path: {train_lora_model_info}"
)
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=train_lora_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -43,22 +53,33 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models "
"dropdown on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
# janky fix for overflowing text
train_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
train_lora_info = f"LoRA Path: {train_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights to initialize weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA weights to initialize weights",
info=train_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use a "
"standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID to initialize weights",
lines=3,
@@ -74,7 +95,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
with gr.Accordion(label="Advanced Options", open=False):
@@ -159,10 +180,10 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
lambda: -1,
inputs=[],
outputs=[seed],
_js="() => -1",
queue=False,
)
with gr.Column(scale=6):
train_lora = gr.Button("Train LoRA")
@@ -215,7 +236,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
),
],
outputs=[std_output],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
prompt_submit = prompt.submit(**kwargs)

View File

@@ -0,0 +1,160 @@
import os
import gradio as gr
import requests
from io import BytesIO
from PIL import Image
def get_hf_list(num_of_models=20):
path = "https://huggingface.co/api/models"
params = {
"search": "stable-diffusion",
"sort": "downloads",
"direction": "-1",
"limit": {num_of_models},
"full": "true",
}
response = requests.get(path, params=params)
return response.json()
def get_civit_list(num_of_models=50):
path = (
f"https://civitai.com/api/v1/models?limit="
f"{num_of_models}&types=Checkpoint"
)
headers = {"Content-Type": "application/json"}
raw_json = requests.get(path, headers=headers).json()
models = list(raw_json.items())[0][1]
safe_models = [
safe_model for safe_model in models if not safe_model["nsfw"]
]
version_id = 0 # Currently just using the first version.
safe_models = [
safe_model
for safe_model in safe_models
if safe_model["modelVersions"][version_id]["files"][0]["metadata"][
"format"
]
== "SafeTensor"
]
first_version_models = []
for model_iter in safe_models:
# The modelVersion would only keep the version name.
if (
model_iter["modelVersions"][version_id]["images"][0]["nsfw"]
!= "None"
):
continue
model_iter["modelVersions"][version_id]["modelName"] = model_iter[
"name"
]
model_iter["modelVersions"][version_id]["rating"] = model_iter[
"stats"
]["rating"]
model_iter["modelVersions"][version_id]["favoriteCount"] = model_iter[
"stats"
]["favoriteCount"]
model_iter["modelVersions"][version_id]["downloadCount"] = model_iter[
"stats"
]["downloadCount"]
first_version_models.append(model_iter["modelVersions"][version_id])
return first_version_models
def get_image_from_model(model_json):
model_id = model_json["modelId"]
image = None
for img_info in model_json["images"]:
if img_info["nsfw"] == "None":
image_url = model_json["images"][0]["url"]
response = requests.get(image_url)
image = BytesIO(response.content)
break
return image
with gr.Blocks() as model_web:
with gr.Row():
model_source = gr.Radio(
value=None,
choices=["Hugging Face", "Civitai"],
type="value",
label="Model Source",
)
model_number = gr.Slider(
1,
100,
value=10,
step=1,
label="Number of models",
interactive=True,
)
# TODO: add more filters
get_model_btn = gr.Button(value="Get Models")
hf_models = gr.Dropdown(
label="Hugging Face Model List",
choices=None,
value=None,
visible=False,
)
# TODO: select and SendTo
civit_models = gr.Gallery(
label="Civitai Model Gallery",
value=None,
interactive=True,
visible=False,
)
with gr.Row(visible=False) as sendto_btns:
modelmanager_sendto_txt2img = gr.Button(value="SendTo Txt2Img")
modelmanager_sendto_img2img = gr.Button(value="SendTo Img2Img")
modelmanager_sendto_inpaint = gr.Button(value="SendTo Inpaint")
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
def get_model_list(model_source, model_number):
if model_source == "Hugging Face":
hf_model_list = get_hf_list(model_number)
models = []
for model in hf_model_list:
# TODO: add model info
models.append(f'{model["modelId"]}')
return (
gr.Dropdown.update(choices=models, visible=True),
gr.Gallery.update(value=None, visible=False),
gr.Row.update(visible=True),
)
elif model_source == "Civitai":
civit_model_list = get_civit_list(model_number)
models = []
for model in civit_model_list:
image = get_image_from_model(model)
if image is None:
continue
# TODO: add model info
models.append(
(Image.open(image), f'{model["files"][0]["downloadUrl"]}')
)
return (
gr.Dropdown.update(value=None, choices=None, visible=False),
gr.Gallery.update(value=models, visible=True),
gr.Row.update(visible=False),
)
else:
return (
gr.Dropdown.update(value=None, choices=None, visible=False),
gr.Gallery.update(value=None, visible=False),
gr.Row.update(visible=False),
)
get_model_btn.click(
fn=get_model_list,
inputs=[model_source, model_number],
outputs=[
hf_models,
civit_models,
sendto_btns,
],
)

View File

@@ -1,8 +1,6 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
import base64
@@ -23,11 +21,13 @@ from apps.stable_diffusion.src import (
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
@@ -91,9 +91,13 @@ def outpaint_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
args.hf_model_id = hf_model_id
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
@@ -207,6 +211,7 @@ def outpaint_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
@@ -219,9 +224,11 @@ def outpaint_inf(
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output, status_label(
"Outpaint", i + 1, batch_count, batch_size
)
return generated_imgs, text_output
return generated_imgs, text_output, ""
def decode_base64_to_image(encoding):
@@ -258,7 +265,9 @@ def outpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = outpaint_inf(
@@ -281,7 +290,7 @@ def outpaint_api(
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
precision="fp16",
device=available_devices[0],
@@ -292,6 +301,10 @@ def outpaint_api(
lora_hf_id="",
ondemand=False,
)
# Convert Generator to Subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -309,30 +322,52 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
# janky fix for overflowing text
outpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
outpaint_model_info = (
f"Custom Model Path: {outpaint_model_info}"
)
outpaint_custom_model = gr.Dropdown(
label=f"Models",
info=outpaint_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files()
+ get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
)
hf_model_id = gr.Textbox(
outpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
outpaint_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
outpaint_vae_info = f"VAE Path: {outpaint_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=outpaint_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -344,31 +379,42 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
outpaint_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
label="Input Image",
type="pil",
height=300,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
outpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
outpaint_lora_info = f"LoRA Path: {outpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=outpaint_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -504,10 +550,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
lambda: -1,
inputs=[],
outputs=[seed],
_js="() => -1",
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
@@ -518,19 +564,17 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
outpaint_status = gr.Textbox(visible=False)
with gr.Row():
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -557,8 +601,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
outpaint_custom_model,
outpaint_hf_model_id,
custom_vae,
precision,
device,
@@ -569,13 +613,20 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
lora_hf_id,
ondemand,
],
outputs=[outpaint_gallery, std_output],
show_progress=args.progress_bar,
outputs=[outpaint_gallery, std_output, outpaint_status],
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Outpaint", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=outpaint_status,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],

View File

@@ -0,0 +1,491 @@
import glob
import gradio as gr
import os
import subprocess
import sys
from PIL import Image
from apps.stable_diffusion.src import args
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
# -- Functions for file, directory and image info querying
output_dir = get_generated_imgs_path()
def outputgallery_filenames(subdir) -> list[str]:
new_dir_path = os.path.join(output_dir, subdir)
if os.path.exists(new_dir_path):
filenames = [
glob.glob(new_dir_path + "/" + ext)
for ext in ("*.png", "*.jpg", "*.jpeg")
]
return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True)
else:
return []
def output_subdirs() -> list[str]:
# Gets a list of subdirectories of output_dir and below, as relative paths.
relative_paths = [
os.path.relpath(entry[0], output_dir)
for entry in os.walk(
output_dir, followlinks=args.output_gallery_followlinks
)
]
# It is less confusing to always including the subdir that will take any
# images generated today even if it doesn't exist yet
if get_generated_imgs_todays_subdir() not in relative_paths:
relative_paths.append(get_generated_imgs_todays_subdir())
# sort subdirectories so that the date named ones we probably
# created in this or previous sessions come first, sorted with the most
# recent first. Other subdirs are listed after.
generated_paths = sorted(
[path for path in relative_paths if path.isnumeric()], reverse=True
)
result_paths = generated_paths + sorted(
[
path
for path in relative_paths
if (not path.isnumeric()) and path != "."
]
)
return result_paths
# --- Define UI layout for Gradio
with gr.Blocks() as outputgallery_web:
nod_logo = Image.open(nodlogo_loc)
with gr.Row(elem_id="outputgallery_gallery"):
# needed to workaround gradio issue:
# https://github.com/gradio-app/gradio/issues/2907
dev_null = gr.Textbox("", visible=False)
gallery_files = gr.State(value=[])
subdirectory_paths = gr.State(value=[])
with gr.Column(scale=6):
logo = gr.Image(
label="Getting subdirectories...",
value=nod_logo,
interactive=False,
visible=True,
show_label=True,
elem_id="top_logo",
elem_classes="logo_centered",
)
gallery = gr.Gallery(
label="",
value=gallery_files.value,
visible=False,
show_label=True,
columns=2,
)
with gr.Column(scale=4):
with gr.Box():
with gr.Row():
with gr.Column(
scale=15,
min_width=160,
elem_id="output_subdir_container",
):
subdirectories = gr.Dropdown(
label=f"Subdirectories of {output_dir}",
type="value",
choices=subdirectory_paths.value,
value="",
interactive=True,
elem_classes="dropdown_no_container",
)
with gr.Column(
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
open_subdir = gr.Button(
variant="secondary",
value="\U0001F5C1", # unicode open folder
interactive=False,
size="sm",
)
with gr.Column(
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
refresh = gr.Button(
variant="secondary",
value="\u21BB", # unicode clockwise arrow circle
size="sm",
)
image_columns = gr.Slider(
label="Columns shown", value=4, minimum=1, maximum=16, step=1
)
outputgallery_filename = gr.Textbox(
label="Filename",
value="None",
interactive=False,
show_copy_button=True,
)
with gr.Accordion(
label="Parameter Information", open=False
) as parameters_accordian:
image_parameters = gr.DataFrame(
headers=["Parameter", "Value"],
col_count=2,
wrap=True,
elem_classes="output_parameters_dataframe",
value=[["Status", "No image selected"]],
)
with gr.Accordion(label="Send To", open=True):
with gr.Row():
outputgallery_sendto_txt2img = gr.Button(
value="Txt2Img",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)
outputgallery_sendto_img2img = gr.Button(
value="Img2Img",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)
outputgallery_sendto_inpaint = gr.Button(
value="Inpaint",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)
outputgallery_sendto_outpaint = gr.Button(
value="Outpaint",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)
outputgallery_sendto_upscaler = gr.Button(
value="Upscaler",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)
# --- Event handlers
def on_clear_gallery():
return [
gr.Gallery.update(
value=[],
visible=False,
),
gr.Image.update(
visible=True,
),
]
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
new_images = outputgallery_filenames(subdir)
new_label = (
f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
)
return [
new_images,
gr.Gallery.update(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image.update(
label=new_label,
visible=len(new_images) == 0,
),
]
def on_open_subdir(subdir):
subdir_path = os.path.normpath(os.path.join(output_dir, subdir))
if os.path.isdir(subdir_path):
if sys.platform == "linux":
subprocess.run(["xdg-open", subdir_path])
elif sys.platform == "darwin":
subprocess.run(["open", subdir_path])
elif sys.platform == "win32":
os.startfile(subdir_path)
def on_refresh(current_subdir: str) -> list:
# get an up-to-date subdirectory list
refreshed_subdirs = output_subdirs()
# get the images using either the current subdirectory or the most
# recent valid one
new_subdir = (
current_subdir
if current_subdir in refreshed_subdirs
else refreshed_subdirs[0]
)
new_images = outputgallery_filenames(new_subdir)
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, new_subdir)}"
)
return [
gr.Dropdown.update(
choices=refreshed_subdirs,
value=new_subdir,
),
refreshed_subdirs,
new_images,
gr.Gallery.update(
value=new_images, label=new_label, visible=len(new_images) > 0
),
gr.Image.update(
label=new_label,
visible=len(new_images) == 0,
),
]
def on_new_image(subdir, subdir_paths, status) -> list:
# prevent error triggered when an image generates before the tab
# has even been selected
subdir_paths = (
subdir_paths
if len(subdir_paths) > 0
else [get_generated_imgs_todays_subdir()]
)
# only update if the current subdir is the most recent one as
# new images only go there
if subdir_paths[0] == subdir:
new_images = outputgallery_filenames(subdir)
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, subdir)} - {status}"
)
return [
new_images,
gr.Gallery.update(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image.update(
label=new_label,
visible=len(new_images) == 0,
),
]
else:
# otherwise change nothing,
# (only untyped gradio gr.update() does this)
return [gr.update(), gr.update(), gr.update()]
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
# evt.index is an index into the full list of filenames for
# the current subdirectory
filename = images[evt.index]
params = displayable_metadata(filename)
if params:
if params["source"] == "missing":
return [
"Could not find this image file, refresh the gallery and update the images",
[["Status", "File missing"]],
]
else:
return [
filename,
list(map(list, params["parameters"].items())),
]
return [
filename,
[["Status", "No parameters found"]],
]
def on_outputgallery_filename_change(filename: str) -> list:
exists = filename != "None" and os.path.exists(filename)
return [
# disable or enable each of the sendto button based on whether
# an image is selected
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh
# to populate the subdirectory select box and the images from the most
# recent subdirectory.
#
# We do it at this point rather than setting this up in the controls'
# definitions as when you refresh the browser you always get what was
# *initially* set, which won't include any new subdirectories or images
# that might have created since the application was started. Doing it
# this way means a browser refresh/reload always gets the most
# up-to-date data.
def on_select_tab(subdir_paths, request: gr.Request):
local_client = request.headers["host"].startswith(
"127.0.0.1:"
) or request.headers["host"].startswith("localhost:")
if len(subdir_paths) == 0:
return on_refresh("") + [gr.update(interactive=local_client)]
else:
return (
# Change nothing, (only untyped gr.update() does this)
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
# Unfortunately as of gradio 3.34.0 gr.update against Galleries doesn't
# support things set with .style, nor the elem_classes kwarg, so we have
# to directly set things up via JavaScript if we want the client to take
# notice of our changes to the number of columns after it decides to put
# them back to the original number when we change something
def js_set_columns_in_browser(timeout_length):
return f"""
(new_cols) => {{
setTimeout(() => {{
required_style = "auto ".repeat(new_cols).trim();
gallery = document.querySelector('#outputgallery_gallery .grid-container');
if (gallery) {{
gallery.style.gridTemplateColumns = required_style
}}
}}, {timeout_length});
return []; // prevents console error from gradio
}}
"""
# --- Wire handlers up to the actions
# Many actions reset the number of columns shown in the gallery on the
# browser end, so we have to set them back to what we think they should
# be after the initial action.
#
# None of the actions on this tab trigger inference, and we want the
# user to be able to do them whilst other tabs have ongoing inference
# running. Waiting in the queue behind inference jobs would mean the UI
# can't fully respond until the inference tasks complete,
# hence queue=False on all of these.
set_gallery_columns_immediate = dict(
fn=None,
inputs=[image_columns],
# gradio blanks the UI on Chrome on Linux on gallery select if
# I don't put an output here
outputs=[dev_null],
_js=js_set_columns_in_browser(0),
queue=False,
)
# setting columns after selecting a gallery item needs a real
# timeout length for the number of columns to actually be applied.
# Not really sure why, maybe something has to finish animating?
set_gallery_columns_delayed = dict(
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
)
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
# replacement phase.
clear_gallery = dict(
fn=on_clear_gallery,
inputs=None,
outputs=[gallery, logo],
queue=False,
)
image_columns.change(**set_gallery_columns_immediate)
subdirectories.select(**clear_gallery).then(
on_select_subdir,
[subdirectories],
[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
open_subdir.click(
on_open_subdir, inputs=[subdirectories], queue=False
).then(**set_gallery_columns_immediate)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
gallery.select(
on_select_image,
[gallery_files],
[outputgallery_filename, image_parameters],
queue=False,
).then(**set_gallery_columns_delayed)
outputgallery_filename.change(
on_outputgallery_filename_change,
[outputgallery_filename],
[
outputgallery_sendto_txt2img,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
outputgallery_sendto_upscaler,
],
queue=False,
)
# We should have been given the .select function for our tab, so set it up
def outputgallery_tab_select(select):
select(
fn=on_select_tab,
inputs=[subdirectory_paths],
outputs=[
subdirectories,
subdirectory_paths,
gallery_files,
gallery,
logo,
open_subdir,
],
queue=False,
).then(**set_gallery_columns_immediate)
# We should have been passed a list of components on other tabs that update
# when a new image has generated on that tab, so set things up so the user
# will see that new image if they are looking at today's subdirectory
def outputgallery_watch(components: gr.Textbox):
for component in components:
component.change(
on_new_image,
inputs=[subdirectories, subdirectory_paths, component],
outputs=[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)

View File

@@ -0,0 +1,200 @@
import gradio as gr
import torch
import os
from pathlib import Path
from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
start_message = (
"<|SYSTEM|># StableLM Tuned (Alpha version)"
"\n- StableLM is a helpful and harmless open-source AI language model "
"developed by StabilityAI."
"\n- StableLM is excited to be able to help the user, but will refuse "
"to do anything that could be considered harmful to the user."
"\n- StableLM is more than just an information source, StableLM is also "
"able to write poetry, short stories, and make jokes."
"\n- StableLM will refuse to participate in anything that "
"could harm a human."
)
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
sharkModel = 0
sharded_model = 0
vicuna_model = 0
start_message_vicuna = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
)
past_key_values = None
def chat(curr_system_message, history, model, device, precision):
print(f"In chat for {model}")
global sharded_model
global past_key_values
global vicuna_model
if "vicuna" in model:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
curr_system_message = start_message_vicuna
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = UnshardedVicuna(
"vicuna",
hf_model_path=model,
device=device,
precision=precision,
)
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
prompt = messages.strip()
print("prompt = ", prompt)
for partial_text in vicuna_model.generate(prompt):
history[-1][1] = partial_text
yield history
return history
# else Model is StableLM
global sharkModel
from apps.language_models.src.pipelines.stablelm_pipeline import (
SharkStableLM,
)
if sharkModel == 0:
# max_new_tokens=512
shark_slm = SharkStableLM(
"StableLM"
) # pass elements from UI as required
# Construct the input message string for the model by concatenating the
# current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
curr_system_message = start_message
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
generate_kwargs = dict(prompt=messages)
words_list = shark_slm.generate(**generate_kwargs)
partial_text = ""
for new_text in words_list:
# print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to clean up the message textbox and the updated
# conversation history
yield history
return words_list
with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model = gr.Dropdown(
label="Select Model",
value="TheBloke/vicuna-7B-1.1-HF",
choices=[
"stabilityai/stablelm-tuned-alpha-3b",
"TheBloke/vicuna-7B-1.1-HF",
],
)
supported_devices = available_devices
enabled = len(supported_devices) > 0
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
if enabled
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"int4",
"int8",
"fp16",
"fp32",
],
visible=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
system_msg = gr.Textbox(
start_message, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)

View File

@@ -1,4 +1,3 @@
from pathlib import Path
import os
import torch
import time
@@ -17,7 +16,8 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.utils.png_metadata import import_png_metadata
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
@@ -27,10 +27,14 @@ from apps.stable_diffusion.src import (
save_output_img,
prompt_examples,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platform = args.iree_metal_target_platform
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
@@ -83,9 +87,13 @@ def txt2img_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty",
)
args.hf_model_id = hf_model_id
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
@@ -131,6 +139,7 @@ def txt2img_inf(
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
@@ -187,6 +196,7 @@ def txt2img_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
@@ -199,9 +209,11 @@ def txt2img_inf(
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output, status_label(
"Text-to-Image", i + 1, batch_count, batch_size
)
return generated_imgs, text_output
return generated_imgs, text_output, ""
def encode_pil_to_base64(images):
@@ -227,7 +239,9 @@ def txt2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
res = txt2img_inf(
InputData["prompt"],
@@ -254,6 +268,10 @@ def txt2img_api(
lora_hf_id="",
ondemand=False,
)
# Convert Generator to Subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -271,32 +289,50 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
# janky fix for overflowing text
t2i_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
t2i_model_info = (
f"Custom Model Path: {t2i_model_info}"
)
txt2img_custom_model = gr.Dropdown(
label=f"Models",
info=t2i_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
hf_model_id = gr.Textbox(
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the dropdown "
"on the left and enter model ID here.",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model "
"download URL.",
lines=3,
)
# janky fix for overflowing text
t2i_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
t2i_vae_info = f"VAE Path: {t2i_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"VAE Models",
info=t2i_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -305,7 +341,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
+ get_custom_model_files("vae"),
)
with gr.Column(scale=1, min_width=170):
png_info_img = gr.Image(
txt2img_png_info_img = gr.Image(
label="Import PNG info",
elem_id="txt2img_prompt_image",
type="pil",
@@ -317,26 +353,35 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
t2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
t2i_lora_info = f"LoRA Path: {t2i_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=t2i_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -443,10 +488,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
lambda: -1,
inputs=[],
outputs=[seed],
_js="() => -1",
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
@@ -465,19 +510,17 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
txt2img_status = gr.Textbox(visible=False)
with gr.Row():
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -501,8 +544,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
txt2img_custom_model,
txt2img_hf_model_id,
custom_vae,
precision,
device,
@@ -513,22 +556,30 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
lora_hf_id,
ondemand,
],
outputs=[txt2img_gallery, std_output],
show_progress=args.progress_bar,
outputs=[txt2img_gallery, std_output, txt2img_status],
show_progress="minimal" if args.progress_bar else "none",
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Text-to-Image", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=txt2img_status,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
png_info_img.change(
txt2img_png_info_img.change(
fn=import_png_metadata,
inputs=[
png_info_img,
txt2img_png_info_img,
prompt,
negative_prompt,
steps,
@@ -537,11 +588,14 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
seed,
width,
height,
custom_model,
hf_model_id,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
],
outputs=[
png_info_img,
txt2img_png_info_img,
prompt,
negative_prompt,
steps,
@@ -550,7 +604,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
seed,
width,
height,
custom_model,
hf_model_id,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
],
)

View File

@@ -1,8 +1,6 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
import base64
@@ -17,16 +15,16 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_upscaler_models,
cancel_sd,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generated_imgs_path
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
@@ -66,6 +64,9 @@ def upscaler_inf(
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
@@ -87,9 +88,13 @@ def upscaler_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
args.hf_model_id = hf_model_id
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
@@ -198,26 +203,50 @@ def upscaler_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
save_output_img(high_res_img, img_seed, extra_info)
generated_imgs.append(high_res_img)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(high_res_img, img_seed, extra_info)
generated_imgs.append(high_res_img)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log, status_label(
"Upscaler", current_batch + 1, batch_count, batch_size
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += (
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
text_output += (
f"\nsteps={steps}, "
f"noise_level={noise_level}, "
f"guidance_scale={guidance_scale}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={height}x{width}, "
f"batch_count={batch_count}, "
f"batch_size={batch_size}, "
f"max_length={args.max_length}"
)
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
yield generated_imgs, text_output, ""
def decode_base64_to_image(encoding):
@@ -254,7 +283,9 @@ def upscaler_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = upscaler_inf(
@@ -284,6 +315,9 @@ def upscaler_api(
lora_hf_id="",
ondemand=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -301,30 +335,52 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
# janky fix for overflowing text
upscaler_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
upscaler_model_info = (
f"Custom Model Path: {upscaler_model_info}"
)
upscaler_custom_model = gr.Dropdown(
label=f"Models",
info=upscaler_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-x4-upscaler",
choices=["None"]
+ get_custom_model_files()
+ get_custom_model_files(
custom_checkpoint_type="upscaler"
)
+ predefined_upscaler_models,
)
hf_model_id = gr.Textbox(
upscaler_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
upscaler_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
upscaler_vae_info = f"VAE Path: {upscaler_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=upscaler_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -336,31 +392,42 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
upscaler_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
label="Input Image",
type="pil",
height=300,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
upscaler_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=upscaler_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -475,10 +542,10 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
lambda: -1,
inputs=[],
outputs=[seed],
_js="() => -1",
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
@@ -489,19 +556,18 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
upscaler_status = gr.Textbox(visible=False)
with gr.Row():
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -524,8 +590,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
upscaler_custom_model,
upscaler_hf_model_id,
custom_vae,
precision,
device,
@@ -536,13 +602,21 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
lora_hf_id,
ondemand,
],
outputs=[upscaler_gallery, std_output],
show_progress=args.progress_bar,
outputs=[upscaler_gallery, std_output, upscaler_status],
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Upscaler", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=upscaler_status,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -39,8 +39,16 @@ scheduler_list_cpu_only = [
"LMSDiscrete",
"KDPM2Discrete",
"DPMSolverMultistep",
"DPMSolverMultistep++",
"DPMSolverMultistepKarras",
"DPMSolverMultistepKarras++",
"EulerDiscrete",
"EulerAncestralDiscrete",
"DEISMultistep",
"KDPM2AncestralDiscrete",
"DPMSolverSinglestep",
"DDPM",
"HeunDiscrete",
]
scheduler_list = scheduler_list_cpu_only + [
"SharkEulerDiscrete",
@@ -50,6 +58,7 @@ predefined_models = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"xzuyn/PhotoMerge",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
@@ -58,6 +67,7 @@ predefined_models = [
predefined_paint_models = [
"runwayml/stable-diffusion-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
"xzuyn/PhotoMerge-inpainting",
]
predefined_upscaler_models = [
"stabilityai/stable-diffusion-x4-upscaler",
@@ -72,30 +82,37 @@ def resource_path(relative_path):
return os.path.join(base_path, relative_path)
def create_custom_models_folders():
dir = ["vae", "lora"]
if not args.ckpt_dir:
dir.insert(0, "models")
else:
if not os.path.isdir(args.ckpt_dir):
sys.exit(
f"Invalid --ckpt_dir argument, "
f"{args.ckpt_dir} folder does not exists."
)
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
def get_custom_model_path(model="models"):
# If `--ckpt_dir` is provided it'd override the heirarchical folder
# structure in WebUI :-
# model
# models or args.ckpt_dir
# |___lora
# |___vae
sub_folder = "" if model == "models" else model
if args.ckpt_dir:
return Path(args.ckpt_dir)
match model:
case "models":
return Path(Path.cwd(), "models")
case "vae":
return Path(Path.cwd(), "models/vae")
case "lora":
return Path(Path.cwd(), "models/lora")
case _:
return ""
return Path(Path(args.ckpt_dir), sub_folder)
else:
return Path(Path.cwd(), "models/" + sub_folder)
def get_custom_model_pathfile(custom_model_name, model="models"):
return os.path.join(get_custom_model_path(model), custom_model_name)
def get_custom_model_files(model="models"):
def get_custom_model_files(model="models", custom_checkpoint_type=""):
ckpt_files = []
file_types = custom_model_filetypes
if model == "lora":
@@ -107,6 +124,28 @@ def get_custom_model_files(model="models"):
os.path.join(get_custom_model_path(model), extn)
)
]
match custom_checkpoint_type:
case "inpainting":
files = [
val
for val in files
if val.endswith("inpainting" + extn.removeprefix("*"))
]
case "upscaler":
files = [
val
for val in files
if val.endswith("upscaler" + extn.removeprefix("*"))
]
case _:
files = [
val
for val in files
if not (
val.endswith("inpainting" + extn.removeprefix("*"))
or val.endswith("upscaler" + extn.removeprefix("*"))
)
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)

View File

@@ -0,0 +1,9 @@
# functions for generating labels used in common by tabs across the UI
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
if batch_index < batch_count:
bs = f"x{batch_size}" if batch_size > 1 else ""
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
else:
return f"{tab_name} complete"

View File

@@ -1,31 +1,54 @@
import os
import tempfile
import gradio
from os import listdir
import shutil
from time import time
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
# Clear all gradio tmp images
def clear_gradio_tmp_imgs_folder():
if not os.path.exists(gradio_tmp_imgs_folder):
return
for fileName in listdir(gradio_tmp_imgs_folder):
# Delete tmp png files
if fileName.startswith("tmp") and fileName.endswith(".png"):
os.remove(gradio_tmp_imgs_folder + fileName)
def config_gradio_tmp_imgs_folder():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
def save_pil_to_file(pil_image, dir=None):
if not os.path.exists(gradio_tmp_imgs_folder):
os.mkdir(gradio_tmp_imgs_folder)
file_obj = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
print(
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
)
pil_image.save(file_obj)
return file_obj
# Clear all gradio tmp images from the last session
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
cleanup_start = time()
print(
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
)
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
print(
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
)
# Register save_pil_to_file override
gradio.processing_utils.save_pil_to_file = save_pil_to_file
# older SHARK versions had to workaround gradio bugs and stored things differently
else:
image_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
print(
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
)
cleanup_start = time()
for filename in image_files:
os.remove(shark_tmp + filename)
print(
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
)
else:
print("No temporary images files to clear.")

View File

@@ -0,0 +1,6 @@
from .png_metadata import (
import_png_metadata,
)
from .display import (
displayable_metadata,
)

View File

@@ -0,0 +1,45 @@
import csv
import os
from .format import humanize, humanizable
def csv_path(image_filename: str):
return os.path.join(os.path.dirname(image_filename), "imgs_details.csv")
def has_csv(image_filename: str) -> bool:
return os.path.exists(csv_path(image_filename))
def matching_filename(image_filename: str, row):
# we assume the final column of the csv has the original filename with full path and match that
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
# the value of the OUTPUT key
return os.path.basename(image_filename) in (
row[-1] if isinstance(row, list) else row["OUTPUT"]
)
def parse_csv(image_filename: str):
csv_filename = csv_path(image_filename)
with open(csv_filename, "r", newline="") as csv_file:
# We use a reader or DictReader here for images_details.csv depending on whether we think it
# has headers or not. Having headers means less guessing of the format.
has_header = csv.Sniffer().has_header(csv_file.read(2048))
csv_file.seek(0)
reader = (
csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
)
matches = [
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
humanize(row)
for row in reader
if row
and (has_header or humanizable(row))
and matching_filename(image_filename, row)
]
return matches[0] if matches else {}

View File

@@ -0,0 +1,53 @@
import json
import os
from PIL import Image
from .png_metadata import parse_generation_parameters
from .exif_metadata import has_exif, parse_exif
from .csv_metadata import has_csv, parse_csv
from .format import compact, humanize
def displayable_metadata(image_filename: str) -> dict:
if not os.path.isfile(image_filename):
return {"source": "missing", "parameters": {}}
pil_image = Image.open(image_filename)
# we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads,
# and we go via that for SendTo, and is directly tied to the image)
if "parameters" in pil_image.info:
return {
"source": "png",
"parameters": compact(
parse_generation_parameters(pil_image.info["parameters"])
),
}
# we have a matching json file (next most likely to be accurate when it's there)
json_path = os.path.splitext(image_filename)[0] + ".json"
if os.path.isfile(json_path):
with open(json_path) as params_file:
return {
"source": "json",
"parameters": compact(
humanize(json.load(params_file), includes_filename=False)
),
}
# we have a CSV file so try that (can be different shapes, and it usually has no
# headers/param names so of the things we we *know* have parameters, it's the
# last resort)
if has_csv(image_filename):
params = parse_csv(image_filename)
if params: # we might not have found the filename in the csv
return {
"source": "csv",
"parameters": compact(params), # already humanized
}
# EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something*
if has_exif(image_filename):
return {"source": "exif", "parameters": parse_exif(pil_image)}
# we've got nothing
return None

View File

@@ -0,0 +1,52 @@
from PIL import Image
from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS
def has_exif(image_filename: str) -> bool:
return True if Image.open(image_filename).getexif() else False
def parse_exif(pil_image: Image) -> dict:
img_exif = pil_image.getexif()
# See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594
# I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I
# I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a
# dependency
exif_tags = {
TAGS.get(key, key): str(val)
for (key, val) in img_exif.items()
if key in TAGS
and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo)
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
def try_get_ifd(ifd_id):
try:
return img_exif.get_ifd(ifd_id).items()
except KeyError:
return {}
ifd_tags = {
TAGS.get(key, key): str(val)
for ifd_id in IFD
for (key, val) in try_get_ifd(ifd_id)
if ifd_id != IFD.GPSInfo
and key in TAGS
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
gps_tags = {
GPSTAGS.get(key, key): str(val)
for (key, val) in try_get_ifd(IFD.GPSInfo)
if key in GPSTAGS
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
return {**exif_tags, **ifd_tags, **gps_tags}

View File

@@ -0,0 +1,143 @@
# As SHARK has evolved more columns have been added to images_details.csv. However, since
# no version of the CSV has any headers (yet) we don't actually have anything within the
# file that tells us which parameter each column is for. So this is a list of known patterns
# indexed by length which is what we're going to have to use to guess which columns are the
# right ones for the file we're looking at.
# The same ordering is used for JSON, but these do have key names, however they are not very
# human friendly, nor do they match up with the what is written to the .png headers
# So these are functions to try and get something consistent out the raw input from all
# these sources
PARAMS_FORMATS = {
9: {
"VARIANT": "Model",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"OUTPUT": "Filename",
},
10: {
"MODEL": "Model",
"VARIANT": "Variant",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"OUTPUT": "Filename",
},
12: {
"VARIANT": "Model",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
},
}
PARAMS_FORMAT_CURRENT = {
"VARIANT": "Model",
"VAE": "VAE",
"LORA": "LoRA",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
}
def compact(metadata: dict) -> dict:
# we don't want to alter the original dictionary
result = dict(metadata)
# discard the filename because we should already have it
if result.keys() & {"Filename"}:
result.pop("Filename")
# make showing the sizes more compact by using only one line each
if result.keys() & {"Size-1", "Size-2"}:
result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}"
elif result.keys() & {"Height", "Width"}:
result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}"
if result.keys() & {"Hires resize-1", "Hires resize-1"}:
hires_y = result.pop("Hires resize-1")
hires_x = result.pop("Hires resize-2")
if hires_x == 0 and hires_y == 0:
result["Hires resize"] = "None"
else:
result["Hires resize"] = f"{hires_y}x{hires_x}"
# remove VAE if it exists and is empty
if (result.keys() & {"VAE"}) and (
not result["VAE"] or result["VAE"] == "None"
):
result.pop("VAE")
# remove LoRA if it exists and is empty
if (result.keys() & {"LoRA"}) and (
not result["LoRA"] or result["LoRA"] == "None"
):
result.pop("LoRA")
return result
def humanizable(metadata: dict | list[str], includes_filename=True) -> dict:
lookup_key = len(metadata) + (0 if includes_filename else 1)
return lookup_key in PARAMS_FORMATS.keys()
def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
lookup_key = len(metadata) + (0 if includes_filename else 1)
# For lists we can only work based on the length, we have no other information
if isinstance(metadata, list):
if humanizable(metadata, includes_filename):
return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata))
else:
raise KeyError(
f"Humanize could not find the format for a parameter list of length {len(metadata)}"
)
# For dictionaries we try to use the matching length parameter format if
# available, otherwise we just use the current format which is assumed to
# have everything currently known about. Then we swap keys in the metadata
# that match keys in the format for the friendlier name that we have set
# in the format value
if isinstance(metadata, dict):
if humanizable(metadata, includes_filename):
format = PARAMS_FORMATS[lookup_key]
else:
format = PARAMS_FORMAT_CURRENT
return {
format[key]: metadata[key]
for key in format.keys()
if key in metadata.keys() and metadata[key]
}
raise TypeError("Can only humanize parameter lists or dictionaries")

View File

@@ -62,6 +62,82 @@ def parse_generation_parameters(x: str):
return res
def try_find_model_base_from_png_metadata(
file: str, folder: str = "models"
) -> str:
custom = ""
# Remove extension from file info
if file.endswith(".safetensors") or file.endswith(".ckpt"):
file = Path(file).stem
# Check for the file name match with one of the local ckpt or safetensors files
if Path(get_custom_model_pathfile(file + ".ckpt", folder)).is_file():
custom = file + ".ckpt"
if Path(
get_custom_model_pathfile(file + ".safetensors", folder)
).is_file():
custom = file + ".safetensors"
return custom
def find_model_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
png_hf_id = ""
png_custom = ""
if key in metadata:
model_file = metadata[key]
png_custom = try_find_model_base_from_png_metadata(model_file)
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if model_file in predefined_models:
png_custom = model_file
# If nothing had matched, check vendor/hf_model_id
if not png_custom and model_file.count("/"):
png_hf_id = model_file
# No matching model was found
if not png_custom and not png_hf_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% model_file
)
return png_custom, png_hf_id
def find_vae_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> str:
vae_custom = ""
if key in metadata:
vae_file = metadata[key]
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
# VAE input is optional, should not print or throw an error if missing
return vae_custom
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
lora_custom = ""
if key in metadata:
lora_file = metadata[key]
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
def import_png_metadata(
pil_data,
prompt,
@@ -74,40 +150,21 @@ def import_png_metadata(
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
png_hf_model_id = ""
png_custom_model = ""
if "Model" in metadata:
# Remove extension from model info
if metadata["Model"].endswith(".safetensors") or metadata[
"Model"
].endswith(".ckpt"):
metadata["Model"] = Path(metadata["Model"]).stem
# Check for the model name match with one of the local ckpt or safetensors files
if Path(
get_custom_model_pathfile(metadata["Model"] + ".ckpt")
).is_file():
png_custom_model = metadata["Model"] + ".ckpt"
if Path(
get_custom_model_pathfile(metadata["Model"] + ".safetensors")
).is_file():
png_custom_model = metadata["Model"] + ".safetensors"
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if metadata["Model"] in predefined_models:
png_custom_model = metadata["Model"]
# If nothing had matched, check vendor/hf_model_id
if not png_custom_model and metadata["Model"].count("/"):
png_hf_model_id = metadata["Model"]
# No matching model was found
if not png_custom_model and not png_hf_model_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% metadata["Model"]
)
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
@@ -115,12 +172,24 @@ def import_png_metadata(
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
hf_model_id = ""
if "Model" in metadata and png_hf_model_id:
custom_model = "None"
hf_model_id = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
if "Prompt" in metadata:
prompt = metadata["Prompt"]
if "Sampler" in metadata:
@@ -149,4 +218,7 @@ def import_png_metadata(
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
)

View File

@@ -40,7 +40,7 @@ cmake --build build/
*Prepare the model*
```bash
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
```
*Prepare the input*
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
```bash
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
```
VAE and Autoencoder are also available
```bash
# VAE
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
# CLIP Autoencoder
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
```

View File

@@ -21,7 +21,7 @@ endif()
# Compile mnist.mlir to mnist.vmfb.
set(_COMPILE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)
set(_COMPILE_ARGS)
list(APPEND _COMPILE_ARGS "--iree-input-type=mhlo")
list(APPEND _COMPILE_ARGS "--iree-input-type=auto")
list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu")
list(APPEND _COMPILE_ARGS "${IREE_SOURCE_DIR}/samples/models/mnist.mlir")
list(APPEND _COMPILE_ARGS "-o")

View File

@@ -24,7 +24,9 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=100)
width=150,
height=100,
)
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
prompt_data = dict()
@@ -37,7 +39,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_body"):
# TODO: add ability to search image by typing
with gr.Column(scale=1, min_width=600):
image = gr.Image(type="filepath").style(height=512)
image = gr.Image(type="filepath", height=512)
with gr.Column(scale=1, min_width=600):
prompts = gr.Dropdown(

View File

@@ -1,3 +1,3 @@
# SHARK Annotator
gradio==3.15.0
gradio==3.34.0
jsonlines

View File

@@ -14,4 +14,4 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79
include = '\.pyi?$'
exclude = "apps/language_models/scripts/vicuna.py"

View File

@@ -16,7 +16,7 @@ parameterized
# Add transformers, diffusers and scipy since it most commonly used
transformers
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
diffusers
scipy
ftfy
gradio
@@ -26,7 +26,14 @@ safetensors
opencv-python
scikit-image
pytorch_lightning # for runwayml models
tk
pywebview
sentencepiece
py-cpuinfo
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller
# low precision vicuna
brevitas @ git+https://github.com/Xilinx/brevitas.git@llm

243
rest_api_tests/api_test.py Normal file
View File

@@ -0,0 +1,243 @@
import requests
from PIL import Image
import base64
from io import BytesIO
def upscaler_test():
# Define values here
prompt = ""
negative_prompt = ""
seed = 2121991605
height = 512
width = 512
steps = 50
noise_level = 10
cfg_scale = 7
image_path = r"./rest_api_tests/dog.png"
# Converting Image to base64
img_file = open(image_path, "rb")
init_images = [
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
]
url = "http://127.0.0.1:8080/sdapi/v1/upscaler"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"height": height,
"width": width,
"steps": steps,
"noise_level": noise_level,
"cfg_scale": cfg_scale,
"init_images": init_images,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"response from server was : {res.status_code}")
def img2img_test():
# Define values here
prompt = "Paint a rabbit riding on the dog"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
height = 512
width = 512
steps = 50
denoising_strength = 0.75
cfg_scale = 7
image_path = r"./rest_api_tests/dog.png"
# Converting Image to Base64
img_file = open(image_path, "rb")
init_images = [
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
]
url = "http://127.0.0.1:8080/sdapi/v1/img2img"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"init_images": init_images,
"height": height,
"width": width,
"steps": steps,
"denoising_strength": denoising_strength,
"cfg_scale": cfg_scale,
"seed": seed,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"response from server was : {res.status_code}")
# NOTE Uncomment below to save the picture
# print("Extracting response object")
# response_obj = res.json()
# img_b64 = response_obj.get("images", [False])[0] or response_obj.get(
# "image"
# )
# img_b2 = base64.b64decode(img_b64.replace("data:image/png;base64,", ""))
# im_file = BytesIO(img_b2)
# response_img = Image.open(im_file)
# print("Saving Response Image to: response_img")
# response_img.save(r"rest_api_tests/response_img.png")
def inpainting_test():
prompt = "Paint a rabbit riding on the dog"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
height = 512
width = 512
steps = 50
noise_level = 10
cfg_scale = 7
is_full_res = False
full_res_padding = 32
image_path = r"./rest_api_tests/dog.png"
img_file = open(image_path, "rb")
image = (
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
)
img_file = open(image_path, "rb")
mask = (
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
)
url = "http://127.0.0.1:8080/sdapi/v1/inpaint"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": image,
"mask": mask,
"height": height,
"width": width,
"steps": steps,
"noise_level": noise_level,
"cfg_scale": cfg_scale,
"seed": seed,
"is_full_res": is_full_res,
"full_res_padding": full_res_padding,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"[Inpainting] response from server was : {res.status_code}")
def outpainting_test():
prompt = "Paint a rabbit riding on the dog"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
height = 512
width = 512
steps = 50
cfg_scale = 7
color_variation = 0.2
noise_q = 0.2
directions = ["up", "down", "right", "left"]
pixels = 32
mask_blur = 64
image_path = r"./rest_api_tests/dog.png"
# Converting Image to Base64
img_file = open(image_path, "rb")
init_images = [
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
]
url = "http://127.0.0.1:8080/sdapi/v1/outpaint"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"height": height,
"width": width,
"steps": steps,
"cfg_scale": cfg_scale,
"color_variation": color_variation,
"noise_q": noise_q,
"directions": directions,
"pixels": pixels,
"mask_blur": mask_blur,
"init_images": init_images,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"[Outpaint] response from server was : {res.status_code}")
def txt2img_test():
prompt = "Paint a rabbit in a top hate"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
height = 512
width = 512
steps = 50
cfg_scale = 7
url = "http://127.0.0.1:8080/sdapi/v1/txt2img"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"height": height,
"width": width,
"steps": steps,
"cfg_scale": cfg_scale,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"[txt2img] response from server was : {res.status_code}")
if __name__ == "__main__":
txt2img_test()
img2img_test()
upscaler_test()
inpainting_test()
outpainting_test()

BIN
rest_api_tests/dog.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

View File

@@ -39,7 +39,7 @@ setup(
install_requires=[
"numpy",
"PyYAML",
"torch-mlir>=20221021.633",
"torch-mlir==20230620.875",
]
+ backend_deps,
)

View File

@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --pre torch-mlir==20230620.875 torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html

View File

@@ -2,9 +2,10 @@
# Sets up a venv suitable for running samples.
# e.g:
# ./setup_venv.sh #setup a default $PYTHON3 shark.venv
# Environment Variables by the script.
# Environment variables used by the script.
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
# VENV_DIR=myshark.venv #create a venv called myshark.venv
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
# IMPORTER=1 #Install importer deps
# BENCHMARK=1 #Install benchmark deps
@@ -26,15 +27,22 @@ PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; prin
echo "Python: $PYTHON"
echo "Python version: $PYTHON_VERSION_X_Y"
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
VENV_DIR=${VENV_DIR:-shark.venv}
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
PYTHON="$(which python3)"
else
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
if [ "$PYTHON_VERSION_X_Y" != "3.11" ]; then
echo "Error: Python version 3.11 is required."
exit 1
fi
if [[ "$SKIP_VENV" != "1" ]]; then
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
VENV_DIR=${VENV_DIR:-shark.venv}
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
PYTHON="$(which python3)"
else
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
fi
fi
Red=`tput setaf 1`
@@ -80,7 +88,7 @@ if [ "$torch_mlir_bin" = true ]; then
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
else
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
$PYTHON -m pip install --pre torch-mlir==20230620.875 -f https://llvm.github.io/torch-mlir/package-index/
if [ $? -eq 0 ];then
echo "Successfully Installed torch-mlir"
else
@@ -147,8 +155,9 @@ if [[ ! -z "${ONNX}" ]]; then
fi
fi
if [[ -z "${CONDA_PREFIX}" ]]; then
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
echo "${Green}Before running examples activate venv with:"
echo " ${Green}source $VENV_DIR/bin/activate"
fi
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@llm

View File

@@ -0,0 +1,28 @@
import importlib
import logging
from torch._dynamo import register_backend
log = logging.getLogger(__name__)
@register_backend
def shark(model, inputs, *, options):
try:
from shark.dynamo_backend.utils import SharkBackend
except ImportError:
log.exception(
"Unable to import SHARK - High Performance Machine Learning Distribution"
"Please install the right version of SHARK that matches the PyTorch version being used. "
"Refer to https://github.com/nod-ai/SHARK/ for details."
)
raise
return SharkBackend(model, inputs, options)
def has_shark():
try:
importlib.import_module("shark")
return True
except ImportError:
return False

View File

@@ -0,0 +1,154 @@
import functools
from typing import List, Optional
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._functorch.compile_utils import strip_overloads
from shark.shark_inference import SharkInference
from torch._decomp import get_decompositions
from torch.func import functionalize
import io
import torch_mlir
# TODO: Control decompositions.
def default_decompositions():
return get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
]
)
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
class SharkBackend:
def __init__(
self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict
):
self.fx_g = fx_g
self.inputs = inputs
self.shark_module = None
self.device: str = options.get("device", "cpu")
self.was_unwrapped: bool = False
self.none_indices: list = []
self._modify_fx_g()
self.compile()
def _modify_fx_g(self):
self.none_indices = _remove_nones(self.fx_g)
self.was_unwrapped = _unwrap_single_tuple_return(self.fx_g)
def compile(self):
gm = make_fx(
functionalize(self.fx_g),
decomposition_table=default_decompositions(),
)(*self.inputs)
gm.graph.set_codegen(torch.fx.graph.CodeGen())
gm.recompile()
strip_overloads(gm)
ts_g = torch.jit.script(gm)
mlir_module = torch_mlir.compile(
ts_g, self.inputs, output_type="linalg-on-tensors"
)
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode,
device=self.device,
mlir_dialect="tm_tensor",
)
shark_module.compile(extra_args=[])
self.shark_module = shark_module
def __call__(self, *inputs):
np_inputs = [x.contiguous().detach().cpu().numpy() for x in inputs]
np_outs = self.shark_module("forward", np_inputs)
if self.was_unwrapped:
np_outs = [
np_outs,
]
if not isinstance(np_outs, list):
res = torch.from_numpy(np_outs)
return res
result = [torch.from_numpy(x) for x in np_outs]
for r_in in self.none_indices:
result.insert(r_in, None)
result = tuple(result)
return result

View File

@@ -1,70 +1,25 @@
import torch
import torch_mlir
import torch._dynamo as torchdynamo
from shark.sharkdynamo.utils import make_shark_compiler
import shark
import warnings, logging
warnings.simplefilter("ignore")
torchdynamo.config.log_level = logging.ERROR
def foo(x, a):
if x.shape[0] > 3:
return x + a
else:
return x + 3
torchdynamo.reset()
shark_options = {"device": "cpu"}
compiled = torch.compile(foo, backend="shark", options=shark_options)
input = torch.ones(4)
@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
)
def foo(t):
return 2 * t
x = compiled(input, input)
example_input = torch.rand((2, 3))
x = foo(example_input)
print(x)
input = torch.ones(3)
torchdynamo.reset()
x = compiled(input, input)
@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
)
def foo(a, b):
x = a / (a + 1)
if b.sum() < 0:
b = b * -1
return x * b
print(foo(torch.rand((2, 3)), -torch.rand((2, 3))))
torchdynamo.reset()
@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
)
def foo(a):
for i in range(10):
a += 1.0
return a
print(foo(torch.rand((1, 2))))
torchdynamo.reset()
@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
)
def test_unsupported_types(t, y):
return t, 2 * y
str_input = "hello"
tensor_input = torch.randn(2)
print(test_unsupported_types(str_input, tensor_input))
print(x)

View File

@@ -0,0 +1,72 @@
import torch
import torch_mlir
from shark.shark_inference import SharkInference
from shark.shark_compile import shark_compile_through_fx
from MEGABYTE_pytorch import MEGABYTE
import os
class MegaModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = MEGABYTE(
num_tokens=16000, # number of tokens
dim=(
512,
256,
), # transformer model dimension (512 for coarsest, 256 for fine in this example)
max_seq_len=(
1024,
4,
), # sequence length for global and then local. this can be more than 2
depth=(
6,
4,
), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
dim_head=64, # dimension per head
heads=8, # number of attention heads
flash_attn=True, # use flash attention
)
def forward(self, input):
return self.model(input)
megaModel = MegaModel()
inputs = [torch.randint(0, 16000, (1, 1024, 4))]
# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :-
# 1. aten.alias
shark_module, _ = shark_compile_through_fx(
model=megaModel,
inputs=inputs,
extended_model_name="mega_shark",
is_f16=False,
f16_input_mask=None,
save_dir=os.getcwd(),
debug=False,
generate_or_load_vmfb=True,
extra_args=[],
device="cuda",
mlir_dialect="tm_tensor",
)
# logits = model(x)
def print_output_info(output, msg):
print("\n", msg)
print("\n\t", output.shape)
ans = shark_module("forward", inputs)
print_output_info(torch.from_numpy(ans), "SHARK's output")
ans = megaModel.forward(*inputs)
print_output_info(ans, "ORIGINAL Model's output")
# and sample from the logits accordingly
# or you can use the generate function
# NEED TO LOOK AT THIS LATER IF REQUIRED IN SHARK.
# sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)

View File

@@ -0,0 +1,73 @@
from transformers import AutoTokenizer, FlaxAutoModel
import torch
import jax
from typing import Union, Dict, List, Any
import numpy as np
from shark.shark_inference import SharkInference
import io
NumpyTree = Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]
def convert_torch_tensor_tree_to_numpy(
tree: Union[torch.tensor, Dict[str, torch.tensor], List[torch.tensor]]
) -> NumpyTree:
return jax.tree_util.tree_map(
lambda torch_tensor: torch_tensor.cpu().detach().numpy(), tree
)
def convert_int64_to_int32(tree: NumpyTree) -> NumpyTree:
return jax.tree_util.tree_map(
lambda tensor: np.array(tensor, dtype=np.int32)
if tensor.dtype == np.int64
else tensor,
tree,
)
def get_sample_input():
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased"
)
inputs_torch = tokenizer("Hello, World!", return_tensors="pt")
return convert_int64_to_int32(
convert_torch_tensor_tree_to_numpy(inputs_torch.data)
)
def get_jax_model():
return FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
def export_jax_to_mlir(jax_model: Any, sample_input: NumpyTree):
model_mlir = jax.jit(jax_model).lower(**sample_input).compiler_ir()
byte_stream = io.BytesIO()
model_mlir.operation.write_bytecode(file=byte_stream)
return byte_stream.getvalue()
def assert_array_list_allclose(x, y, *args, **kwargs):
assert len(x) == len(y)
for a, b in zip(x, y):
np.testing.assert_allclose(
np.asarray(a), np.asarray(b), *args, **kwargs
)
sample_input = get_sample_input()
jax_model = get_jax_model()
mlir = export_jax_to_mlir(jax_model, sample_input)
# Compile and load module.
shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo")
shark_inference.compile()
# Run main function.
result = shark_inference("main", jax.tree_util.tree_flatten(sample_input)[0])
# Run JAX model.
reference_result = jax.tree_util.tree_flatten(jax_model(**sample_input))[0]
# Verify result.
assert_array_list_allclose(result, reference_result, atol=1e-5)

View File

@@ -0,0 +1,6 @@
flax
jax[cpu]
nodai-SHARK
orbax
transformers
torch

View File

@@ -70,11 +70,11 @@ mlir_model, func_name, inputs, golden_out = download_model(
"resnet50", frontend="torch"
)
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
shark_module = SharkInference(mlir_model, mlir_dialect="linalg")
shark_module.compile()
path = shark_module.save_module()
shark_module.load_module(path)
result = shark_module.forward((img.detach().numpy(),))
result = shark_module("forward", (img.detach().numpy(),))
print("The top 3 results obtained via shark_runner is:")
print(top3_possibilities(torch.from_numpy(result)))

View File

@@ -1,4 +1,4 @@
# Copyright 2020 The Nod Team. All rights reserved.
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -45,10 +45,15 @@ def run_cmd(cmd, debug=False):
def iree_device_map(device):
uri_parts = device.split("://", 2)
iree_driver = (
_IREE_DEVICE_MAP[uri_parts[0]]
if uri_parts[0] in _IREE_DEVICE_MAP
else uri_parts[0]
)
if len(uri_parts) == 1:
return _IREE_DEVICE_MAP[uri_parts[0]]
return iree_driver
else:
return f"{_IREE_DEVICE_MAP[uri_parts[0]]}://{uri_parts[1]}"
return f"{iree_driver}://{uri_parts[1]}"
def get_supported_device_list():
@@ -57,9 +62,12 @@ def get_supported_device_list():
_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"AMD-AIE": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
"metal": "metal",
"rocm": "rocm",
"intel-gpu": "level_zero",
}
@@ -68,14 +76,17 @@ _IREE_DEVICE_MAP = {
def iree_target_map(device):
if "://" in device:
device = device.split("://")[0]
return _IREE_TARGET_MAP[device]
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
_IREE_TARGET_MAP = {
"cpu": "llvm-cpu",
"cpu-task": "llvm-cpu",
"AMD-AIE": "llvm-cpu",
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
"metal": "metal",
"rocm": "rocm",
"intel-gpu": "opencl-spirv",
}
@@ -92,11 +103,13 @@ def check_device_drivers(device):
subprocess.check_output("nvidia-smi")
except Exception:
return True
elif device in ["metal", "vulkan"]:
elif device in ["vulkan"]:
try:
subprocess.check_output("vulkaninfo")
except Exception:
return True
elif device == "metal":
return False
elif device in ["intel-gpu"]:
try:
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])
@@ -110,10 +123,8 @@ def check_device_drivers(device):
subprocess.check_output("rocminfo")
except Exception:
return True
# Unknown device.
else:
return True
# Unknown device. We assume drivers are installed.
return False

View File

@@ -1,4 +1,4 @@
# Copyright 2020 The Nod Team. All rights reserved.
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,15 +14,19 @@
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import iree_device_map, iree_target_map
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
from shark.iree_utils.benchmark_utils import *
from shark.parser import shark_args
import numpy as np
import os
import re
import tempfile
from pathlib import Path
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan"]:
@@ -30,19 +34,39 @@ def get_iree_device_args(device, extra_args=[]):
f"Specific device selection only supported for vulkan now."
f"Proceeding with {device} as device."
)
device_num = device_uri[1]
else:
device_num = 0
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
return get_iree_cpu_args()
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
return (
get_iree_cpu_args()
+ data_tiling_flag
+ u_kernel_flag
+ stack_size_flag
)
if device_uri[0] == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
return get_iree_gpu_args()
if device_uri[0] in ["metal", "vulkan"]:
if device_uri[0] == "vulkan":
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(extra_args=extra_args)
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args
return get_iree_metal_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
@@ -54,10 +78,9 @@ def get_iree_device_args(device, extra_args=[]):
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
return ["--iree-llvmcpu-target-cpu-features=host"]
elif frontend in ["tensorflow", "tf", "mhlo"]:
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
return [
"--iree-llvmcpu-target-cpu-features=host",
"--iree-mhlo-demote-i64-to-i32=false",
"--iree-flow-demote-i64-to-i32",
]
else:
@@ -170,8 +193,10 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
vmfb_file.close()
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
vm_module = ireert.VmModule.from_buffer(
config.vm_instance,
flatbuffer_blob,
warn_if_copy=False,
)
benchmark_cl = build_benchmark_args_non_tensor_input(
@@ -259,8 +284,8 @@ def compile_module_to_flatbuffer(
args += extra_args
if frontend in ["tensorflow", "tf"]:
input_type = "mhlo"
elif frontend in ["mhlo", "tosa"]:
input_type = "auto"
elif frontend in ["stablehlo", "tosa"]:
input_type = frontend
elif frontend in ["tflite", "tflite-tosa"]:
input_type = "tosa"
@@ -302,15 +327,72 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
vm_module = ireert.VmModule.from_buffer(
config.vm_instance, flatbuffer_blob, warn_if_copy=False
)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
ModuleCompiled = ctx.modules.module
ModuleCompiled = getattr(ctx.modules, vm_module.name)
return ModuleCompiled, config
def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
):
instance = ireert.VmInstance()
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device_by_uri(
device,
allocators=[],
)
# First get configs.
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(mmaped_vmfb)
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
return mmaped_vmfb, config, temp_file_to_unlink
def get_iree_compiled_module(
module,
device: str,
@@ -318,19 +400,58 @@ def get_iree_compiled_module(
model_config_path: str = None,
extra_args: list = [],
device_idx: int = None,
mmap: bool = False,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args
)
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
# Got to find a cleaner way to unlink/delete the temporary file since
# we're setting delete=False when creating NamedTemporaryFile. That's why
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
if mmap:
print(f"Will load the compiled module as a mmapped temporary file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_blob, device, device_idx
)
else:
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
)
ret_params = {
"vmfb": vmfb,
"config": config,
"temp_file_to_unlink": temp_file_to_unlink,
}
return ret_params
def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None):
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
def load_flatbuffer(
flatbuffer_path: str,
device: str,
device_idx: int = None,
mmap: bool = False,
):
temp_file_to_unlink = None
if mmap:
print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_path, device, device_idx
)
else:
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
)
ret_params = {
"vmfb": vmfb,
"config": config,
"temp_file_to_unlink": temp_file_to_unlink,
}
return ret_params
def export_iree_module_to_vmfb(
@@ -361,7 +482,7 @@ def export_iree_module_to_vmfb(
def export_module_to_mlir_file(module, frontend, directory: str):
# TODO: write proper documentation.
mlir_str = module
if frontend in ["tensorflow", "tf", "mhlo", "tflite"]:
if frontend in ["tensorflow", "tf", "mhlo", "stablehlo", "tflite"]:
mlir_str = module.decode("utf-8")
elif frontend in ["pytorch", "torch"]:
mlir_str = module.operation.get_asm()

View File

@@ -16,6 +16,7 @@
import subprocess
import platform
from shark.parser import shark_args
def get_cpu_count():
@@ -44,4 +45,18 @@ def get_iree_cpu_args():
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
raise Exception(error_message)
print(f"Target triple found:{target_triple}")
return [f"--iree-llvmcpu-target-triple={target_triple}"]
return [
f"--iree-llvmcpu-target-triple={target_triple}",
]
# Get iree runtime flags for cpu
def get_iree_cpu_rt_args():
default = get_cpu_count()
default = default if default <= 8 else default - 2
cpu_count = (
default
if shark_args.task_topology_max_group_count is None
else shark_args.task_topology_max_group_count
)
return [f"--task_topology_max_group_count={cpu_count}"]

View File

@@ -0,0 +1,121 @@
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# All the iree_vulkan related functionalities go here.
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_metal_device_name(device_num=0):
iree_device_dump = run_cmd("iree-run-module --dump_devices")
iree_device_dump = iree_device_dump[0].split("\n\n")
metal_device_list = [
s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s
]
if len(metal_device_list) == 0:
raise ValueError("No device name found in device dump!")
if len(metal_device_list) > 1:
print("Following devices found:")
for i, dname in enumerate(metal_device_list):
print(f"{i}. {dname}")
print(f"Choosing device: {metal_device_list[device_num]}")
return metal_device_list[device_num]
def get_os_name():
if platform.startswith("linux"):
return "linux"
elif platform == "darwin":
return "macos"
elif platform == "win32":
return "windows"
else:
print("Cannot detect OS type, defaulting to linux.")
return "linux"
def get_metal_target_triple(device_name):
"""This method provides a target triple str for specified vulkan device.
Args:
device_name (str): name of the hardware device to be used with vulkan
Returns:
str or None: target triple or None if no match found for given name
"""
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"
else:
triple = None
return triple
def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
for flag in extra_args:
if "-iree-metal-target-platform=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None
if device_name == "" or device_name == [] or device_name is None:
metal_device = get_metal_device_name(device_num=device_num)
else:
metal_device = device_name
triple = get_metal_target_triple(metal_device)
if triple is not None:
print(
f"Found metal device {metal_device}. Using metal target triple {triple}"
)
return f"-iree-metal-target-platform={triple}"
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
or pull up an issue."""
)
print(f"Target : {metal_device}")
return None
def get_iree_metal_args(device_num=0, extra_args=[]):
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]
res_metal_flag = []
metal_triple_flag = None
for arg in extra_args:
if "-iree-metal-target-platform=" in arg:
print(f"Using target triple {arg} from command line args")
metal_triple_flag = arg
break
if metal_triple_flag is None:
metal_triple_flag = get_metal_triple_flag(
device_num=device_num, extra_args=extra_args
)
if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
res_metal_flag.append(vulkan_target_env)
return res_metal_flag
def set_iree_metal_runtime_flags(flags):
for flag in flags:
ireert.flags.parse_flags(flag)
return

View File

@@ -117,7 +117,8 @@ def get_extensions(triple):
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
ext.append("VK_NV_cooperative_matrix")
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
ext.append("VK_KHR_shader_integer_dot_product")
return make_ext_list(ext_list=ext)
@@ -133,9 +134,9 @@ def get_vendor(triple):
return "Apple"
if arch in ["arc", "UHD"]:
return "Intel"
if arch in ["turing", "ampere"]:
if arch in ["turing", "ampere", "pascal"]:
return "NVIDIA"
if arch == "ardeno":
if arch == "adreno":
return "Qualcomm"
if arch == "cpu":
if product == "swiftshader":
@@ -151,7 +152,7 @@ def get_device_type(triple):
return "Unknown"
if arch == "cpu":
return "CPU"
if arch in ["turing", "ampere", "arc"]:
if arch in ["turing", "ampere", "arc", "pascal"]:
return "DiscreteGPU"
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]:
if product == "ivega10":
@@ -228,6 +229,7 @@ def get_vulkan_target_capabilities(triple):
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["shaderIntegerDotProduct"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
@@ -236,12 +238,12 @@ def get_vulkan_target_capabilities(triple):
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
if arch == "rdna3":
# TODO: Get scope value
cap["coopmatCases"] = [
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
]
if product == "rx5700xt":
cap["storagePushConstant16"] = False
cap["storagePushConstant8"] = False
@@ -274,7 +276,7 @@ def get_vulkan_target_capabilities(triple):
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["shaderIntegerDotProduct"] = True
cap["storagePushConstant16"] = False
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
@@ -305,6 +307,7 @@ def get_vulkan_target_capabilities(triple):
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["shaderIntegerDotProduct"] = False
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
@@ -367,6 +370,7 @@ def get_vulkan_target_capabilities(triple):
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = False
cap["shaderIntegerDotProduct"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
@@ -389,6 +393,40 @@ def get_vulkan_target_capabilities(triple):
"ShuffleRelative",
]
elif arch in ["pascal"]:
cap["maxComputeSharedMemorySize"] = 49152
cap["maxComputeWorkGroupInvocations"] = 1536
cap["maxComputeWorkGroupSize"] = [1536, 1024, 64]
cap["subgroupSize"] = 32
cap["minSubgroupSize"] = 32
cap["maxSubgroupSize"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["shaderFloat16"] = False
cap["shaderFloat64"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["shaderIntegerDotProduct"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch in ["ampere", "turing"]:
cap["maxComputeSharedMemorySize"] = 49152
cap["maxComputeWorkGroupInvocations"] = 1024
@@ -413,6 +451,7 @@ def get_vulkan_target_capabilities(triple):
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["shaderIntegerDotProduct"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True

View File

@@ -21,7 +21,7 @@ from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_vulkan_device_name():
def get_vulkan_device_name(device_num=0):
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
@@ -31,8 +31,8 @@ def get_vulkan_device_name():
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing first one: {vulkaninfo_list[0]}")
return vulkaninfo_list[0]
print(f"Choosing device: {vulkaninfo_list[device_num]}")
return vulkaninfo_list[device_num]
def get_os_name():
@@ -114,19 +114,24 @@ def get_vulkan_target_triple(device_name):
# Intel Targets
elif any(x in device_name for x in ("A770", "A750")):
triple = f"arc-770-{system_os}"
# Adreno Targets
elif all(x in device_name for x in ("Adreno", "740")):
triple = f"adreno-a740-{system_os}"
else:
triple = None
return triple
def get_vulkan_triple_flag(device_name="", extra_args=[]):
def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
for flag in extra_args:
if "-iree-vulkan-target-triple=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None
if device_name == "" or device_name == [] or device_name is None:
vulkan_device = get_vulkan_device_name()
vulkan_device = get_vulkan_device_name(device_num=device_num)
else:
vulkan_device = device_name
triple = get_vulkan_target_triple(vulkan_device)
@@ -144,7 +149,7 @@ def get_vulkan_triple_flag(device_name="", extra_args=[]):
return None
def get_iree_vulkan_args(extra_args=[]):
def get_iree_vulkan_args(device_num=0, extra_args=[]):
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
res_vulkan_flag = []
@@ -156,7 +161,9 @@ def get_iree_vulkan_args(extra_args=[]):
break
if vulkan_triple_flag is None:
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
vulkan_triple_flag = get_vulkan_triple_flag(
device_num=device_num, extra_args=extra_args
)
if vulkan_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)

View File

@@ -30,8 +30,8 @@ import os
import sys
from typing import Dict, List
import iree.compiler._mlir_libs
from iree.compiler import ir
from iree.compiler.transforms import ireec as ireec_trans
def model_annotation(
@@ -311,11 +311,18 @@ def add_attributes(op: ir.Operation, config: List[Dict]):
split_k = config["split_k"]
elif "SPIRV" in config["pipeline"]:
pipeline = config["pipeline"]
tile_sizes = [
config["work_group_tile_sizes"],
config["parallel_tile_sizes"],
config["reduction_tile_sizes"],
]
if pipeline == "SPIRVMatmulPromoteVectorize":
tile_sizes = [
config["work_group_tile_sizes"]
+ [config["reduction_tile_sizes"][-1]],
]
else:
tile_sizes = [
config["work_group_tile_sizes"],
config["parallel_tile_sizes"],
config["reduction_tile_sizes"],
]
workgroup_size = config["work_group_sizes"]
if "vector_tile_sizes" in config.keys():
tile_sizes += [config["vector_tile_sizes"]]
@@ -409,7 +416,6 @@ def shape_list_to_string(input):
def create_context() -> ir.Context:
context = ir.Context()
ireec_trans.register_all_dialects(context)
context.allow_unregistered_dialects = True
return context

View File

@@ -119,5 +119,11 @@ parser.add_argument(
"to augment the base device allocator",
choices=["debug", "caching"],
)
parser.add_argument(
"--task_topology_max_group_count",
type=str,
default=None,
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
)
shark_args, unknown = parser.parse_known_args()

99
shark/shark_compile.py Normal file
View File

@@ -0,0 +1,99 @@
import os
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
shark_module = None
if os.path.isfile(vmfb_path):
shark_module = SharkInference(
None,
device=device,
mlir_dialect=mlir_dialect,
)
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
def compile_module(
shark_module, extended_model_name, generate_vmfb, extra_args=[]
):
if generate_vmfb:
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
if os.path.isfile(vmfb_path):
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
print(
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
)
path = shark_module.save_module(
os.getcwd(), extended_model_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
def shark_compile_through_fx(
model,
inputs,
extended_model_name,
is_f16=False,
f16_input_mask=None,
save_dir=tempfile.gettempdir(),
debug=False,
generate_or_load_vmfb=True,
extra_args=[],
device=None,
mlir_dialect="tm_tensor",
):
if generate_or_load_vmfb:
shark_module = load_vmfb(
extended_model_name=extended_model_name,
device=device,
mlir_dialect=mlir_dialect,
extra_args=extra_args,
)
if shark_module:
return (
shark_module,
None,
)
from shark.parser import shark_args
if "cuda" in device:
shark_args.enable_tf32 = True
(
mlir_module,
_,
) = import_with_fx(
model=model,
inputs=inputs,
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=extended_model_name,
save_dir=save_dir,
)
shark_module = SharkInference(
mlir_module,
device=device,
mlir_dialect=mlir_dialect,
)
return (
compile_module(
shark_module,
extended_model_name,
generate_vmfb=generate_or_load_vmfb,
extra_args=extra_args,
),
mlir_module,
)

View File

@@ -60,10 +60,15 @@ def download_public_file(
else:
continue
destination_filename = os.path.join(destination_folder_name, blob_name)
with open(destination_filename, "wb") as f:
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
storage_client.download_blob_to_file(blob, file_obj)
else:
destination_filename = os.path.join(
destination_folder_name, blob_name
)
if os.path.isdir(destination_filename):
continue
with open(destination_filename, "wb") as f:
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
storage_client.download_blob_to_file(blob, file_obj)
input_type_to_np_dtype = {
@@ -196,7 +201,7 @@ def download_model(
tank_url=None,
frontend=None,
tuned=None,
import_args=None,
import_args={"batch_size": 1},
):
model_name = model_name.replace("/", "_")
dyn_str = "_dynamic" if dynamic else ""
@@ -210,6 +215,9 @@ def download_model(
+ "_BS"
+ str(import_args["batch_size"])
)
elif any(model in model_name for model in ["clip", "unet", "vae"]):
# TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation.
model_dir_name = model_name
else:
model_dir_name = model_name + "_" + frontend
model_dir = os.path.join(WORKDIR, model_dir_name)
@@ -270,6 +278,9 @@ def download_model(
tuned_str = "" if tuned is None else "_" + tuned
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
filename = os.path.join(model_dir, model_name + suffix)
print(
f"Verifying that model artifacts were downloaded successfully to {filename}..."
)
if not os.path.exists(filename):
from tank.generate_sharktank import gen_shark_files

View File

@@ -0,0 +1,206 @@
from typing import Any, Dict, List, Tuple
from collections import defaultdict
from shark.shark_importer import import_with_fx
import torchvision.models as models
import copy
import io
import numpy as np
import sys
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
import torch_mlir
def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
mlir_module = torch_mlir.compile(
fx_g, inputs, output_type="linalg-on-tensors"
)
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode,
device=device,
mlir_dialect="tm_tensor",
)
shark_module.compile(extra_args=[])
return shark_module
def _make_single_op_gm(node, captured_val, compiled_graph):
"""Make a GraphModule that just executes the given node."""
g = torch.fx.Graph()
env = {}
inputs = []
for arg in node.args:
if arg and hasattr(arg, "name"):
env[arg.name] = g.placeholder(arg.name)
if isinstance(captured_val[arg.name], (list, tuple)):
for val in captured_val[arg.name]:
inputs.append(val)
else:
inputs.append(captured_val[arg.name])
call = g.node_copy(node, lambda n: env[n.name])
g.output(call)
g.lint()
single_node = torch.fx.GraphModule(torch.nn.Module(), g)
compiled_module = shark_backend(single_node, inputs)
compiled_graph[node.name] = {
"module": compiled_module,
"inputs": [i for i in env],
"result": None,
}
return
def compiled_graph(gm: torch.fx.GraphModule, attr_info):
compiled_graph = {}
g = gm.graph
for node in g.nodes:
if node.op == "call_function":
if not (
node.target in [torch.ops.aten.empty]
or node.name.startswith("getitem")
):
_make_single_op_gm(node, attr_info, compiled_graph)
# Currently torch.aten.empty has an compilation issue, so running natively.
elif node.target in [torch.ops.aten.empty]:
compiled_graph[node.name] = {
"target": node.target,
"args": node.args,
"kwargs": node.kwargs,
"result": None,
}
# Get item is a simple case takes a tuple and return the tensor at a particular index.
elif node.name.startswith("getitem"):
compiled_graph[node.name] = {
"input": node.args[0].name,
"pos": node.args[1],
"result": None,
}
return compiled_graph
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env: Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target: str):
target_atoms = target.split(".")
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == "placeholder":
result = next(args_iter)
elif node.op == "get_attr":
result = fetch_attr(node.target)
elif node.op == "call_function":
result = node.target(
*load_arg(node.args), **load_arg(node.kwargs)
)
elif node.op == "call_method":
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == "call_module":
result = self.modules[node.target](
*load_arg(node.args), **load_arg(node.kwargs)
)
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return env
# return load_arg(self.graph.result)
resnet18 = models.resnet18(pretrained=True)
resnet18.train(False)
input = (torch.randn(1, 3, 224, 224),)
print(resnet18(input[0]))
fx_graph = import_with_fx(resnet18, input, mlir_type="fx")
shape_prop = ShapeProp(fx_graph)
x = shape_prop.propagate(input[0])
shark_graph = compiled_graph(fx_graph, x)
for key in shark_graph:
if key.startswith("getitem"):
input_val = shark_graph[key]["input"]
pos = shark_graph[key]["pos"]
if input_val not in shark_graph:
shark_graph[key]["result"] = x[input_val][pos].detach()
else:
shark_graph[key]["result"] = shark_graph[input_val]["result"][
pos
].detach()
elif key.startswith("empty"):
operator = shark_graph[key]["target"]
args = shark_graph[key]["args"]
kwargs = shark_graph[key]["kwargs"]
shark_graph[key]["result"] = operator(*args, **kwargs).detach()
else:
input_val = shark_graph[key]["inputs"]
input_tensors = []
for input in input_val:
if input not in shark_graph:
input_tensors.append(x[input].detach())
else:
input_tensors.append(shark_graph[input]["result"])
val = shark_graph[key]["module"]("forward", input_tensors)
if isinstance(val, (tuple, list)):
list_val = []
for v in val:
list_val.append(torch.from_numpy(v))
shark_graph[key]["result"] = list_val
else:
shark_graph[key]["result"] = torch.from_numpy(val)
print(shark_graph)

View File

@@ -0,0 +1,105 @@
import re
import json
import torch_mlir
from iree.compiler import compile_str
from shark.shark_importer import import_with_fx, get_f16_inputs
class GenerateConfigFile:
def __init__(
self,
model,
num_sharding_stages: int,
sharding_stages_id: list[str],
model_input=None,
config_file_path="model_config.json",
):
self.model = model
self.num_sharding_stages = num_sharding_stages
self.sharding_stages_id = sharding_stages_id
assert self.num_sharding_stages == len(
self.sharding_stages_id
), "Number of sharding stages should be equal to the list of their ID"
self.model_input = model_input
self.config_file_path = config_file_path
def split_into_dispatches(
self,
backend,
fx_tracing_required=True,
f16_model=False,
torch_mlir_tracing=False,
):
graph_for_compilation = self.model
if fx_tracing_required:
graph_for_compilation = import_with_fx(
self.model,
self.model_input,
is_f16=f16_model,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
module = torch_mlir.compile(
graph_for_compilation,
(self.model_input),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=torch_mlir_tracing,
verbose=False,
)
module = module.operation.get_asm(large_elements_limit=4)
compiled_module_str = str(
compile_str(
str(module),
target_backends=[backend],
extra_args=[
"--compile-to=flow",
"--mlir-elide-elementsattrs-if-larger=4",
],
)
)
substring_start_idx = [
m.start()
for m in re.finditer("flow.dispatch @", compiled_module_str)
]
dispatch_list = dict()
# dispatch_no is the 'i'th index of a dispatch out of n total dispatches of a model
# dispatch_id is the unique id of a dispatch, multiple instances of the same dispatch
# can occur in a model
for dispatch_no, substring_idx in enumerate(substring_start_idx):
dispatch_idx = (
compiled_module_str[substring_idx:]
.split(":")[0]
.split("@")[-1]
)
key = "dispatch_no_" + str(dispatch_no)
dispatch_list[key] = {n: "None" for n in self.sharding_stages_id}
dispatch_list[key]["dispatch_id"] = dispatch_idx
self.generate_json(dispatch_list)
def split_into_layers(self):
model_dictionary = dict()
for name, m in self.model.named_modules():
if name == "":
continue
# Remove non-leaf nodes from the config as they aren't an operation
substring_before_final_period = name.split(".")[:-1]
substring_before_final_period = ".".join(
substring_before_final_period
)
if substring_before_final_period in model_dictionary:
del model_dictionary[substring_before_final_period]
layer_dict = {n: "None" for n in self.sharding_stages_id}
model_dictionary[name] = layer_dict
self.generate_json(model_dictionary)
def generate_json(self, artifacts):
with open(self.config_file_path, "w") as outfile:
json.dump(artifacts, outfile)

View File

@@ -81,7 +81,7 @@ class SharkImporter:
# NOTE: The default function for torch is "forward" and tf-lite is "main".
def _torch_mlir(self, is_dynamic, tracing_required):
def _torch_mlir(self, is_dynamic, tracing_required, mlir_type):
from shark.torch_mlir_utils import get_torch_mlir_module
return get_torch_mlir_module(
@@ -90,6 +90,7 @@ class SharkImporter:
is_dynamic,
tracing_required,
self.return_str,
mlir_type,
)
def _tf_mlir(self, func_name, save_dir="."):
@@ -120,6 +121,7 @@ class SharkImporter:
tracing_required=False,
func_name="forward",
save_dir="./shark_tmp/",
mlir_type="linalg",
):
if self.frontend in ["torch", "pytorch"]:
if self.inputs == None:
@@ -127,7 +129,10 @@ class SharkImporter:
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module"
)
sys.exit(1)
return self._torch_mlir(is_dynamic, tracing_required), func_name
return (
self._torch_mlir(is_dynamic, tracing_required, mlir_type),
func_name,
)
if self.frontend in ["tf", "tensorflow"]:
return self._tf_mlir(func_name, save_dir), func_name
if self.frontend in ["tflite", "tf-lite"]:
@@ -143,14 +148,23 @@ class SharkImporter:
# Saves `function_name.npy`, `inputs.npz`, `golden_out.npz` and `model_name.mlir` in the directory `dir`.
def save_data(
self, dir, model_name, mlir_data, func_name, inputs, outputs
self,
dir,
model_name,
mlir_data,
func_name,
inputs,
outputs,
mlir_type="linalg",
):
import numpy as np
inputs_name = "inputs.npz"
outputs_name = "golden_out.npz"
func_file_name = "function_name"
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
model_name_mlir = (
model_name + "_" + self.frontend + "_" + mlir_type + ".mlir"
)
print(f"saving {model_name_mlir} to {dir}")
try:
inputs = [x.cpu().detach() for x in inputs]
@@ -186,19 +200,23 @@ class SharkImporter:
dir=tempfile.gettempdir(),
model_name="model",
golden_values=None,
mlir_type="linalg",
):
if self.inputs == None:
print(
f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir."
)
sys.exit(1)
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
model_name_mlir = (
model_name + "_" + self.frontend + "_" + mlir_type + ".mlir"
)
artifact_path = os.path.join(dir, model_name_mlir)
imported_mlir = self.import_mlir(
is_dynamic,
tracing_required,
func_name,
save_dir=artifact_path,
mlir_type=mlir_type,
)
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
# TODO: Check for multiple outputs.
@@ -224,6 +242,7 @@ class SharkImporter:
imported_mlir[1],
self.inputs,
golden_out,
mlir_type,
)
return (
imported_mlir,
@@ -293,7 +312,48 @@ def get_f16_inputs(inputs, is_f16, f16_input_mask):
return tuple(f16_masked_inputs)
def transform_fx(fx_g):
# Upcasts the block/list of ops.
def add_upcast(fx_g):
import torch
for node in fx_g.graph.nodes:
if node.target in [torch.ops.aten.mul]:
# This is a very strict check.
if hasattr(node.args[1], "target"):
if (
node.args[1].target in [torch.ops.aten.rsqrt]
and node.args[1].args[0].target in [torch.ops.aten.add]
and node.args[1].args[0].args[0].target
in [torch.ops.aten.mean]
and node.args[1].args[0].args[0].args[0].target
in [torch.ops.aten.pow]
):
print("found an upcasting block let's upcast it.")
pow_node = node.args[1].args[0].args[0].args[0]
mul_node = node
with fx_g.graph.inserting_before(pow_node):
lhs = pow_node.args[0]
upcast_lhs = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(lhs,),
kwargs={"dtype": torch.float32},
)
pow_node.args = (upcast_lhs, pow_node.args[1])
with fx_g.graph.inserting_before(mul_node):
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(mul_node,),
kwargs={"dtype": torch.float16},
)
mul_node.append(new_node)
mul_node.replace_all_uses_with(new_node)
new_node.args = (mul_node,)
new_node.kwargs = {"dtype": torch.float16}
fx_g.graph.lint()
def transform_fx(fx_g, quantized=False):
import torch
kwargs_dict = {
@@ -301,14 +361,62 @@ def transform_fx(fx_g):
"device": torch.device(type="cpu"),
"pin_memory": False,
}
kwargs_dict1 = {
"dtype": torch.float16,
}
for node in fx_g.graph.nodes:
if node.op == "call_function":
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
if quantized:
continue
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.zeros,
torch.ops.aten.zeros_like,
]:
node.kwargs = kwargs_dict
if node.kwargs.get("dtype") == torch.float32:
node.kwargs = kwargs_dict
# Vicuna
if node.target in [
torch.ops.aten._to_copy,
]:
if node.kwargs.get("dtype") == torch.float32:
node.kwargs = kwargs_dict1
if node.target in [
torch.ops.aten.masked_fill,
]:
if node.args[2] > torch.finfo(torch.half).max:
max_val = torch.finfo(torch.half).max
node.args = (node.args[0], node.args[1], max_val)
elif node.args[2] < torch.finfo(torch.half).min:
min_val = torch.finfo(torch.half).min
node.args = (node.args[0], node.args[1], min_val)
if node.target in [
torch.ops.aten.full,
]:
if node.args[1] > torch.finfo(torch.half).max:
max_val = torch.finfo(torch.half).max
node.args = (node.args[0], max_val)
node.kwargs = kwargs_dict
elif node.args[1] < torch.finfo(torch.half).min:
min_val = torch.finfo(torch.half).min
node.args = (node.args[0], min_val)
node.kwargs = kwargs_dict
# Inputs and outputs of aten.var.mean should be upcasted to fp32.
if node.target in [torch.ops.aten.var_mean]:
with fx_g.graph.inserting_before(node):
@@ -318,6 +426,7 @@ def transform_fx(fx_g):
kwargs={},
)
node.args = (new_node, node.args[1])
if node.name.startswith("getitem"):
with fx_g.graph.inserting_before(node):
if node.args[0].target in [torch.ops.aten.var_mean]:
@@ -330,16 +439,14 @@ def transform_fx(fx_g):
node.replace_all_uses_with(new_node)
new_node.args = (node,)
new_node.kwargs = {"dtype": torch.float16}
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
# Required for cuda debugging.
# for node in fx_g.graph.nodes:
# if node.op == "call_function":
# if node.kwargs.get("device") == torch.device(type="cpu"):
# new_kwargs = node.kwargs.copy()
# new_kwargs["device"] = torch.device(type="cuda")
# node.kwargs = new_kwargs
fx_g.graph.lint()
@@ -381,6 +488,7 @@ def flatten_training_input(inputs):
return tuple(flattened_input)
# TODO: get rid of is_f16 by using precision
# Applies fx conversion to the model and imports the mlir.
def import_with_fx(
model,
@@ -392,10 +500,31 @@ def import_with_fx(
return_str=False,
save_dir=tempfile.gettempdir(),
model_name="model",
mlir_type="linalg",
is_dynamic=False,
tracing_required=False,
precision="fp32",
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
)
from brevitas_examples.llm.llm_quant.export import (
brevitas_layer_export_mode,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
LinearWeightBlockQuantHandlerFwd,
)
from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
matmul_rhs_group_quant_placeholder,
)
from brevitas.backport.fx.experimental.proxy_tensor import (
make_fx as brevitas_make_fx,
)
golden_values = None
if debug:
@@ -403,24 +532,97 @@ def import_with_fx(
golden_values = model(*inputs)
except:
golden_values = None
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
# TODO: Control the decompositions.
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
]
),
)(*inputs)
decomps_list = [
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
]
if precision in ["int4", "int8"]:
export_context_manager = brevitas_layer_export_mode
export_class = block_quant_layer_level_manager(
export_handlers=[LinearWeightBlockQuantHandlerFwd]
)
with export_context_manager(model, export_class):
fx_g = brevitas_make_fx(
model,
decomposition_table=get_decompositions(decomps_list),
)(*inputs)
transform_fx(fx_g, quantized=True)
replace_call_fn_target(
fx_g,
src=matmul_rhs_group_quant_placeholder,
target=torch.ops.brevitas.matmul_rhs_group_quant,
)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
else:
fx_g = make_fx(
model,
decomposition_table=get_decompositions(decomps_list),
)(*inputs)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
@@ -441,13 +643,21 @@ def import_with_fx(
if is_f16:
fx_g = fx_g.half()
transform_fx(fx_g)
# TODO: Have to make it more generic.
add_upcast(fx_g)
fx_g.recompile()
if mlir_type == "fx":
return fx_g
if training:
change_fx_graph_return_to_tuple(fx_g)
inputs = flatten_training_input(inputs)
ts_graph = torch.jit.script(fx_g)
if mlir_type == "torchscript":
return ts_graph
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)
mlir_importer = SharkImporter(
ts_graph,
@@ -458,7 +668,12 @@ def import_with_fx(
if debug: # and not is_f16:
(mlir_module, func_name), _, _ = mlir_importer.import_debug(
dir=save_dir, model_name=model_name, golden_values=golden_values
dir=save_dir,
model_name=model_name,
golden_values=golden_values,
mlir_type=mlir_type,
is_dynamic=is_dynamic,
tracing_required=tracing_required,
)
return mlir_module, func_name

View File

@@ -48,6 +48,8 @@ class SharkInference:
Refer to {https://mlir.llvm.org/docs/Dialects/}
is_benchmark: bool
Whether this SharkInference module should be benchmark-enabled.
mmap: bool
Whether to load/run vmfb using mmap. It's `True` by default.
Methods
-------
@@ -70,6 +72,7 @@ class SharkInference:
dispatch_benchmark: str = None,
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
device_idx: int = None,
mmap: bool = True,
):
self.mlir_module = mlir_module
self.device = shark_args.device if device == "none" else device
@@ -88,6 +91,7 @@ class SharkInference:
)
self.shark_runner = None
self.mmap = mmap
def compile(self, extra_args=[]):
if self.dispatch_benchmarks is not None:
@@ -201,12 +205,14 @@ class SharkInference:
compile_vmfb=False,
extra_args=extra_args,
)
(
self.shark_runner.iree_compilation_module,
self.shark_runner.iree_config,
) = load_flatbuffer(
params = load_flatbuffer(
path,
self.device,
self.device_idx,
mmap=self.mmap,
)
self.shark_runner.iree_compilation_module = params["vmfb"]
self.shark_runner.iree_config = params["config"]
self.shark_runner.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
return

View File

@@ -25,7 +25,14 @@ import sys
# supported dialects by the shark-runtime.
supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite", "tm_tensor"}
supported_dialects = {
"linalg",
"auto",
"stablehlo",
"tosa",
"tf-lite",
"tm_tensor",
}
class SharkRunner:
@@ -78,16 +85,17 @@ class SharkRunner:
if compile_vmfb == True:
# Compile the module to get the .vmfb.
(
self.iree_compilation_module,
self.iree_config,
) = get_iree_compiled_module(
params = get_iree_compiled_module(
self.mlir_module,
self.device,
self.mlir_dialect,
extra_args=self.extra_args,
device_idx=self.device_idx,
)
self.iree_compilation_module = params["vmfb"]
self.iree_config = params["config"]
self.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
def run(self, function_name, inputs: tuple, send_to_host=False):
return get_results(

View File

@@ -59,6 +59,7 @@ class SharkTrainer:
"torch",
"tensorflow",
"tf",
"stablehlo",
"mhlo",
"linalg",
"tosa",
@@ -84,7 +85,7 @@ class SharkTrainer:
"tm_tensor",
extra_args=extra_args,
)
elif self.frontend in ["tensorflow", "tf", "mhlo"]:
elif self.frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
self.shark_runner = SharkRunner(
self.model,
self.input,

View File

@@ -1,11 +0,0 @@
1. Install torchdynamo
- `git clone https://github.com/pytorch/torchdynamo.git`
- `cd torchdynamo`
- `python -m pip install -r requirements.txt`
- `python setup.py develop`
2. Install functorch
- `python -m pip install -v "git+https://github.com/pytorch/pytorch.git@$(python -c "import torch.version; print(torch.version.git_version)")#subdirectory=functorch"`
3. Run examples.
- `python shark/examples/shark_dynamo/basic_examples.py`

View File

@@ -1,163 +0,0 @@
import functools
import time
from typing import List, Optional
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._functorch.compile_utils import strip_overloads
from shark.shark_inference import SharkInference
from torch._decomp import get_decompositions
import torch_mlir
# TODO: Control decompositions.
def default_decompositions():
return get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
)
def timeit(*, append_time_to: Optional[List] = None):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time_ns()
result = func(*args, **kwargs)
end_time = time.time_ns()
if append_time_to is not None:
append_time_to.append(end_time - start_time)
return result
return wrapper
return decorator
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def make_shark_compiler(use_tracing: bool, device: str, verbose=False):
def compiler(
fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
):
"""Compile GraphModule using torch-mlir + SHARK."""
if verbose:
print("Compiling graph...")
if _returns_nothing(fx_graph):
return fx_graph
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
fx_graph = make_fx(
fx_graph, decomposition_table=default_decompositions()
)(*example_inputs)
strip_overloads(fx_graph)
if verbose:
print("torch.fx graph:")
print(fx_graph.graph)
ts_compiler = torch.jit.trace if use_tracing else torch.jit.script
ts_graph = ts_compiler(fx_graph, example_inputs)
if verbose:
torch_mlir_module = torch_mlir.compile(
ts_graph,
example_inputs,
output_type=torch_mlir.OutputType.TORCH,
)
print("\n\ntorch-mlir backend contract graph:")
print(torch_mlir_module)
linalg_module = torch_mlir.compile(
ts_graph,
example_inputs,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
)
import io
bytecode_stream = io.BytesIO()
linalg_module.operation.write_bytecode(bytecode_stream)
mlir_module = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module, mlir_dialect="linalg", device=device
)
shark_module.compile()
def forward(*inputs):
result = shark_module("forward", inputs)
result = tuple() if result is None else result
return (result,) if was_unwrapped else result
return forward
return compiler
def check_results(compiled_results, eager_results):
for compiled_result, eager_result in zip(compiled_results, eager_results):
if not torch.allclose(
compiled_result.to("cpu"), eager_result.to("cpu"), atol=1e-5
):
print("Compiled result does not match eager result")
return
print("Compiled result matches eager result!")
def print_time_stats(times):
times_tensor = torch.tensor(times)
def quantile_ms(q):
return torch.quantile(times_tensor.to(float), q).item() / 1e6
print(f"Median: {quantile_ms(0.5)} ms")
print(f"10%ile: {quantile_ms(0.1)} ms")
print(f"90%ile: {quantile_ms(0.9)} ms")
print(f"Total: {torch.sum(times_tensor) / 1e6} ms")
print()

View File

@@ -19,6 +19,12 @@ import tempfile
from shark.parser import shark_args
import io
mlir_type_mapping_dict = {
"linalg": torch_mlir.OutputType.LINALG_ON_TENSORS,
"stablehlo": torch_mlir.OutputType.STABLEHLO,
"tosa": torch_mlir.OutputType.TOSA,
}
def get_module_name_for_asm_dump(module):
"""Gets a name suitable for an assembly dump.
@@ -57,6 +63,7 @@ def get_torch_mlir_module(
dynamic: bool,
jit_trace: bool,
return_str: bool = False,
mlir_type: str = "linalg",
):
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
ignore_traced_shapes = False
@@ -70,10 +77,11 @@ def get_torch_mlir_module(
mlir_module = torch_mlir.compile(
module,
input,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
output_type=mlir_type_mapping_dict[mlir_type],
use_tracing=jit_trace,
ignore_traced_shapes=ignore_traced_shapes,
)
if return_str:
return mlir_module.operation.get_asm()
bytecode_stream = io.BytesIO()

Some files were not shown because too many files have changed in this diff Show More