Compare commits

...

285 Commits

Author SHA1 Message Date
Anush Elangovan
6d6a9dcae8 Revert "Revert "Enable --device_allocator=caching""
This reverts commit 41ee65b377.
2023-02-09 23:00:32 -08:00
Anush Elangovan
41ee65b377 Revert "Enable --device_allocator=caching"
This reverts commit 83fe477066.
2023-02-09 23:00:06 -08:00
Anush Elangovan
83fe477066 Enable --device_allocator=caching 2023-02-09 22:58:46 -08:00
yzhang93
4ca84ee4ee Revert "Delete unnecessary arg setting (#978)" (#985)
This reverts commit 83c69ecd49.
2023-02-09 16:44:26 -08:00
Ean Garvey
c28cc4c919 Fix local_tank_cache handling in shark_downloader. (#981) 2023-02-09 14:52:03 -06:00
yzhang93
e9864cb3f7 Modify the annotation OTF to return bytecode module (#980) 2023-02-08 14:29:43 -08:00
yzhang93
83c69ecd49 Delete unnecessary arg setting (#978) 2023-02-08 10:30:18 -08:00
Prashant Kumar
3595b4aaff Incorporate latest changes in the shark_dynamo backend. 2023-02-08 20:37:30 +05:30
Abhishek Varma
3a9cfe113a Fix SD restart error in exe file (#975)
-- This commit fixes SD restart error in exe file by creating
   variants.json in CWD instead of a relative path.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-08 06:14:08 -08:00
yzhang93
c9966127da Fix iree flags to be able to run on rdna2 (#972) 2023-02-07 16:39:32 -08:00
Ean Garvey
51300d33a7 Remove non-SD args from generate_sharktank.py (#970) 2023-02-07 13:29:55 -06:00
Gaurav Shukla
5af124c5a5 [SD] Add batch count in stable diffusion
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-07 23:26:46 +05:30
Abhishek Varma
eeb20b531a Fix restart SD session error + override args.use_tuned temporarily
-- This commit fixes the session restart error for SD.
-- It also overrides `args.use_tuned` for `import_mlir`, and sets
   `use_tuned` as `False`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-07 19:50:48 +05:30
cstueckrath
9dca842c22 Update .gitignore to exclude models (#967)
the models folder will be stashed along with other changes and most likely kill git doing so.
2023-02-07 01:48:36 -08:00
Ean Garvey
1eb9436836 Fix generate_sharktank args. 2023-02-07 14:06:07 +05:30
Ean Garvey
9604d9ce81 make --update_tank update only if hash mismatch 2023-02-07 14:06:07 +05:30
Ean Garvey
481d0553d8 Remove unnecessary repro_dir / shark_tmp usage 2023-02-07 14:06:07 +05:30
powderluv
60035cd63a Add css in exe (#963)
exe should now default to dark theme too
2023-02-06 15:26:08 -08:00
drumicube
d35f992ace Bring back the --runs options for the cmd command and fix wrong seed/model reported in json, csv and png (#962) 2023-02-06 15:16:50 -06:00
Daniel Garvey
157ae64f9d print to stdout for test visibility (#937)
Co-authored-by: dan <dan@nod-labs.com>
2023-02-06 01:03:27 -08:00
powderluv
ffa17f6057 Update sd_dark_theme.css 2023-02-06 01:01:50 -08:00
drumicube
d695a43e32 Make the dark theme default while launching web server (#954) 2023-02-05 07:25:45 -08:00
powderluv
01f6b4e6f0 Update README.md 2023-02-04 23:40:13 -08:00
yzhang93
7cf31a6ae4 Fix iree-benchmark flag names (#952) 2023-02-04 22:24:18 -08:00
Quinn Dawkins
fbd6224b04 Revert "Revert pipelines (#948)" (#951)
This reverts commit 8115b26079.
Additionally fixes img2col by adding detach elementwise from named op
passes.
2023-02-04 22:44:08 -05:00
powderluv
8115b26079 Revert pipelines (#948)
* Revert "[SD] Modify the flags to use --iree-preprocessing-pass-pipeline (#914)"

This reverts commit a783c089a9.

* Revert "Fix iree flags due to the change in shark-runtime (#944)"

This reverts commit 1d38d49162.
2023-02-04 07:09:51 -08:00
powderluv
820586ac68 Update README.md 2023-02-04 01:01:11 -08:00
powderluv
4a7441ed07 Update profiling_with_iree.md 2023-02-04 00:47:57 -08:00
powderluv
383741f284 Update stable_diffusion_amd.md 2023-02-04 00:40:47 -08:00
powderluv
2bbc4e0e9f Update README.md 2023-02-04 00:35:40 -08:00
powderluv
a7237244b0 Send users to the .exe file first 2023-02-04 00:30:32 -08:00
yzhang93
1d38d49162 Fix iree flags due to the change in shark-runtime (#944) 2023-02-03 21:34:02 -08:00
yzhang93
a783c089a9 [SD] Modify the flags to use --iree-preprocessing-pass-pipeline (#914)
* [SD] Modify the flags to use --iree-preprocessing-pass-pipeline

* Fix flags in sd_annotation
2023-02-03 15:08:02 -08:00
powderluv
e7907dc532 Disable tuned models for sm_89 (#943)
Looks like tuning on A100 doesn't necessarily translate to 40xx.
2023-02-03 14:30:46 -08:00
powderluv
394413679d Fix ckpt_dir (#939) 2023-02-03 12:54:19 -08:00
powderluv
37189f14cb roll to 492 2023-02-03 11:59:18 -08:00
powderluv
0b1ee81901 Minor webui changes (#938) 2023-02-03 11:26:45 -08:00
Gaurav Shukla
00cf73f9b8 [SD] Merge model id dropdown and .ckpt dropdown (#936)
- use_tuned is set to False for custom checkpoints.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-03 10:43:33 -08:00
Abhishek Varma
5a5f285493 [apps-SD] Prepone loading of vmfbs + restructure the SD pipeline
-- This commit prepones loading of vmfbs, if present, for all sub-models.
-- It also involves restructuring the SD pipeline to achieve the loading
   of vmfbs smoothly and postpones processing of checkpoint files only when
   required.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-03 20:21:24 +05:30
powderluv
7f2ea454b6 revert /base variants as they are different (#929)
sd2_1base is different from VAE base (for older cards)
2023-02-03 01:27:32 -08:00
Daniel Garvey
7c14002118 Map 2_1 to 2_1_base (#927)
* fix broken paths for older models

* adds a mapping from sd_2_1 to sd_2_1_base

we only have models in models_db for 2_1_base.
now that diffusers is fixed we can actually generate
2_1 itself, but until we add support for that in the tank
we should fetch 2_1_base for no-import_mlir

---------

Co-authored-by: dan <dan@nod-labs.com>
2023-02-02 19:03:19 -08:00
powderluv
3e9554f0a1 roll to 487 2023-02-02 19:02:39 -08:00
Daniel Garvey
e11ffec544 fix broken paths for older models (#926)
Co-authored-by: dan <dan@nod-labs.com>
2023-02-02 15:48:19 -08:00
powderluv
8a47ddbe99 Update models/ location in UI (#925)
default to png metadata on
2023-02-02 15:28:39 -08:00
powderluv
821108c7bd Fix models path (#924) 2023-02-02 15:16:00 -08:00
Gaurav Shukla
339738f8a3 [SD][web] Populate checkpoints as dropdown UI (#918)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-02 13:59:50 -08:00
powderluv
9b90672f63 Fix LLPC env var (#920) 2023-02-02 11:45:08 -08:00
Ean Garvey
ba07e94a5e disable Torch Inductor autotuner in benchmarks (#919) 2023-02-02 13:25:43 -06:00
aldesilv
b3fc0f29cc enable additional flags for tank test models (#866)
Co-authored-by: Alex <alexander@nod-labs.com>
2023-02-02 11:19:33 -08:00
Gaurav Shukla
5c7deb3611 [SD] Fix output image location (#917)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-02 09:50:37 -08:00
Daniel Garvey
15604e374f change bytecode model paths (#913)
Co-authored-by: dan <dan@nod-labs.com>
2023-02-02 11:12:13 -06:00
Abhishek Varma
7cfc0fa55b [APPS-SD] Fix a few bugs and bring it up to speed with SD CLI (#908) 2023-02-02 07:12:01 -08:00
Ean Garvey
a90812133b Enable pytests on Windows (#901) 2023-02-01 18:36:41 -06:00
powderluv
e26a70aa4f Drop old cli and webui (#911) 2023-02-01 13:13:46 -08:00
Daniel Garvey
6a32a4e26c move ci sd stuff to apps (#912)
Co-authored-by: dan <dan@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-02-01 12:15:07 -08:00
powderluv
e853abf98b Update stable_diffusion_amd.md 2023-02-01 11:11:58 -08:00
powderluv
51e81e6ef8 update main readme 2023-02-01 11:09:00 -08:00
powderluv
e355000ceb Drop torchvision 2023-02-01 10:26:37 -08:00
Daniel Garvey
e374074013 Windows test (#896)
* add generate_sharktank for stable_diffusion model defaults

* add windows test for sd

---------

Co-authored-by: dan <dan@nod-labs.com>
2023-02-01 12:03:54 -06:00
powderluv
81e3d1c2c6 switch to apps/ 2023-02-01 06:54:20 -08:00
powderluv
ab0cbb4475 Add PyInstaller for apps/ webui and cli (#909)
tested webui, cli and webui exe and cli exe
2023-02-01 06:51:27 -08:00
powderluv
1c64e40722 Add PyInstaller for apps/ (#907)
Build with pyinstaller.exe .\apps\stable_diffusion\web\shark_sd.spec

normal flow works. exe is missing a few json files
2023-02-01 06:04:49 -08:00
Evan Guan
8cafe56eb4 Added flags for metadata information. (#894) 2023-02-01 05:16:11 -08:00
Eliasj42
3eceeb7b23 fixed a bug that would sometimes cause intel-gpu to appear unsupported (#899)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-01-31 22:32:05 -08:00
powderluv
1a37675435 Revert "move beta to release (#898)" (#905)
This reverts commit 7edcaf5a06.
2023-01-31 20:31:41 -08:00
powderluv
198ebede8d Revert "replace new model_db.json (#902)" (#904)
This reverts commit 842adef29c.
2023-01-31 20:29:40 -08:00
Ean Garvey
a504903dd5 Fix formatting issues. (#903) 2023-02-01 09:12:45 +05:30
Daniel Garvey
842adef29c replace new model_db.json (#902) 2023-01-31 18:55:22 -08:00
Daniel Garvey
7edcaf5a06 move beta to release (#898)
Co-authored-by: dan <dan@nod-labs.com>
2023-01-31 17:14:08 -06:00
Gaurav Shukla
c124b76328 [SD] Reorganize the stable diffusion model. (#806)
The stable diffusion codebase has been reorganized to make it more
modular so that the same script can be used for web as well as cli,
instead of duplicating the whole codebase.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-31 14:42:41 -08:00
aldesilv
e9c744ee5d find rocm arch used in rocminfo (#893)
Co-authored-by: Alex <alexander@nod-labs.com>
2023-01-31 10:22:31 -08:00
Ean Garvey
83302930d8 Update generate_sharktank.py (#897) 2023-01-31 10:21:22 -08:00
Daniel Garvey
a4634632ba add generate_sharktank for stable_diffusion model defaults (#742)
Co-authored-by: dan <dan@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-01-31 09:44:54 -08:00
Abhishek Varma
d17e8dc5ad [NFC] Rename SD negative_prompts flag
-- This commit renames SD `negative-prompts` -> `negative_prompts` flag.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-31 21:38:59 +05:30
powderluv
9fe63de4d4 Pin macOS SDK to 216 2023-01-31 01:09:44 -08:00
Eliasj42
8111f8bf35 added ability to select gpu (#891)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-01-30 13:39:12 -08:00
Abhishek Varma
fcd62513cf [SD-CLI] Add support for .safetensors + Use diffusers pipeline to load SD
-- This commit uses `load_pipeline_from_original_stable_diffusion_ckpt`
   as exposed due to [Diffusers PR](https://github.com/huggingface/diffusers/pull/2019).
-- It also adds a support for the end users to use `.safetensors` along
   with `.ckpt` file.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-31 00:00:37 +05:30
Abhishek Varma
c3c701e654 Update requirements.txt + README.md of SD
-- This commit includes two python modules as part of requirements.txt.
-- It also updates README.md to also inclue `--no-use_tuned` for users to
   be able to try `hf_model_id` or `ckpt_loc` without any issue.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-30 14:12:54 +05:30
Daniel Garvey
6bf991edf6 adding more robust main.py testing (#889)
Co-authored-by: dan <dan@nod-labs.com>
2023-01-30 00:14:26 -08:00
yzhang93
9644e78545 Fix CUDA tuned model annotation (#880) 2023-01-27 11:35:18 -08:00
dymil
c911189ef0 Add note about latest RDNA3 driver support (#881)
Also tweak other wording
2023-01-27 09:39:19 -08:00
Abhishek Varma
1118b4b651 [SD-CLI] Clean up vmfbs if a retry method fails
-- This commit cleans up vmfb files generated as a result of retry method.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-27 21:55:36 +05:30
PhaneeshB
4be75d4418 fix seed values in SD json and filename 2023-01-27 18:40:26 +05:30
Ean Garvey
fb6beae27c Adds pytest-forked dependency to fix pytest memory accumulation issues. (#876)
* Minor improvements to test-models workflow

- cleaned up pytest command line args in Validate Models job scripts.
- Removed -s flag to provide more readable logs
- Changed shark_cache location to within github workspace and removed --update_tank flag from Linux workflows.

* Use pytest-forked for managing pytest memory usage.
2023-01-26 18:20:15 -06:00
yzhang93
fee73b0b63 Add SD model annotation on fly (#869)
* Add SD model annotation on fly

* Move tuned_compile_through_fx to utils

* Fix SD compilation flags
2023-01-26 11:46:36 -08:00
powderluv
9bbffa519e Add an option to respect LLPC env var (#875)
Also add OSX paths
2023-01-25 13:56:55 -08:00
jinchen62
c3a641f0ab Address TODOs for dataset annotator (#872)
- add args usage, pass gs_url by CL flag
- add support for no existing prompts
2023-01-25 09:28:23 -08:00
yzhang93
aafe7c4701 Add more cuda devices to use tuned model (#868) 2023-01-25 06:36:17 -08:00
Abhishek Varma
9a0b082cf8 [SD-CLI] Add batch_size command-line arg + prompt processing
-- This commit adds `batch_size` command-line arg.
-- It also involves replicating the prompt `batch_size` no. of times.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-25 19:21:25 +05:30
powderluv
8265e34a29 Add SHARK SD CLI tool (#870) 2023-01-24 23:14:32 -08:00
powderluv
8ef8ae097f Update to build 469 2023-01-24 22:16:13 -08:00
powderluv
c3d14293c0 Update sample results 2023-01-24 22:14:06 -08:00
powderluv
d55d8be504 Add signing of release builds 2023-01-24 21:32:21 -08:00
powderluv
03543030d3 use pefile 2023-01-24 18:35:51 -08:00
powderluv
fc6b474b92 Add ordlookup to requirements.txt 2023-01-24 18:30:16 -08:00
powderluv
a5db785dd7 checkoutv2 on windows 2023-01-24 18:23:22 -08:00
powderluv
1c1c5cd611 Build Windows nightly on 7950x 2023-01-24 16:21:56 -08:00
Abhishek Varma
6ed02f70ec [SD-CLI] Make using ckpt_loc and hf_model_id easier
-- Currently we require users to specify the base model on which the custom
   model (.ckpt) is tuned on. Even for running a HuggingFace repo-id, we
   require the users to go a tedious way of adding things to variants.json.

-- This commit aims to address the above issues and will be treated as a
   starting point for a series of design changes which makes using SHARK's SD
   easier.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-24 23:03:46 +05:30
Prashant Kumar
cb78cd8ac0 Add the support for the batch size parameter. 2023-01-24 22:33:13 +05:30
Ean Garvey
0c4590b45a Update generate_sharktank.py 2023-01-24 10:18:03 +05:30
jinchen62
d2e2ee6efa Add multiple prompts support for dataset annotator (#862) 2023-01-23 18:40:36 -08:00
powderluv
6a380a0b48 Add more nvidia cards 2023-01-23 17:07:45 -08:00
powderluv
e5d5acbf1f Remove torchvision requirements from web (#860) 2023-01-23 13:48:53 -08:00
powderluv
00e38abbf0 Add 4080 support 2023-01-23 09:56:34 -08:00
Abhishek Varma
e3e4ea5443 Update README.md
-- Make usage of `hf_model_id` clearer.
2023-01-23 23:25:23 +05:30
Prashant Kumar
a3e4ea3228 Remove the dependency of the torchvision. (#858)
Remove the dependency of torchvision library for the conversion
of tensor layout format to what PIL library expects.
2023-01-23 08:49:57 -08:00
powderluv
56f16d6baf Update SD readme 2023-01-23 06:51:54 -08:00
Abhishek Varma
7a55ab900e [SD-CLI] Fix CKPT script + add more variants + update README.md
-- This commit fixes CKPT script to rely on the previous CKPT to Diffusers
   script.
   TODO: Let go of the script once the CKPT is included in next release
         of diffusers.
-- It also adds many variants as part of `variants.json` and updates
   `README.md` to reflect change in default `hf_model_id`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-23 18:34:24 +05:30
Abhishek Varma
137643fe72 [SD-CLI] Update README.md of custom models to include hf_model_id 2023-01-23 11:37:13 +05:30
Anush Elangovan
d6e59c6241 black format comments 2023-01-22 16:34:40 -08:00
powderluv
458eb5d34c detect RX 7900 better 2023-01-22 16:32:27 -08:00
Erkin Alp Güney
8259f08864 Collapsibles for Win10 and Linux users (#851)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-01-22 09:50:33 -08:00
Prashant Kumar
b3ab0a1843 Add width and height support for the scheduler. 2023-01-22 23:16:50 +05:30
dependabot[bot]
f09f217478 Bump tensorflow from 2.10 to 2.10.1 (#853)
Bumps [tensorflow](https://github.com/tensorflow/tensorflow) from 2.10 to 2.10.1.
- [Release notes](https://github.com/tensorflow/tensorflow/releases)
- [Changelog](https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md)
- [Commits](https://github.com/tensorflow/tensorflow/compare/v2.10.0...v2.10.1)

---
updated-dependencies:
- dependency-name: tensorflow
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-01-22 06:40:17 -08:00
Daniel Garvey
e842c8c19b add main.py testing for sdiff (#836)
Co-authored-by: dan <dan@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-01-22 01:16:17 -08:00
powderluv
f6c3112d44 Revert "potential fix to pre-load DLL dir for torch-mlir (#848)" (#852)
This reverts commit 6c470d8131.
2023-01-22 00:09:35 -08:00
yzhang93
7059610632 Modify the default for --hf_model_id flag 2023-01-21 11:21:47 +05:30
powderluv
2d272930d9 Update to signed build 455 2023-01-20 16:50:42 -08:00
powderluv
6c470d8131 potential fix to pre-load DLL dir for torch-mlir (#848)
Doesn't regress the main.py script but system already pre-loaded
the DLL so needs more testing.
2023-01-20 14:48:45 -08:00
jinchen62
30b29ce8cd Add readme for dataset annotator (#847) 2023-01-20 01:03:33 -08:00
jinchen62
1a9933002f Add dataset annotation tool (#835) 2023-01-19 16:56:08 -08:00
stanley
c4a9365aa1 [Shark][Training] Refresh SharkTrainer to latest APIs. 2023-01-19 20:30:15 +00:00
Prashant Kumar
9d3af37104 bugfix related to the height width params. 2023-01-20 00:21:44 +05:30
Prashant Kumar
7b3d57cff7 Add height and width as args. 2023-01-19 23:43:29 +05:30
Abhishek Varma
a802270da9 [SD-CLI] Update README.md about variants.json 2023-01-19 22:46:54 +05:30
Abhishek Varma
dd194a8758 [SD-CLI] Reorder loading of opt_params when needed
-- This commit reorders loading of opt_params when `import_mlir` is not used.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 22:02:51 +05:30
Abhishek Varma
6de02de221 [SD-CLI] Make using custom models easier
-- This commit makes using custom models easier using a combination of
   `import_mlir`, `ckpt_loc` and `hf_model_id`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 22:02:36 +05:30
Abhishek Varma
85259750bf [SD-CLI] Fix variants.json mapping
-- This commit fixes variants.json's mapping.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 22:02:36 +05:30
Prashant Kumar
1249f0007d Remove args.variant and args.version with args.custom_model. 2023-01-19 19:55:12 +05:30
Abhishek Varma
db0514d3fa [SD-CLI] Fix get_model_configuration to use max_length
-- This commit fixes `get_model_configuration` to use `max_length`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 19:10:04 +05:30
Abhishek Varma
dce42a7fad [SD-CLI] Fix args.max_length range check
This commit fixes args.max_length range check.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 18:26:23 +05:30
Prashant Kumar
ec0b380194 Refactor shark_tank models and custom models.
The custom models shouldn't depend on shark_tank in anyway.
2023-01-19 13:56:11 +05:30
Ean Garvey
7f27b61c98 Update setup_venv.sh to install triton if BENCHMARK=1 2023-01-19 00:26:46 -06:00
Guy Nachshon
f0b3557b02 fix: replace malicious and deleted package (#833) 2023-01-18 13:41:05 -08:00
xzuyn
2a1d1c1001 make jpeg optimized and progressive (#820)
* GUI make jpeg optimized and progressive

* CLI make jpeg optimized and progressive
2023-01-17 16:35:36 -08:00
Abhishek Varma
df7eb80e5b [SD-CLI] Make custom_model take highest priority for generating models if present
-- This commit makes `custom_model` take highest priority for generating models if present.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-17 22:50:58 +05:30
Fraser Humphries
b9d947ce6f style: 🎨 Restore whitespace 2023-01-17 17:45:32 +05:30
Fraser Humphries
e6589d2454 fix: 🏗️ Add demo.css to spec file datas 2023-01-17 17:45:32 +05:30
Fraser Humphries
0f5ac6afcf fix: 🐛 resolve css file path relative to __file__
issues-816
2023-01-17 17:45:32 +05:30
Abhishek Varma
bc1bb1d188 [SD-CLI] Fix vmfb naming + update README.md for custom_model
-- This commit introduces a fix for .vmfb naming to strip away any
   non-alphanumeric characters from `custom_model` path.
-- It also updates the README.md to include the `custom_model` arg.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-17 16:27:54 +05:30
Abhishek Varma
3af2dd10ce [SD-CLI] Add CKPT support to update models irrespective of import_mlir flag
-- This commit adds CKPT support to update models irrespective of `import_mlir` flag.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-17 13:24:27 +05:30
yzhang93
dd22c65855 Add CUDA tuned models for SD variants (#814) 2023-01-16 09:38:27 -08:00
PhaneeshB
48137ced19 add png as default format 2023-01-16 18:37:36 +05:30
Phaneesh Barwaria
6eb47c12d1 add multi-run in single execution (#812) 2023-01-13 11:12:43 -08:00
Prashant Kumar
5a1fc6675a This PR adds --import-mlir for f16 tensors without cuda. 2023-01-13 22:19:53 +05:30
Prashant Kumar
6f80825814 Modify import_with_fx to import with dtype=f16. 2023-01-13 22:19:53 +05:30
PhaneeshB
f0dd48ed2a remaining disk space warning 2023-01-13 19:34:05 +05:30
Gaurav Shukla
15e2df0db0 [SD][web] Add a UI textbox to show the output location
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-13 19:33:04 +05:30
Fraser Humphries
4ad0109769 fix: 🐛 Extract demo css string to css file
fix: 🐛 Extract demo css string to css file

issues/807

fix: 🐛 Revert background colors
2023-01-13 16:42:05 +05:30
PhaneeshB
ee0009d4b8 pythonize uname for cpu target triple in windows 2023-01-12 22:39:49 +05:30
PhaneeshB
9d851c3346 small fixes 2023-01-12 22:32:24 +05:30
xzuyn
5d117af8ae Increase JPEG output quality & disable subsampling (#801)
* Increase JPEG output quality & disable subsampling

Increased to JPEG95 from the default JPEG75 which is way too compressed. Output image size is now ~100kb. Previously was ~20kb.

* Increase JPEG output quality & disable subsampling

Add jpeg quality increase on cli

* line length changes

* line length changes
2023-01-11 23:06:11 -08:00
yzhang93
bb41c2d15e Add VAE cuda tuned model (#796) 2023-01-11 14:15:03 -08:00
powderluv
eba138ee4a Revert "Change address for connection test (#785)" (#797)
This reverts commit 187f0fa70c.
2023-01-11 12:01:37 -08:00
Gaurav Shukla
3b2bbb74f8 [SD][web] Add support for saving generated images
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-11 22:47:32 +05:30
fokin33
dbc0f81211 Add simple telegram bot (#787) 2023-01-11 09:20:23 -06:00
mariecwhite
d0b613d22e Enable Torch-Inductor Part 2 2023-01-10 20:15:29 -08:00
Ean Garvey
72f29b67d5 Add Resnet50 fp16 variant to pytests. (#760) 2023-01-10 16:31:11 -08:00
Quinn Dawkins
9570045cc3 Fix tuned model selection for non-vulkan devices (#792) 2023-01-10 19:04:21 -05:00
Phaneesh Barwaria
e4efdb5cbb add json data for each image (#790)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-01-10 13:13:07 -08:00
calcifer11
187f0fa70c Change address for connection test (#785)
Some ISP's (like mine) reserves 1.1.1.1 for internal testing, meaning _internet_connected(); needlessly retries for a minute until it fails even though my connection is fine.
Propose 8.8.8.8 instead as this is also publically available and not normally blocked by ISPs.
2023-01-10 10:51:30 -08:00
Gaurav Shukla
472185c3e4 [SD][web] Fix device key error
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-10 20:51:01 +05:30
Gaurav Shukla
f94a571773 [SD] Update spec file to include model_config.json
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-10 20:38:10 +05:30
mariecwhite
183e447d35 Enable Torch Inductor (#784) 2023-01-10 20:57:58 +11:00
xzuyn
12f844d93a Git pull through argument in setup_venv (#623) 2023-01-09 15:42:13 -08:00
yzhang93
47a119a37f [SD] Add CUDA A100 tuned model (#773) 2023-01-09 15:22:27 -08:00
Gaurav Shukla
ee56559b9a [SD][web] Add a json file for model configuration
This cleans model_wrappers.py file.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-10 00:05:46 +05:30
Gaurav Shukla
00e594deea [SD][web] Add version number in performance details
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-09 21:32:34 +05:30
George Petterson
6ad9b213b9 Add GCN4
(cherry picked from commit 3be072b3c09c9b38bc2d79ad6e6900eefee49a1c)
2023-01-09 21:09:50 +05:30
PhaneeshB
e4375e8195 Add support for vulkan target env 2023-01-09 21:09:50 +05:30
mariecwhite
487bf8e29b Enable TF32 in Torch if specified (#768) 2023-01-09 06:48:57 -08:00
Prashant Kumar
fea1694e74 Delete the cached objects explicitly. 2023-01-06 23:04:52 +05:30
Prashant Kumar
4102c124a9 Add the shark upscaler model. (#759) 2023-01-05 14:07:20 -08:00
yzhang93
135bad3280 [SD] Update v1.4 tuned model (#758) 2023-01-05 11:04:30 -08:00
Gaurav Shukla
b604f36881 [SD][web] Add flags for global URL and server port
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-05 15:30:30 +05:30
yzhang93
782b449c71 Add script to auto annotate SD models and variants (#751)
* Add script to auto annotate SD models and variants

* Add model config files

* Add script to auto annotate SD models and variants

* Add model config files

* Move config files to shark_tank
2023-01-04 15:53:10 -08:00
jinchen62
017dcab685 Add target triple support for TITAN RTX (#756) 2023-01-04 15:39:00 -08:00
Abhishek Varma
e60b4568c6 [SharkInference] Make SharkInference compile the entire module (#708)
* [SharkInference] Make SharkInference compile the entire module

-- Previously SharkInference was compiling and providing run APIs
   for a harcoded function with function name "forward".
-- This commit makes the compiling functionality generic and now
   any function being defined within the module can be run.
-- It also creates an API to fetch all the function names defined
   within the compiled module.
-- This commit updates both web and command-line execution of Stable
   Diffusion to use new API of  SharkInference.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-03 23:25:23 +05:30
powderluv
4ee3d95a5a Update to build 423
Post pytorch security breach
2023-01-01 12:10:23 -08:00
Graham
f18725bacc replaced <username> with %username% for easy copy/paste (#744) 2022-12-31 21:29:37 -08:00
jinchen62
f6064a2b84 Add a prototype of the model compilation configs for SD (#734) 2022-12-28 15:14:36 -08:00
Quinn Dawkins
2e90cb7b95 Set default warmup count to 0 (#736) 2022-12-28 12:27:43 -06:00
powderluv
2c09d63cd9 Update to build 417 2022-12-27 14:25:20 -08:00
powderluv
cc6fbdb0c3 Add sm_89 and point to nvcuda.dll (#731) 2022-12-26 10:54:38 -08:00
powderluv
ecfdec12f3 Update requirements.txt 2022-12-25 15:39:20 -08:00
Gaurav Shukla
45af40fd14 [SD][web] Add openjourney and dreamlike in SD web UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-26 01:59:36 +05:30
Phaneesh Barwaria
d11cf42501 Add support for dreamlike diffusion (#725)
* Add support for dreamlike diffusion

* model wrapper to support 77 dreamlike

* lint fix
2022-12-26 01:35:17 +05:30
Gaurav Shukla
c3c1e3b055 [SD] Add bucket info in the model_db.json
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-25 20:38:33 +05:30
Gaurav Shukla
7c5e3b1d99 [SD] Fix flags for cuda devices
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-25 19:03:02 +05:30
Gaurav Shukla
ed6cec71e7 [SD] Fix clip inference time
Fix clip inference time by adding default warmup_count to 5.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-25 18:16:53 +05:30
Tobby "GTD-Carthage" Ong
d6bcdd069c - Added missing double linebreak from linting 2022-12-25 12:07:43 +05:30
Tobby "GTD-Carthage" Ong
a26347826d - Revised code to also use get_schedulers function instead 2022-12-25 12:07:43 +05:30
Tobby "GTD-Carthage" Ong
5d1c099b31 [SD] Add Euler Ancestral scheduler as option to WebUI 2022-12-25 12:07:43 +05:30
Gaurav Shukla
220bee1365 [SD][web] Add device support in the SD web UI
1. Now device selection is available through UI.
2. Models reloading will only happen when there will be a change in the
   settings(variant + device).

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-25 01:45:07 +05:30
PhaneeshB
1261074d95 Add tuned models for av3 and ad 2022-12-24 22:56:15 +05:30
Stanley Winata
136021424c [SD] Change default VMA large heap block size for windows perf. (#715)
Windows perform can boost from 2.67s/image to 2.4523s/image.
While Linux stays the same.
2022-12-24 01:40:58 +07:00
PhaneeshB
fee4ba3746 Add openjourney 2022-12-23 23:34:22 +05:30
Gaurav Shukla
a5b70335d4 [SD][web] Add variant support in the web UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-23 23:18:27 +05:30
Stanley Winata
5cf4976054 [Vulkan][utils] Add GTX Pascal support. (#709) 2022-12-22 15:24:15 -08:00
PhaneeshB
1aa3255061 Add vaebase for av3 and ad 2022-12-23 04:17:17 +05:30
Daniel Garvey
b01f29f10d add support for clear_all (#691) 2022-12-22 11:25:03 -06:00
Boian Petkantchin
2673abca88 Fix concurrency issue in stress_test for CUDA devices 2022-12-22 08:54:19 -08:00
Gaurav Shukla
7eeb7f0715 [SD] Update all the utilities to make web and CLI codebase closer (#707)
At this point, all the utilities of SD web and CLI are exactly same.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-22 02:49:48 -08:00
powderluv
37262a2479 Remove spurious characters 2022-12-21 19:23:54 -08:00
Gaurav Shukla
de6e304959 [SD] Fix the resource location in shark_sd.spec (#706) 2022-12-21 14:41:56 -08:00
Quinn Dawkins
234475bbc7 Add base_vae entries for variant models (#705) 2022-12-21 14:35:08 -08:00
Quinn Dawkins
abbd9f7cfc [SD] Set unet flags for cuda (#704) 2022-12-21 13:22:04 -08:00
Gaurav Shukla
dfd6ba67b3 [SD] Update SD CLI to use model_db.json
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-22 02:13:04 +05:30
yzhang93
1595254eab Modify model annotation tool to walk through ops by shape (#692) 2022-12-21 10:46:30 -08:00
PhaneeshB
6964c5eeba encapsulate relevant methods in one method 2022-12-21 23:56:17 +05:30
PhaneeshB
2befe771b3 Add support for automatic target triple selection for SD 2022-12-21 22:38:06 +05:30
Prashant Kumar
b133a035a4 Add the download progress bar. 2022-12-21 15:47:33 +05:30
Gaurav Shukla
726c062327 [SD] Update spec files
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-21 14:16:04 +05:30
Gaurav Shukla
9083672de3 [SD][web] Tuned models only for stablediffusion/fp16 and rdna3 cards
Currently tuned models are only available for stablediffusion/fp16 and
rdna3 cards.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-21 14:15:39 +05:30
Quinn Dawkins
cdbaf880af [SD] [web] Add model variants to web 2022-12-21 13:42:22 +05:30
Quinn Dawkins
9434981cdc Add random seed generation for seed = -1 in cli (#689) 2022-12-20 17:15:22 -05:00
Phaneesh Barwaria
8b3706f557 Add Anything v3 and AnalogDiffusion variants of SD (#685)
* base support for anythingv3

* add analogdiffusiont

* Update readme

* keep max len 77 till support for 64 added for variants

* lint fix
2022-12-20 13:08:13 -08:00
Gaurav Shukla
0d5173833d [SD] Add a json file for model names information. (#687)
This commit simplifies the code to identify the model name for a
particular set of flags. This is achieved by introducing a json file
that stores the model names information. The models are uploaded in
gcloud with these names.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-20 11:47:31 -08:00
powderluv
bf1178eb79 roll to build 400 2022-12-20 10:34:31 -08:00
yzhang93
abcd3fa94a [SD] Set model max length 64 as default (#681) 2022-12-19 21:13:04 -08:00
Quinn Dawkins
62aa1614b6 [SD] Add --use_base_vae flag to do conversion to pixel space on cpu (#682) 2022-12-19 21:09:39 -08:00
Quinn Dawkins
7027356126 [SD] Fix warmup for max length 64 (#680) 2022-12-19 21:04:44 -05:00
yzhang93
5ebe13a13d Add Unet len 64 tuned model (#679) 2022-12-19 16:24:08 -08:00
Gaurav Shukla
c3bed9a2b7 [SD][web] Add flag to disable the progress bar animation
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-20 02:50:04 +05:30
yzhang93
f865222882 Update VAE 19dec tuned model (#676) 2022-12-19 12:42:28 -08:00
powderluv
e2fe2e4095 Point to 398 2022-12-19 12:08:30 -08:00
powderluv
0532a95f08 Update stable_diffusion_amd.md 2022-12-19 12:04:42 -08:00
Quinn Dawkins
ff536f6015 [SD] Deduplicate initial noise generation (#677) 2022-12-19 14:38:41 -05:00
Gaurav Shukla
097d0f27bb [SD][web] Add 64 max_length support in SD web
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-20 00:00:58 +05:30
Prashant Kumar
2257f87edf Update opt_params.py 2022-12-19 23:43:30 +05:30
PhaneeshB
a17800da00 Add 64 len f16 untuned mlir 2022-12-19 22:53:17 +05:30
Prashant Kumar
059c1b3a19 Disable vae --use_tuned version. 2022-12-19 22:45:45 +05:30
Stanley Winata
9a36816d27 [SD][CLI] Add a warmup phase (#670) 2022-12-20 00:14:23 +07:00
Gaurav Shukla
7986b9b20b [SD][WEB] Update VAE model and wrapper
This commit updates VAE model which significantly improves performance
by an order of ~300ms.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-19 22:32:05 +05:30
Gaurav Shukla
b2b3a0a62b [SD] Move initial latent generation out of inference time
The initial random latent generation is not taken into account
for total SD inference time.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-19 22:32:05 +05:30
Prashant Kumar
3173b7d1d9 Update VAE model and wrapper. 2022-12-19 19:54:50 +05:30
Gaurav Shukla
9d716d70d6 [SD][web] Fix performance issues on shark scheduler
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-19 17:44:37 +05:30
Stanley Winata
e1901a8608 [SD][CL] Disable print at every iteration. (#664)
Printing might incur extra time to runtime. Hence, we add a flag to hide it. To disable printing please set this flag `--hide_steps`.

Co-authored-by: Stanley <stanley@MacStudio.lan>
2022-12-19 15:39:57 +07:00
Quinn Dawkins
7d0cbd8d90 [SD][web] Set default tuned unet to v2 (#663) 2022-12-19 11:50:08 +07:00
Quinn Dawkins
59358361f9 [SD] Make clip batch 2 for positive and negative prompts (#662)
Combines the forward passes for each input prompt type into a single batched clip pass.
2022-12-18 23:46:21 -05:00
Quinn Dawkins
7fea2d3b68 [SD] update default large heap size for web as well (#661) 2022-12-18 21:50:26 -05:00
Quinn Dawkins
b6d3ff26bd [SD] Change default VMA large heap block size (#660) 2022-12-18 21:41:46 -05:00
Stella Laurenzo
523e63f5c1 Fix NoneType exception if vulkan tuning flags not detected. (#659)
(This goes on to produce compilation errors, but one step at a time)
2022-12-18 16:40:56 -08:00
Stella Laurenzo
10630ab597 Add config stanza for NVIDIA RTX 2080. (#658)
Just happened to have this card on my Windows machine and verified that the SD demo works on it.

```
Average step time: 144.26142692565918ms/it
Clip Inference Avg time (ms) = (205.001 + 44.000) / 2 = 124.501
VAE Inference time (ms): 281.001

Total image generation time: 7.856997728347778sec
```

I'd love to add an API upstream to derive compiler tuning flags from a host device.
2022-12-18 16:40:47 -08:00
Quinn Dawkins
2bc6de650d [SD] Add support for a compiled version of the discrete Euler scheduler (#657)
* Add Shark version of euler scheduler

* Add Shark version of euler scheduler to web ui
2022-12-17 19:25:43 -08:00
powderluv
ffef1681e3 Update stable_diffusion_amd.md 2022-12-17 03:40:08 -08:00
yzhang93
d935006a4a Update Unet tuned model to v2 (#656) 2022-12-16 22:10:15 -08:00
powderluv
660cb5946e Update to 392 release 2022-12-16 16:00:49 -08:00
Gaurav Shukla
10160a066a [SD][WEB] Add vae tuned model in the SD web (#653)
1. Add tuned vae model in the SD web.
2. Use tuned models in case of rdna3 cards.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-16 15:29:48 -08:00
Anush Elangovan
72976a2ece Import env vars first 2022-12-16 15:12:28 -08:00
Phaneesh Barwaria
831f206cd0 Revert "Add target triple selection for multiple cards" (#655)
This reverts commit acb905f0cc.
2022-12-16 15:01:45 -08:00
Gaurav Shukla
72648aa9f2 Revert "[SD][WEB] Deduce vulkan-target-triple in the presence of multiple cards"
This reverts commit 35e623deaf.
2022-12-17 04:28:18 +05:30
Gaurav Shukla
35e623deaf [SD][WEB] Deduce vulkan-target-triple in the presence of multiple cards
1. Get the correct vulkan-target-triple for a specified device in the
   presence of multiple cards.
2. Use tuned unet model for rdna3 cards.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-17 03:04:47 +05:30
Anush Elangovan
6263636738 Fix more lints 2022-12-16 13:26:15 -08:00
Anush Elangovan
535d012ded Fix lint 2022-12-16 13:24:51 -08:00
yzhang93
c73eed2e51 Add VAE winograd tuned model (#647) 2022-12-16 13:01:45 -08:00
Anush Elangovan
30fdc99f37 Set to enable llpc
Use an env var to enable llpc
2022-12-16 12:57:30 -08:00
PhaneeshB
acb905f0cc Add target triple selection for multiple cards 2022-12-17 02:24:37 +05:30
Gaurav Shukla
bba06d0142 [SD][WEB] Avoid passing args to utils APIs
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-17 01:41:33 +05:30
Ean Garvey
a14a47af12 Move most xfails to entries in tank/all_models.csv and temporarily remove multiprocessing and TF gpu support. (#646)
-Adds date variable back to nightly.yml so shark_tank uploads are dated again
-added specification for nightly pytests to not run tests on metal (vulkan is sufficient)
-added some paths/filetypes to be ignored when triggering workflow runs. (no test-models on changes to .md files or anything in the shark/examples/ directory or its subdirectories.
-pytest only picks up tank/test_models.py, so no need to specify which file to run when running pytest from SHARK base directory.
-Cleaned up xfails so that they can be added to models as csv entries. Columns 7-9 in all_models.csv trigger xfails with cpu, cuda, vulkan, respectively, and row 10 can be populated with a reason for the xfails.
-Fixed a few defaults for shark_args and pytest args (defined in conftest.py)
-Fixes --update_tank option in shark_downloader
removes some multiprocessing in pytest / TF+CUDA support because it breaks pytest and false passes, leaving regressions at large.
-Adds xfails for and removes albert torch from gen_sharktank list (tank/torch_model_list.csv).
-Cleans up xfails for cpu, cuda, vulkan (removing old ones)
2022-12-16 12:56:32 +05:30
Phaneesh Barwaria
73457336bc add flag for toggling vulkan validation layers (#624)
* add vulkan_validation_layers flag

* categorize SD flags

* stringify true and false for flag
2022-12-15 20:40:59 -06:00
Ean Garvey
a14c53ad31 Remove albert-base-v2 since it fails torch_mlir.compile() (#644) 2022-12-15 16:05:19 -06:00
Gaurav Shukla
e7e763551a [WEB][SD] Make unet tuned model default for rdna3 devices (#642) 2022-12-15 12:02:03 -08:00
nirvedhmeshram
2928179331 Add more NVIDIA targets (#640) 2022-12-15 11:24:38 -06:00
Stanley Winata
24a16a4cfe [Stable Diffusion] Disable binding fusion to work with moltenVK on mac. (#639)
Co-authored-by: Stanley <stanley@MacStudio.lan>
2022-12-16 00:22:49 +07:00
Phaneesh Barwaria
6aed4423b2 add vulkan lib path (#638) 2022-12-15 19:48:29 +07:00
yzhang93
6508e3fcc9 Update tuned model SD v2.1base (#634) 2022-12-14 16:02:35 -05:00
Gaurav Shukla
a15cb140ae [WEB] Display the 512x512 image size
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-14 22:43:03 +05:30
Prashant Kumar
898bc9e009 Add the stable diffusion v2.1 version. 2022-12-14 20:19:41 +05:30
Gaurav Shukla
e67ea31ee2 [SHARK][SD] Add --local_tank_cache flag in the stable diffusion
This flag can be used to set local shark_tank cache directory.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-14 20:00:25 +05:30
Gaurav Shukla
986c126a5c [SHARK][SD] Add support for negative prompts
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-14 18:20:09 +05:30
Gaurav Shukla
0eee7616b9 [WEB] Launch only one SD version at a time
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-14 17:30:24 +05:30
powderluv
5ddce749b8 lint fix 2022-12-13 22:02:32 -08:00
powderluv
d946cffabc Revert "Move most xfails to entries in tank/all_models.csv and temporarily remove multiprocessing and TF gpu support. (#602)" (#622)
This reverts commit fe618811ee.
2022-12-13 21:49:46 -08:00
Ean Garvey
fe618811ee Move most xfails to entries in tank/all_models.csv and temporarily remove multiprocessing and TF gpu support. (#602)
* Move most xfails to entries in tank/all_models.csv

* enable usage of pytest without specifying tank/test_models.py

* add dict_configs.py to gitignore.

* Pin versions for runtimes and torch-mlir for setup.
2022-12-13 18:11:17 -08:00
powderluv
09c45bfb80 clean up cache printf 2022-12-13 14:11:14 -08:00
Boian Petkantchin
e9e9ccd379 Add stress test 2022-12-13 13:21:51 -08:00
Boian Petkantchin
a9b27c78a3 Return dynamic model if specified when downloading from the tank 2022-12-13 13:21:51 -08:00
Boian Petkantchin
bc17c29b2e In get_iree_runtime_config get the specific device instead of the default 2022-12-13 13:21:51 -08:00
Boian Petkantchin
aaf60bdee6 Simplify iree_device_map 2022-12-13 13:21:51 -08:00
Gaurav Shukla
d913453e57 [WEB] Update models to 8dec and also default values (#620)
1. Update the models to 8 dec.
2. precision is default to `fp16` in CLI.
3. version is default to `v2.1base` in CLI as well as web.
4. The default scheduler is set to `EulerDiscrete` now.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-13 13:08:33 -08:00
powderluv
08e373aef4 Update stable_diffusion_amd.md 2022-12-13 11:47:29 -08:00
Prashant Kumar
4cb50a3d06 Update the models to 8th Dec version. 2022-12-14 00:01:46 +05:30
Gaurav Shukla
b03038222d [SHARK] Update dependencies
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-13 22:12:00 +05:30
Gaurav Shukla
5f5e0766dd [WEB] Add SD2.1 web support
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-13 21:36:01 +05:30
124 changed files with 7062 additions and 3707 deletions

View File

@@ -10,14 +10,14 @@ on:
jobs:
windows-build:
runs-on: windows-latest
runs-on: 7950X
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
@@ -50,8 +50,12 @@ jobs:
shell: powershell
run: |
./setup_venv.ps1
pyinstaller web/shark_sd.spec
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
# GHA windows VM OOMs so disable for now
@@ -108,6 +112,7 @@ jobs:
- name: Install dependencies
run: |
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
@@ -131,14 +136,14 @@ jobs:
pip install ./wheelhouse/nodai*
# Validate the Models
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" tank/test_models.py |
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt
if !(grep -Fxq " failed" pytest_results.txt)
then
export SHA=$(git log -1 --format='%h')
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/latest/
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
fi
rm -rf ./wheelhouse/nodai*
@@ -154,6 +159,6 @@ jobs:
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
pytest --ci --ci_sha=${SHORT_SHA} tank/test_models.py |
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt

View File

@@ -6,8 +6,14 @@ name: Validate Models on Shark Runtime
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
@@ -23,7 +29,7 @@ jobs:
strategy:
fail-fast: true
matrix:
os: [icelake, a100, MacStudio, ubuntu-latest]
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
suite: [cpu,cuda,vulkan]
python-version: ["3.10"]
include:
@@ -46,13 +52,19 @@ jobs:
suite: cuda
- os: a100
suite: cpu
- os: 7950x
suite: cpu
- os: 7950x
suite: cuda
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
if: matrix.os != '7950x'
- name: Set Environment Variables
if: matrix.os != '7950x'
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
@@ -72,6 +84,9 @@ jobs:
#cache-dependency-path: |
# **/requirements-importer.txt
# **/requirements.txt
- uses: actions/checkout@v2
if: matrix.os == '7950x'
- name: Install dependencies
if: matrix.suite == 'lint'
@@ -94,9 +109,9 @@ jobs:
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cpu --update_tank
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
@@ -106,29 +121,42 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cuda --update_tank
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
echo "VULKAN SDK PATH wo setup: $VULKAN_SDK"
cd /Users/anush/VulkanSDK/1.3.224.1/
source setup-env.sh
cd $GITHUB_WORKSPACE
echo "VULKAN SDK PATH with setup: $VULKAN_SDK"
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" tank/test_models.py -k vulkan --update_tank
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" -k vulkan --update_tank
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os != 'MacStudio'
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k vulkan --update_tank
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
pytest --benchmark -k vulkan -s
type bench_results.csv
- name: Validate Stable Diffusion Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
./shark.venv/Scripts/activate
python build_tools/stable_diffusion_testing.py --device=vulkan

16
.gitignore vendored
View File

@@ -159,16 +159,26 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# vscode related
.vscode
# Shark related artefacts
*venv/
shark_tmp/
*.vmfb
.use-iree
tank/dict_configs.py
# ORT related artefacts
cache_models/
onnx_models/
#web logging
web/logs/
web/stored_results/stable_diffusion/
# Generated images
generated_imgs/
# Custom model related artefacts
variants.json
models/
# models folder
apps/stable_diffusion/web/models/

View File

@@ -1,12 +1,47 @@
# SHARK
High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerators and Heterogeneous Clusters
High Performance Machine Learning Distribution
[![Nightly Release](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
[![Validate torch-models on Shark Runtime](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
## Installation (Windows, Linux and macOS)
<details>
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download this specific driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mril-iree). Latest drivers may not work.
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
#### Linux Drivers
* MESA / RADV drivers wont work with FP16. Please use the latest AMGPU-PRO drivers (non-pro OSS drivers also wont work) or the latest NVidia Linux Drivers.
Other users please ensure you have your latest vendor drivers and Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home) and if you are using vulkan check `vulkaninfo` works in a terminal window
</details>
### Quick Start for SHARK Stable Diffusion for Windows 10/11 Users
Install Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
Download the latest .exe https://github.com/nod-ai/SHARK/releases.
Double click the .exe and you should have the [UI]( http://localhost:8080/?__theme=dark) in the browser.
If you have custom models (ckpt, safetensors) put in a `models/` directory where the .exe is.
Enjoy.
Some known AMD Driver quirks and fixes with cursors are documented [here](https://github.com/nod-ai/SHARK/blob/main/apps/stable_diffusion/stable_diffusion_amd.md ).
<details>
<summary>Advanced Installation (Only for developers)</summary>
## Advanced Installation (Windows, Linux and macOS) for developers
## Check out the code
@@ -45,12 +80,12 @@ source shark.venv/bin/activate
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\Users\nod\SHARK> cd web
(shark.venv) PS C:\Users\nod\SHARK\web> python index.py
(shark.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
```
#### Linux Users
#### Linux / macOS Users
```shell
(shark.venv) > cd web
(shark.venv) > cd apps/stable_diffusion/web
(shark.venv) > python index.py
```
@@ -63,39 +98,28 @@ source shark.venv/bin/activate
### Run Stable Diffusion on your device - Commandline
#### Install your hardware drivers
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mril-iree)
* [macOS Users] Download and install the latest Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home)
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
Other users please ensure you have your latest vendor drivers and Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home) and if you are using vulkan check `vulkaninfo` works in a terminal window
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\shark\examples\shark_inference\stable_diffusion\main.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
```shell
python3.10 shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
python3.10 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
</details>
The output on a 6900XT would like:
The output on a 7900XTX would like:
```shell
44it [00:08, 5.14it/s]i = 44 t = 120 (191ms)
45it [00:08, 5.15it/s]i = 45 t = 100 (191ms)
46it [00:08, 5.16it/s]i = 46 t = 80 (191ms)
47it [00:09, 5.16it/s]i = 47 t = 60 (193ms)
48it [00:09, 5.15it/s]i = 48 t = 40 (195ms)
49it [00:09, 5.12it/s]i = 49 t = 20 (196ms)
50it [00:09, 5.14it/s]
Average step time: 192.8154182434082ms/it
Total image generation runtime (s): 10.390909433364868
(shark.venv) PS C:\g\shark>
Stats for run 0:
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
Total image generation time: 2.5788655281066895sec
```
Here are some samples generated:
@@ -105,9 +129,6 @@ Here are some samples generated:
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.

View File

@@ -0,0 +1,87 @@
Compile / Run Instructions:
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
Compile Commands FP32/FP16:
```shell
Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=mhlo for tf models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
CPU:
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
```
Run / Benchmark Command (FP32 - NCHW):
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
```shell
## Vulkan AMD:
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CUDA:
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=cuda --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CPU:
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=local-task --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
Run via vulkan_gui for RGP Profiling:
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
```shell
./build/vulkan_gui/iree-vulkan-gui --module=/path/to/unet.vmfb --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
</details>
<details>
<summary>Debug Commands</summary>
## Debug commands and other advanced usage follows.
```shell
python txt2img.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
```
## dump all dispatch .spv and isa using amdllpc
```shell
python txt2img.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
```
## Compile and save the .vmfb (using vulkan fp16 as an example):
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
```
## Capture an RGP trace
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
```
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
```shell
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf16
```
## Run the unet module with iree-benchmark-module (same config as above):
```shell
##if you want to use .npz inputs:
unzip ~/.local/shark_tank/<your unet>/inputs.npz
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --input=@arr_0.npy --input=1xf16 --input=@arr_2.npy --input=@arr_3.npy --input=@arr_4.npy
```
</details>

View File

@@ -0,0 +1 @@
from apps.stable_diffusion.scripts.txt2img import txt2img_inf

View File

@@ -0,0 +1,240 @@
import logging
import os
from models.stable_diffusion.main import stable_diff_inf
from models.stable_diffusion.utils import get_available_devices
from dotenv import load_dotenv
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram import BotCommand
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
from io import BytesIO
import random
log = logging.getLogger("TG.Bot")
logging.basicConfig()
log.warning("Start")
load_dotenv()
os.environ["AMD_ENABLE_LLPC"] = "0"
TG_TOKEN = os.getenv("TG_TOKEN")
SELECTED_MODEL = "stablediffusion"
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
STEPS = 30
NEGATIVE_PROMPT = (
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
" legs, fused fingers, too many fingers"
)
GUIDANCE_SCALE = 6
available_devices = get_available_devices()
models_list = [
"stablediffusion",
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]
sheds_list = [
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
]
def image_to_bytes(image):
bio = BytesIO()
bio.name = "image.jpeg"
image.save(bio, "JPEG")
bio.seek(0)
return bio
def get_try_again_markup():
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
reply_markup = InlineKeyboardMarkup(keyboard)
return reply_markup
def generate_image(prompt):
seed = random.randint(1, 10000)
log.warning(SELECTED_MODEL)
log.warning(STEPS)
image, text = stable_diff_inf(
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
steps=STEPS,
guidance_scale=GUIDANCE_SCALE,
seed=seed,
scheduler_key=SELECTED_SCHEDULER,
variant=SELECTED_MODEL,
device_key=available_devices[0],
)
return image, seed
async def generate_and_send_photo(
update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
progress_msg = await update.message.reply_text(
"Generating image...", reply_to_message_id=update.message.message_id
)
im, seed = generate_image(prompt=update.message.text)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{update.message.text}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=update.message.message_id,
)
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
query = update.callback_query
if query.data in models_list:
global SELECTED_MODEL
SELECTED_MODEL = query.data
await query.answer()
await query.edit_message_text(text=f"Selected model: {query.data}")
return
if query.data in sheds_list:
global SELECTED_SCHEDULER
SELECTED_SCHEDULER = query.data
await query.answer()
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
return
replied_message = query.message.reply_to_message
await query.answer()
progress_msg = await query.message.reply_text(
"Generating image...", reply_to_message_id=replied_message.message_id
)
if query.data == "TRYAGAIN":
prompt = replied_message.text
im, seed = generate_image(prompt)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{prompt}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=replied_message.message_id,
)
async def select_model_handler(update, context):
text = "Select model"
keyboard = []
for model in models_list:
keyboard.append(
[
InlineKeyboardButton(text=model, callback_data=model),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def select_scheduler_handler(update, context):
text = "Select schedule"
keyboard = []
for shed in sheds_list:
keyboard.append(
[
InlineKeyboardButton(text=shed, callback_data=shed),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def set_steps_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_steps ")[1]
global STEPS
STEPS = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_steps 30"
)
await update.message.reply_text(input_args)
async def set_negative_prompt_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_negative_prompt ")[1]
global NEGATIVE_PROMPT
NEGATIVE_PROMPT = input_args
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_negative_prompt ugly, bad art, mutated"
)
await update.message.reply_text(input_args)
async def set_guidance_scale_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_guidance_scale ")[1]
global GUIDANCE_SCALE
GUIDANCE_SCALE = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_guidance_scale 7"
)
await update.message.reply_text(input_args)
async def setup_bot_commands(application: Application) -> None:
await application.bot.set_my_commands(
[
BotCommand("select_model", "to select model"),
BotCommand("select_scheduler", "to select scheduler"),
BotCommand("set_steps", "to set steps"),
BotCommand("set_guidance_scale", "to set guidance scale"),
BotCommand("set_negative_prompt", "to set negative prompt"),
]
)
app = (
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
)
app.add_handler(CommandHandler("select_model", select_model_handler))
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
app.add_handler(CommandHandler("set_steps", set_steps_handler))
app.add_handler(
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
)
app.add_handler(
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
)
app.add_handler(
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
)
app.add_handler(CallbackQueryHandler(button))
log.warning("Start bot")
app.run_polling()

View File

@@ -0,0 +1,331 @@
import os
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
import sys
import json
import torch
import re
import time
from pathlib import Path
from PIL import PngImagePlugin
from datetime import datetime as dt
from dataclasses import dataclass
from csv import DictWriter
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
utils,
)
@dataclass
class Config:
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
# This has to come before importing cache objects
if args.clear_all:
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
from glob import glob
import shutil
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
if os.path.exists(yaml):
os.remove(yaml)
home = os.path.expanduser("~")
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed):
output_path = args.output_dir if args.output_dir else Path.cwd()
generated_imgs_path = Path(output_path, "generated_imgs")
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
img_model = args.hf_model_id
if args.ckpt_loc:
img_model = os.path.basename(args.ckpt_loc)
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
)
new_entry = {
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
"PROMPT": args.prompts[0],
"NEG_PROMPT": args.negative_prompts[0],
"SEED": img_seed,
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
}
with open(csv_path, "a") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
dictwriter_obj.writerow(new_entry)
csv_obj.close()
if args.save_metadata_to_json:
del new_entry["OUTPUT"]
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
txt2img_obj = None
config_obj = None
schedulers = None
# Exposed to UI.
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
):
global txt2img_obj
global config_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = custom_model
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
)
if config_obj != new_config_obj:
config_obj = new_config_obj
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.use_tuned = True
args.import_mlir = False
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
)
if not txt2img_obj:
sys.exit("text to image pipeline must not return a null value")
txt2img_obj.scheduler = schedulers[scheduler]
start_time = time.time()
txt2img_obj.log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = txt2img_obj.generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
txt2img_obj.log += "\n"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
return generated_imgs, text_output
if __name__ == "__main__":
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
)
for run in range(args.runs):
if run > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --runs=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)

View File

@@ -19,13 +19,19 @@ datas += copy_metadata('torchvision')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('diffusers')
datas += copy_metadata('transformers')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('gradio')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'prompts.json', '.' ),
( 'logos/*.png', 'logos' )
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
( 'web/css/*', 'css' ),
( 'web/logos/*', 'logos' )
]
binaries = []
@@ -34,11 +40,11 @@ block_cipher = None
a = Analysis(
['index.py'],
['web/index.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio'],
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
hookspath=[],
hooksconfig={},
runtime_hooks=[],

View File

@@ -19,24 +19,30 @@ datas += copy_metadata('torchvision')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('diffusers')
datas += copy_metadata('transformers')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('gradio')
datas += collect_data_files('iree')
#datas += copy_metadata('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'prompts.json', '.' ),
( 'logos/*', 'logos' )
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
]
binaries = []
block_cipher = None
a = Analysis(
['index.py'],
['scripts/txt2img.py'],
pathex=['.'],
binaries=[],
binaries=binaries,
datas=datas,
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio'],
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
@@ -51,13 +57,17 @@ pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
exclude_binaries=True,
name='shark_sd',
name='shark_sd_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
@@ -65,13 +75,3 @@ exe = EXE(
codesign_identity=None,
entitlements_file=None,
)
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
strip=False,
upx=True,
upx_exclude=[],
name='shark_sd',
)

View File

@@ -0,0 +1,8 @@
from apps.stable_diffusion.src.utils import (
args,
set_init_device_flags,
prompt_examples,
get_available_devices,
)
from apps.stable_diffusion.src.pipelines import Text2ImagePipeline
from apps.stable_diffusion.src.schedulers import get_schedulers

View File

@@ -0,0 +1,11 @@
from apps.stable_diffusion.src.models.model_wrappers import (
SharkifyStableDiffusionModel,
)
from apps.stable_diffusion.src.models.opt_params import (
get_vae,
get_unet,
get_clip,
get_tokenizer,
get_params,
get_variant_version,
)

View File

@@ -0,0 +1,295 @@
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from collections import defaultdict
import torch
import traceback
import re
import sys
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
base_models,
args,
fetch_or_delete_vmfbs,
preprocessCKPT,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
)
# These shapes are parameter dependent.
def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape = []
for i in range(len(shape)):
if shape[i] == "max_len":
new_shape.append(max_len)
elif shape[i] == "height":
new_shape.append(height)
elif shape[i] == "width":
new_shape.append(width)
elif isinstance(shape[i], str):
if "batch_size" in shape[i]:
mul_val = int(shape[i].split("*")[0])
new_shape.append(batch_size * mul_val)
else:
new_shape.append(shape[i])
return new_shape
# Get the input info for various models i.e. "unet", "clip", "vae".
def get_input_info(model_info, max_len, width, height, batch_size):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = defaultdict(list)
for k in model_info:
for inp in model_info[k]:
shape = model_info[k][inp]["shape"]
dtype = dtype_config[model_info[k][inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, max_len, width, height, batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map[k].append(tensor)
return input_map
class SharkifyStableDiffusionModel:
def __init__(
self,
model_id: str,
custom_weights: str,
precision: str,
max_len: int = 64,
width: int = 512,
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
self.model_id = model_id if custom_weights == "" else custom_weights
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
str(batch_size)
+ "_"
+ str(max_len)
+ "_"
+ str(height)
+ "_"
+ str(width)
+ "_"
+ precision
)
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
# We need a better naming convention for the .vmfbs because despite
# using the custom model variant the .vmfb names remain the same and
# it'll always pick up the compiled .vmfb instead of compiling the
# custom model.
# So, currently, we add `self.model_id` in the `self.model_name` of
# .vmfb file.
# TODO: Have a better way of naming the vmfbs using self.model_name.
model_name = re.sub(r"\W+", "_", self.model_id)
if model_name[0] == "_":
model_name = model_name[1:]
self.model_name = self.model_name + "_" + model_name
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
sys.exit("please specify max_len in the range [32, 77].")
if not (width % 8 == 0 and width >= 384):
sys.exit("width should be greater than 384 and multiple of 8")
if not (height % 8 == 0 and height >= 384):
sys.exit("height should be greater than 384 and multiple of 8")
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
)
self.base_vae = base_vae
def forward(self, input):
if not self.base_vae:
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
if self.base_vae:
return x
x = x * 255.0
return x.round()
vae = VaeModel()
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
vae_name = "base_vae" if self.base_vae else "vae"
shark_vae = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
model_name=vae_name + self.model_name,
extra_args=get_opt_flags("vae", precision=self.precision),
)
return shark_vae
def get_unet(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(
self, latent, timestep, text_embedding, guidance_scale
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel()
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
shark_unet = compile_through_fx(
unet,
inputs,
model_name="unet" + self.model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_unet
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
)
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText()
shark_clip = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
model_name="clip" + self.model_name,
extra_args=get_opt_flags("clip", precision="fp32"),
)
return shark_clip
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
# configiration.
def compile_all(self, base_model_id):
self.inputs = get_input_info(
base_models[base_model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
return compiled_clip, compiled_unet, compiled_vae
def __call__(self):
# Step 1:
# -- Fetch all vmfbs for the model, if present, else delete the lot.
vmfbs = fetch_or_delete_vmfbs(
self.model_name, self.base_vae, self.precision
)
if vmfbs[0]:
# -- If all vmfbs are indeed present, we also try and fetch the base
# model configuration for running SD with custom checkpoints.
if self.custom_weights != "":
args.hf_model_id = fetch_and_update_base_model_id(self.custom_weights)
if args.hf_model_id == "":
sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.")
print("Loaded vmfbs from cache and successfully fetched base model configuration.")
return vmfbs
# Step 2:
# -- If vmfbs weren't found, we try to see if the base model configuration
# for the required SD run is known to us and bypass the retry mechanism.
model_to_run = ""
if self.custom_weights != "":
model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights)
else:
model_to_run = args.hf_model_id
base_model_fetched = fetch_and_update_base_model_id(model_to_run)
if base_model_fetched != "":
print("Compiling all the models with the fetched base model configuration.")
if args.ckpt_loc != "":
args.hf_model_id = base_model_fetched
return self.compile_all(base_model_fetched)
# Step 3:
# -- This is the retry mechanism where the base model's configuration is not
# known to us and figure that out by trial and error.
print("Inferring base model configuration.")
for model_id in base_models:
try:
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id)
except Exception as e:
if args.enable_stack_trace:
traceback.print_exc()
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please re-run the command with `--enable_stack_trace` flag and create an issue with detailed log at https://github.com/nod-ai/SHARK/issues"
)

View File

@@ -0,0 +1,89 @@
import sys
from transformers import CLIPTokenizer
from apps.stable_diffusion.src.utils import (
models_db,
args,
get_shark_model,
get_opt_flags,
)
hf_model_variant_map = {
"Linaqruf/anything-v3.0": ["anythingv3", "v2_1base"],
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v2_1base"],
"prompthero/openjourney": ["openjourney", "v2_1base"],
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
}
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
def get_params(bucket_key, model_key, model, is_tuned, precision):
try:
bucket = models_db[0][bucket_key]
model_name = models_db[1][model_key]
except KeyError:
raise Exception(
f"{bucket_key}/{model_key} is not present in the models database"
)
iree_flags = get_opt_flags(model, precision="fp16")
return bucket, model_name, iree_flags
def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "unet", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
is_base = "/base" if args.use_base_vae else ""
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
variant, version = get_variant_version(args.hf_model_id)
bucket_key = f"{variant}/untuned"
model_key = (
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
)
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "clip", "untuned", "fp32"
)
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
)
return tokenizer

View File

@@ -0,0 +1,3 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)

View File

@@ -0,0 +1,135 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
return all_imgs

View File

@@ -0,0 +1,206 @@
import torch
from transformers import CLIPTokenizer
from PIL import Image
from tqdm.auto import tqdm
import time
from typing import Union
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae,
get_clip,
get_unet,
get_tokenizer,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
class StableDiffusionPipeline:
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
):
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.unet = unet
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
text_input = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
# Get unconditional embeddings as well
uncond_input = self.tokenizer(
neg_prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
clip_inf_start = time.time()
text_embeddings = self.text_encoder("forward", (text_input,))
clip_inf_time = (time.time() - clip_inf_start) * 1000
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
if use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
if use_base_vae:
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def produce_img_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = self.scheduler.scale_model_input(latents, t)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
@classmethod
def from_pretrained(
cls,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
import_mlir: bool,
model_id: str,
ckpt_loc: str,
precision: str,
max_length: int,
batch_size: int,
height: int,
width: int,
use_base_vae: bool,
use_tuned: bool,
):
if import_mlir:
# TODO: Delet this when on-the-fly tuning of models work.
use_tuned = False
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
return cls(
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
)

View File

@@ -0,0 +1,4 @@
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)

View File

@@ -0,0 +1,51 @@
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
def get_schedulers(model_id):
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"SharkEulerDiscrete"
] = SharkEulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers

View File

@@ -0,0 +1,143 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_shark_model,
args,
)
import torch
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
)
def compile(self):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
model_input = {
"euler": {
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
}
example_latent = model_input["euler"]["latent"]
example_output = model_input["euler"]["output"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_sigma = model_input["euler"]["sigma"]
example_dt = model_input["euler"]["dt"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.import_mlir:
scaling_model = ScalingModel()
self.scaling_model = compile_through_fx(
scaling_model,
(example_latent, example_sigma),
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
else:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
return self.scaling_model(
"forward",
(
sample,
sigma,
),
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
return self.step_model(
"forward",
(
noise_pred,
sigma,
latent,
dt,
),
send_to_host=False,
)

View File

@@ -0,0 +1,27 @@
from apps.stable_diffusion.src.utils.profiler import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.utils.resources import (
prompt_examples,
models_db,
base_models,
opt_flags,
resource_path,
)
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.utils import (
get_shark_model,
compile_through_fx,
set_iree_runtime_flags,
map_device_to_name_path,
set_init_device_flags,
get_available_devices,
get_opt_flags,
preprocessCKPT,
fetch_or_delete_vmfbs,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
)

View File

@@ -0,0 +1,18 @@
from apps.stable_diffusion.src.utils.stable_args import args
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device:
import iree
print(f"Profiling and saving to {file_path}.")
vulkan_device = iree.runtime.get_device(args.device)
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
return vulkan_device
return None
def end_profiling(device):
if device:
return device.end_profiling()

View File

@@ -0,0 +1,37 @@
import os
import json
import sys
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
def get_json_file(path):
json_var = []
loc_json = resource_path(path)
if os.path.exists(loc_json):
with open(loc_json, encoding="utf-8") as fopen:
json_var = json.load(fopen)
if not json_var:
print(f"Unable to fetch {path}")
return json_var
# TODO: This shouldn't be called from here, every time the file imports
# it will run all the global vars.
prompt_examples = get_json_file("resources/prompts.json")
models_db = get_json_file("resources/model_db.json")
# The base_model contains the input configuration for the different
# models and also helps in providing information for the variants.
base_models = get_json_file("resources/base_model.json")
# Contains optimization flags for different models.
opt_flags = get_json_file("resources/opt_flags.json")

View File

@@ -0,0 +1,98 @@
{
"stabilityai/stable-diffusion-2-1": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"CompVis/stable-diffusion-v1-4": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
}
}

View File

@@ -0,0 +1,21 @@
[
{
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
"openjourney/v1_4":"prompthero/openjourney",
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
},
{
"stablediffusion/fp16":"fp16",
"stablediffusion/fp32":"main",
"anythingv3/fp16":"diffusers",
"anythingv3/fp32":"diffusers",
"analogdiffusion/fp16":"main",
"analogdiffusion/fp32":"main",
"openjourney/fp16":"main",
"openjourney/fp32":"main"
}
]

View File

@@ -0,0 +1,82 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/sd_untuned",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
"anythingv3/v2_1base/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
"analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
"openjourney/v2_1base/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
"openjourney/v2_1base/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
"openjourney/v2_1base/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
}
]

View File

@@ -0,0 +1,84 @@
{
"unet": {
"tuned": {
"fp16": {
"default_compilation_flags": []
},
"fp32": {
"default_compilation_flags": []
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
},
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
},
"fp32": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
},
"clip": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
}
}

View File

@@ -0,0 +1,234 @@
import os
import io
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import iree_target_map, run_cmd
from shark.shark_downloader import (
download_model,
download_public_file,
WORKDIR,
)
from shark.parser import shark_args
from apps.stable_diffusion.src.utils.stable_args import args
def get_device():
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
return device
# Download the model (Unet or VAE fp16) from shark_tank
def load_model_from_tank():
from apps.stable_diffusion.src.models import (
get_params,
get_variant_version,
)
variant, version = get_variant_version(args.hf_model_id)
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{variant}/untuned"
if args.annotation_model == "unet":
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
is_base = "/base" if args.use_base_vae else ""
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)
return mlir_model, model_name
# Download the tuned config files from shark_tank
def load_winograd_configs():
device = get_device()
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
winograd_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
def load_lower_configs():
from apps.stable_diffusion.src.models import get_variant_version
variant, version = get_variant_version(args.hf_model_id)
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_version = version
if variant in ["anythingv3", "analogdiffusion"]:
args.max_length = 77
config_version = "v1_4"
if args.annotation_model == "vae":
args.max_length = 77
device = get_device()
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json"
full_gs_url = config_bucket + config_name
lowering_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading lowering config file from ", lowering_config_dir)
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
# Annotate the model with Winograd attribute on selected conv ops
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with create_context() as ctx:
winograd_model = model_annotation(
ctx,
input_contents=input_mlir,
config_path=winograd_config_dir,
search_op="conv",
winograd=True,
)
bytecode_stream = io.BytesIO()
winograd_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return bytecode, out_file_path
def dump_after_mlir(input_mlir, model_name, use_winograd):
if use_winograd:
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline='builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))' "
)
else:
dump_after = "iree-preprocessing-pad-linalg-ops"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline='builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))' "
)
device_spec_args = ""
device = get_device()
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
gpu_flags = get_iree_gpu_args()
for flag in gpu_flags:
device_spec_args += flag + " "
elif device == "vulkan":
device_spec_args = (
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
print("Applying tuned configs on", model_name)
run_cmd(
f"iree-compile {input_mlir} "
"--iree-input-type=tm_tensor "
f"--iree-hal-target-backends={iree_target_map(device)} "
f"{device_spec_args}"
f"{preprocess_flag}"
"--iree-stream-resource-index-bits=64 "
"--iree-vm-target-index-bits=64 "
f"--mlir-print-ir-after={dump_after} "
"--compile-to=flow "
f"2>{args.annotation_output}/dump_after_winograd.mlir "
)
# For Unet annotate the model with tuned lowering configs
def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
# Dump IR after padding/img2col/winograd passes
dump_after_mlir(input_mlir, model_name, use_winograd)
# Annotate the model with lowering configs in the config file
with create_context() as ctx:
tuned_model = model_annotation(
ctx,
input_contents=f"{args.annotation_output}/dump_after_winograd.mlir",
config_path=lowering_config_dir,
search_op="all",
)
# Remove the intermediate mlir and save the final annotated model
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
bytecode_stream = io.BytesIO()
tuned_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return bytecode, out_file_path
def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model, model_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
model_path, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model, output_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
use_winograd = False
if model_from_tank:
mlir_model = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
else:
# Just use this function to convert bytecode to string
orig_model, model_path = annotate_with_winograd(
mlir_model, "", model_name
)
mlir_model = model_path
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)
print(f"Saved the annotated mlir in {output_path}.")
return tuned_model
if __name__ == "__main__":
mlir_model, model_name = load_model_from_tank()
sd_model_annotation(mlir_model, model_name, model_from_tank=True)

View File

@@ -0,0 +1,345 @@
import argparse
from pathlib import Path
def path_expand(s):
return Path(s).expanduser().resolve()
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Stable Diffusion Params
##############################################################################
p.add_argument(
"-p",
"--prompts",
action="append",
default=[],
help="text of which images to be generated.",
)
p.add_argument(
"--negative_prompts",
nargs="+",
default=[""],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=int,
default=42,
help="the seed to use.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `run`.",
)
p.add_argument(
"--height",
type=int,
default=512,
help="the height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
help="the width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="the value to be used for guidance scaling.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
)
##############################################################################
### Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
)
p.add_argument(
"--runs",
type=int,
default=1,
help="number of images to be generated with random seeds in single execution",
)
p.add_argument(
"--ckpt_loc",
type=str,
default="",
help="Path to SD's .ckpt file.",
)
p.add_argument(
"--hf_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--enable_stack_trace",
default=False,
action=argparse.BooleanOptionalAction,
help="Enable showing the stack trace when retrying the base model configuration",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree-vulkan-target-triple",
type=str,
default="",
help="Specify target triple for vulkan",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="4147483648",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
)
##############################################################################
### Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save a generation information json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
)
##############################################################################
### Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the pregress bar animation during image generation",
)
p.add_argument(
"--ckpt_dir",
type=str,
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
##############################################################################
### SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file",
)
p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)
p.add_argument(
"--use_winograd",
default=False,
action=argparse.BooleanOptionalAction,
help="Apply Winograd on selected conv ops.",
)
args, unknown = p.parse_known_args()

View File

@@ -0,0 +1,461 @@
import os
import gc
import json
from pathlib import Path
import numpy as np
from random import randint
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys, functools, operator
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
def get_vmfb_path_name(model_name):
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
return [vmfb_path, extended_name]
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
[vmfb_path, extended_name] = get_vmfb_path_name(model_name)
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), extended_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.parser import shark_args
# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache
from shark.shark_downloader import download_model
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=tank_url,
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
)
return _compile_module(shark_module, model_name, extra_args)
# Converts the torch-module into a shark_module.
def compile_through_fx(
model,
inputs,
model_name,
is_f16=False,
f16_input_mask=None,
use_tuned=False,
extra_args=[],
):
from shark.parser import shark_args
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
)
if use_tuned:
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
mlir_module = sd_model_annotation(mlir_module, model_name)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
del mlir_module
gc.collect()
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--device_allocator=caching",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def set_init_device_flags():
if "vulkan" in args.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "cpu" in args.device:
args.device = "cpu"
# set max_length based on availability.
if args.hf_model_id in [
"Linaqruf/anything-v3.0",
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]:
args.max_length = 77
elif args.hf_model_id == "prompthero/openjourney":
args.max_length = 64
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
if (
args.hf_model_id == "prompthero/openjourney"
or args.ckpt_loc != ""
or args.precision != "fp16"
or args.height != 512
or args.width != 512
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
args.use_tuned = False
elif (
"vulkan" in args.device
and "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80"]:
args.use_tuned = False
elif args.use_base_vae and args.hf_model_id not in [
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {args.hf_model_id}/fp16/{args.device}.")
else:
print("Tuned models are currently not supported for this setting.")
# set import_mlir to True for unuploaded models.
if args.ckpt_loc != "":
args.import_mlir = True
elif args.hf_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.import_mlir = True
elif args.height != 512 or args.width != 512 or args.batch_size != 1:
args.import_mlir = True
# Utility to get list of devices available.
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
for i, device in enumerate(device_list_dict):
device_list.append(f"{device['name']} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
return available_devices
def disk_space_check(path, lim=20):
from shutil import disk_usage
du = disk_usage(path)
free = du.free / (1024 * 1024 * 1024)
if free <= lim:
print(f"[WARNING] Only {free:.2f}GB space available in {path}.")
def get_opt_flags(model, precision="fp16"):
iree_flags = []
is_tuned = "tuned" if args.use_tuned else "untuned"
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
"default_compilation_flags"
]
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
if (
device
not in opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
]
):
device = "default_device"
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
return iree_flags
def get_path_to_diffusers_checkpoint(custom_weights):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = path.stem
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
return path_to_diffusers
def preprocessCKPT(custom_weights):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
return
else:
print(
"Diffusers' checkpoint will be identified here : ",
path_to_diffusers,
)
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
extract_ema = False
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
)
pipe.save_pretrained(path_to_diffusers)
print("Loading complete")
def load_vmfb(vmfb_path, model, precision):
model = "vae" if "base_vae" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module = SharkInference(mlir_module=None, device=args.device)
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
# This utility returns vmfbs of Clip, Unet and Vae, in case all three of them
# are present; deletes them otherwise.
def fetch_or_delete_vmfbs(basic_model_name, use_base_vae, precision="fp32"):
model_name = ["clip", "unet", "base_vae" if use_base_vae else "vae"]
vmfb_path = [
get_vmfb_path_name(model + basic_model_name)[0] for model in model_name
]
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = functools.reduce(operator.__and__, vmfb_present)
compiled_models = [None] * 3
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(len(vmfb_path)):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
else:
for i in range(len(vmfb_path)):
compiled_models[i] = load_vmfb(
vmfb_path[i], model_name[i], precision
)
return compiled_models
# `fetch_and_update_base_model_id` is a resource utility function which
# helps maintaining mapping of the model to run with its base model.
# If `base_model` is "", then this function tries to fetch the base model
# info for the `model_to_run`.
def fetch_and_update_base_model_id(model_to_run, base_model=""):
variants_path = os.path.join(os.getcwd(), "variants.json")
data = {model_to_run: base_model}
json_data = {}
if os.path.exists(variants_path):
with open(variants_path, "r", encoding="utf-8") as jsonFile:
json_data = json.load(jsonFile)
# Return with base_model's info if base_model is "".
if base_model == "":
if model_to_run in json_data:
base_model = json_data[model_to_run]
return base_model
elif base_model == "":
return base_model
# Update JSON data to contain an entry mapping model_to_run with base_model.
json_data.update(data)
with open(variants_path, "w", encoding="utf-8") as jsonFile:
json.dump(json_data, jsonFile)
# Generate and return a new seed if the provided one is not in the supported range (including -1)
def sanitize_seed(seed):
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed

View File

@@ -0,0 +1,70 @@
# Stable Diffusion optimized for AMD RDNA2/RDNA3 GPUs
Before you start, please be aware that this is beta software that relies on a special AMD driver. Like all StableDiffusion GUIs published so far, you need some technical expertise to set it up. We apologize in advance if you bump into issues. If that happens, please don't hesitate to ask our Discord community for help! Please be assured that we (Nod and AMD) are working hard to improve the user experience in coming months.
If it works well for you, please "star" the following GitHub projects... this is one of the best ways to help and spread the word!
* https://github.com/nod-ai/SHARK
* https://github.com/iree-org/iree
## Install this specific AMD Drivers (AMD latest may not have all the fixes).
### AMD KB Drivers for RDNA2 and RDNA3:
*AMD Software: Adrenalin Edition 22.11.1 for MLIR/IREE Driver Version 22.20.29.09 for Windows® 10 and Windows® 11 (Windows Driver Store Version 31.0.12029.9003)*
First, for RDNA2 users, download this special driver in a folder of your choice. We recommend you keep the installation files around, since you may need to re-install it later, if Windows Update decides to overwrite it:
https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mlir-iree
For RDNA3, the latest driver 23.1.2 supports MLIR/IREE as well: https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-1-2-kb
KNOWN ISSUES with this special AMD driver:
* `Windows Update` may (depending how it's configured) automatically install a new official AMD driver that overwrites this IREE-specific driver. If Stable Diffusion used to work, then a few days later, it slows down a lot or produces incorrect results (e.g. black images), this may be the cause. To fix this problem, please check the installed driver version, and re-install the special driver if needed. (TODO: document how to prevent this `Windows Update` behavior!)
* Some people using this special driver experience mouse pointer accuracy issues, especially if using a larger-than-default mouse pointer. The clicked point isn't centered properly. One possible work-around is to reset the pointer size to "1" in "Change pointer size and color".
## Installation
Download the latest Windows SHARK SD binary [492 here](https://github.com/nod-ai/SHARK/releases/download/20230203.492/shark_sd_20230203_492.exe) in a folder of your choice. If you want nighly builds, you can look for them on the GitHub releases page.
Notes:
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR which can be outdated if you run a new EXE from the same folder. You can use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you:
* clear all the local artifacts with `--clear_all` OR
* clear the Vulkan shader cache: For Windows users this can be done by clearing the contents of `C:\Users\%username%\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
* clear the `huggingface` cache. In Windows, this is `C:\Users\%username%\.cache\huggingface`.
## Running
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE to start the web browser)
* The first run may take about 10-15 minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
* If successful, you will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/?__theme=dark.
## Stopping
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment. The application should stop.
* Please make sure to do the above step before you attempt to update the EXE to a new version.
# Results
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
Here are some samples generated:
![tajmahal, snow, sunflowers, oil on canvas_0](https://user-images.githubusercontent.com/74956/204934186-141f7e43-6eb2-4e89-a99c-4704d20444b3.jpg)
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
The output on a 7900XTX would like:
```shell
Stats for run 0:
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
Total image generation time: 2.5788655281066895sec
```
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.

View File

@@ -0,0 +1,15 @@
You need to pre-create your bot (https://core.telegram.org/bots#how-do-i-create-a-bot)
Then create in the directory web file .env
In it the record:
TG_TOKEN="your_token"
specifying your bot's token from previous step.
Then run telegram_bot.py with the same parameters that you use when running index.py, for example:
python telegram_bot.py --max_length=77 --vulkan_large_heap_block_size=0 --use_base_vae --local_tank_cache h:\shark\TEMP
Bot commands:
/select_model
/select_scheduler
/set_steps "integer number of steps"
/set_guidance_scale "integer number"
/set_negative_prompt "negative text"
Any other text triggers the creation of an image based on it.

View File

@@ -0,0 +1,209 @@
/* Overwrite the Gradio default theme with their .dark theme declarations */
:root {
--color-focus-primary: var(--color-grey-700);
--color-focus-secondary: var(--color-grey-600);
--color-focus-ring: rgb(55 65 81);
--color-background-primary: var(--color-grey-950);
--color-background-secondary: var(--color-grey-900);
--color-background-tertiary: var(--color-grey-800);
--color-text-body: var(--color-grey-100);
--color-text-label: var(--color-grey-200);
--color-text-placeholder: var(--color-grey);
--color-text-subdued: var(--color-grey-400);
--color-text-link-base: var(--color-blue-500);
--color-text-link-hover: var(--color-blue-400);
--color-text-link-visited: var(--color-blue-600);
--color-text-link-active: var(--color-blue-500);
--color-text-code-background: var(--color-grey-800);
--color-text-code-border: color.border-primary;
--color-border-primary: var(--color-grey-700);
--color-border-secondary: var(--color-grey-600);
--color-border-highlight: var(--color-accent-base);
--color-accent-base: var(--color-orange-500);
--color-accent-light: var(--color-orange-300);
--color-accent-dark: var(--color-orange-700);
--color-functional-error-base: var(--color-red-400);
--color-functional-error-subdued: var(--color-red-300);
--color-functional-error-background: var(--color-background-primary);
--color-functional-info-base: var(--color-yellow);
--color-functional-info-subdued: var(--color-yellow-300);
--color-functional-success-base: var(--color-green);
--color-functional-success-subdued: var(--color-green-300);
--shadow-spread: 2px;
--api-background: linear-gradient(to bottom, rgba(255, 216, 180, .05), transparent);
--api-pill-background: var(--color-orange-400);
--api-pill-border: var(--color-orange-600);
--api-pill-text: var(--color-orange-900);
--block-border-color: var(--color-border-primary);
--block-background: var(--color-background-tertiary);
--uploadable-border-color-hover: var(--color-border-primary);
--uploadable-border-color-loaded: var(--color-functional-success);
--uploadable-text-color: var(--color-text-subdued);
--block_label-border-color: var(--color-border-primary);
--block_label-icon-color: var(--color-text-label);
--block_label-shadow: var(--shadow-drop);
--block_label-background: var(--color-background-secondary);
--icon_button-icon-color-base: var(--color-text-label);
--icon_button-icon-color-hover: var(--color-text-label);
--icon_button-background-base: var(--color-background-primary);
--icon_button-background-hover: var(--color-background-primary);
--icon_button-border-color-base: var(--color-background-primary);
--icon_button-border-color-hover: var(--color-border-secondary);
--input-text-color: var(--color-text-body);
--input-border-color-base: var(--color-border-primary);
--input-border-color-hover: var(--color-border-primary);
--input-border-color-focus: var(--color-border-primary);
--input-background-base: var(--color-background-tertiary);
--input-background-hover: var(--color-background-tertiary);
--input-background-focus: var(--color-background-tertiary);
--input-shadow: var(--shadow-inset);
--checkbox-border-color-base: var(--color-border-primary);
--checkbox-border-color-hover: var(--color-focus-primary);
--checkbox-border-color-focus: var(--color-blue-500);
--checkbox-background-base: var(--color-background-primary);
--checkbox-background-hover: var(--color-background-primary);
--checkbox-background-focus: var(--color-background-primary);
--checkbox-background-selected: var(--color-blue-600);
--checkbox-label-border-color-base: var(--color-border-primary);
--checkbox-label-border-color-hover: var(--color-border-primary);
--checkbox-label-border-color-focus: var(--color-border-secondary);
--checkbox-label-background-base: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
--checkbox-label-background-hover: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
--checkbox-label-background-focus: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
--form-seperator-color: var(--color-border-primary);
--button-primary-border-color-base: var(--color-orange-600);
--button-primary-border-color-hover: var(--color-orange-600);
--button-primary-border-color-focus: var(--color-orange-600);
--button-primary-text-color-base: white;
--button-primary-text-color-hover: white;
--button-primary-text-color-focus: white;
--button-primary-background-base: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-700));
--button-primary-background-hover: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-500));
--button-primary-background-focus: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-500));
--button-secondary-border-color-base: var(--color-grey-600);
--button-secondary-border-color-hover: var(--color-grey-600);
--button-secondary-border-color-focus: var(--color-grey-600);
--button-secondary-text-color-base: white;
--button-secondary-text-color-hover: white;
--button-secondary-text-color-focus: white;
--button-secondary-background-base: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-700));
--button-secondary-background-hover: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-600));
--button-secondary-background-focus: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-600));
--button-cancel-border-color-base: var(--color-red-600);
--button-cancel-border-color-hover: var(--color-red-600);
--button-cancel-border-color-focus: var(--color-red-600);
--button-cancel-text-color-base: white;
--button-cancel-text-color-hover: white;
--button-cancel-text-color-focus: white;
--button-cancel-background-base: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-700));
--button-cancel-background-focus: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-500));
--button-cancel-background-hover: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-500));
--button-plain-border-color-base: var(--color-grey-600);
--button-plain-border-color-hover: var(--color-grey-500);
--button-plain-border-color-focus: var(--color-grey-500);
--button-plain-text-color-base: var(--color-text-body);
--button-plain-text-color-hover: var(--color-text-body);
--button-plain-text-color-focus: var(--color-text-body);
--button-plain-background-base: var(--color-grey-700);
--button-plain-background-hover: var(--color-grey-700);
--button-plain-background-focus: var(--color-grey-700);
--gallery-label-background-base: var(--color-grey-50);
--gallery-label-background-hover: var(--color-grey-50);
--gallery-label-border-color-base: var(--color-border-primary);
--gallery-label-border-color-hover: var(--color-border-primary);
--gallery-thumb-background-base: var(--color-grey-900);
--gallery-thumb-background-hover: var(--color-grey-900);
--gallery-thumb-border-color-base: var(--color-border-primary);
--gallery-thumb-border-color-hover: var(--color-accent-base);
--gallery-thumb-border-color-focus: var(--color-blue-500);
--gallery-thumb-border-color-selected: var(--color-accent-base);
--chatbot-border-border-color-base: transparent;
--chatbot-border-border-color-latest: transparent;
--chatbot-user-background-base: ;
--chatbot-user-background-latest: ;
--chatbot-user-text-color-base: white;
--chatbot-user-text-color-latest: white;
--chatbot-bot-background-base: ;
--chatbot-bot-background-latest: ;
--chatbot-bot-text-color-base: white;
--chatbot-bot-text-color-latest: white;
--label-gradient-from: var(--color-orange-400);
--label-gradient-to: var(--color-orange-600);
--table-odd-background: var(--color-grey-900);
--table-even-background: var(--color-grey-950);
--table-background-edit: transparent;
--dataset-gallery-background-base: var(--color-background-primary);
--dataset-gallery-background-hover: var(--color-grey-800);
--dataset-dataframe-border-base: var(--color-border-primary);
--dataset-dataframe-border-hover: var(--color-border-secondary);
--dataset-table-background-base: transparent;
--dataset-table-background-hover: var(--color-grey-700);
--dataset-table-border-base: var(--color-grey-800);
--dataset-table-border-hover: var(--color-grey-800);
}
/* SHARK theme customization */
.gradio-container {
background-color: var(--color-background-primary);
}
.container {
background-color: black !important;
padding-top: 20px !important;
}
#ui_title {
padding: 10px !important;
}
#top_logo {
background-color: transparent;
border-radius: 0 !important;
border: 0;
}
#demo_title {
background-color: var(--color-background-primary);
border-radius: 0 !important;
border: 0;
padding-top: 15px;
padding-bottom: 0px;
width: 350px !important;
}
#demo_title_outer {
border-radius: 0;
}
#prompt_box_outer div:first-child {
border-radius: 0 !important
}
#prompt_box textarea {
background-color: var(--color-background-primary) !important;
}
#prompt_examples {
margin: 0 !important;
}
#prompt_examples svg {
display: none !important;
}
#ui_body {
background-color: var(--color-background-secondary) !important;
padding: 10px !important;
border-radius: 0.5em !important;
}
#img_result+div {
display: none !important;
}
footer {
display: none !important;
}

View File

@@ -0,0 +1,264 @@
import os
import sys
from pathlib import Path
import glob
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
import gradio as gr
from PIL import Image
from apps.stable_diffusion.src import (
prompt_examples,
args,
get_available_devices,
)
from apps.stable_diffusion.scripts import txt2img_inf
nodlogo_loc = resource_path("logos/nod-logo.png")
sdlogo_loc = resource_path("logos/sd-demo-logo.png")
demo_css = resource_path("css/sd_dark_theme.css")
with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
logo2 = Image.open(sdlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=100)
with gr.Column(scale=5, elem_id="demo_title_outer"):
gr.Image(
value=logo2,
show_label=False,
interactive=False,
elem_id="demo_title",
).style(width=150, height=100)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
ckpt_path = (
Path(args.ckpt_dir)
if args.ckpt_dir
else Path(Path.cwd(), "models")
)
ckpt_path.mkdir(parents=True, exist_ok=True)
types = (
"*.ckpt",
"*.safetensors",
) # the tuple of file types
ckpt_files = ["None"]
for extn in types:
files = glob.glob(os.path.join(ckpt_path, extn))
ckpt_files.extend(files)
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {ckpt_path})",
value="None",
choices=ckpt_files
+ [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
],
)
hf_model_id = gr.Textbox(
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value="cyberpunk forest by Salvador Dali",
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="trees, green",
lines=1,
elem_id="prompt_box",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
label="Scheduler",
value="SharkEulerDiscrete",
choices=[
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
],
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=True,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=False,
interactive=True,
)
with gr.Row():
height = gr.Slider(
384, 786, value=512, step=8, label="Height"
)
width = gr.Slider(
384, 786, value=512, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=64,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=50, step=1, label="Steps"
)
guidance_scale = gr.Slider(
0,
50,
value=7.5,
step=0.1,
label="CFG Scale",
)
with gr.Row():
batch_count = gr.Slider(
1,
10,
value=1,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=1,
step=1,
label="Batch Size",
interactive=True,
)
with gr.Row():
seed = gr.Number(value=-1, precision=0, label="Seed")
available_devices = get_available_devices()
device = gr.Dropdown(
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => Math.floor(Math.random() * 4294967295)",
)
stable_diffusion = gr.Button("Generate Image")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2], height="auto")
std_output = gr.Textbox(
value="Nothing to show.",
lines=4,
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
kwargs = dict(
fn=txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
prompt.submit(**kwargs)
stable_diffusion.click(**kwargs)
shark_web.queue()
shark_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

View File

Before

Width:  |  Height:  |  Size: 33 KiB

After

Width:  |  Height:  |  Size: 33 KiB

View File

Before

Width:  |  Height:  |  Size: 10 KiB

After

Width:  |  Height:  |  Size: 10 KiB

View File

Before

Width:  |  Height:  |  Size: 5.0 KiB

After

Width:  |  Height:  |  Size: 5.0 KiB

View File

@@ -0,0 +1,45 @@
import argparse
from PIL import Image
import numpy as np
import requests
import shutil
import os
import subprocess
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--newfile")
parser.add_argument(
"-g",
"--golden_url",
default="https://storage.googleapis.com/shark_tank/testdata/cyberpunk_fores_42_0_230119_021148.png",
)
def get_image(url, local_filename):
res = requests.get(url, stream=True)
if res.status_code == 200:
with open(local_filename, "wb") as f:
shutil.copyfileobj(res.raw, f)
def compare_images(new_filename, golden_filename):
new = np.array(Image.open(new_filename)) / 255.0
golden = np.array(Image.open(golden_filename)) / 255.0
diff = np.abs(new - golden)
mean = np.mean(diff)
if mean > 0.1:
subprocess.run(
["gsutil", "cp", new_filename, "gs://shark_tank/testdata/builder/"]
)
raise SystemExit("new and golden not close")
else:
print("SUCCESS")
if __name__ == "__main__":
args = parser.parse_args()
tempfile_name = os.path.join(os.getcwd(), "golden.png")
get_image(args.golden_url, tempfile_name)
compare_images(args.newfile, tempfile_name)

View File

@@ -1,5 +1,5 @@
#!/bin/bash
IMPORTER=1 ./setup_venv.sh
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python generate_sharktank.py --upload=False --ci_tank_dir=True
python generate_sharktank.py

View File

@@ -0,0 +1,78 @@
import os
import subprocess
from apps.stable_diffusion.src.utils.resources import (
get_json_file,
)
from shark.shark_downloader import download_public_file
from image_comparison import compare_images
import argparse
from glob import glob
import shutil
model_config_dicts = get_json_file(
os.path.join(
os.getcwd(),
"apps/stable_diffusion/src/utils/resources/model_config.json",
)
)
def test_loop(device="vulkan", beta=False, extra_flags=[]):
# Get golden values from tank
shutil.rmtree("./test_images", ignore_errors=True)
os.mkdir("./test_images")
os.mkdir("./test_images/golden")
hf_model_names = model_config_dicts[0].values()
tuned_options = ["--no-use_tuned", "use_tuned"]
if beta:
extra_flags.append("--beta_models=True")
for model_name in hf_model_names:
for use_tune in tuned_options:
command = [
"python",
"apps/stable_diffusion/scripts/txt2img.py",
"--device=" + device,
"--prompt=cyberpunk forest by Salvador Dali",
"--output_dir="
+ os.path.join(os.getcwd(), "test_images", model_name),
"--hf_model_id=" + model_name,
use_tune,
]
command += extra_flags
generated_image = not subprocess.call(
command, stdout=subprocess.DEVNULL
)
if generated_image:
print(" ".join(command))
print("Successfully generated image")
os.makedirs(
"./test_images/golden/" + model_name, exist_ok=True
)
download_public_file(
"gs://shark_tank/testdata/golden/" + model_name,
"./test_images/golden/" + model_name,
)
test_file_path = os.path.join(
os.getcwd(), "test_images", model_name, "generated_imgs"
)
test_file = glob(test_file_path + "/*.png")[0]
golden_path = "./test_images/golden/" + model_name + "/*.png"
golden_file = glob(golden_path)[0]
compare_images(test_file, golden_file)
else:
print(" ".join(command))
print("failed to generate image for this configuration")
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--device", default="vulkan")
parser.add_argument(
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
)
if __name__ == "__main__":
args = parser.parse_args()
print(args)
test_loop(args.device, args.beta, [])

27
dataset/README.md Normal file
View File

@@ -0,0 +1,27 @@
# Dataset annotation tool
SHARK annotator for adding or modifying prompts of dataset images
## Set up
Activate SHARK Python virtual environment and install additional packages
```shell
source ../shark.venv/bin/activate
pip install -r requirements.txt
```
## Run annotator
```shell
python annotation_tool.py
```
<img width="1280" alt="annotator" src="https://user-images.githubusercontent.com/49575973/214521137-7ef6ae10-7cd8-46e6-b270-b6c0445157f1.png">
* Select a dataset from `Dataset` dropdown list
* Select an image from `Image` dropdown list
* Image and the existing prompt will be loaded
* Select a prompt from `Prompt` dropdown list to modify or "Add new" to add a prompt
* Click `Save` to save changes, click `Delete` to delete prompt
* Click `Back` or `Next` to switch image, you could also select other images from `Image`
* Click `Finish` when finishing annotation or before switching dataset

247
dataset/annotation_tool.py Normal file
View File

@@ -0,0 +1,247 @@
import gradio as gr
import json
import jsonlines
import os
from args import args
from pathlib import Path
from PIL import Image
from utils import get_datasets
shark_root = Path(__file__).parent.parent
demo_css = shark_root.joinpath("web/demo.css").resolve()
nodlogo_loc = shark_root.joinpath(
"web/models/stable_diffusion/logos/nod-logo.png"
)
with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=100)
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
prompt_data = dict()
with gr.Row(elem_id="ui_body"):
# TODO: add multiselect dataset, there is a gradio version conflict
dataset = gr.Dropdown(label="Dataset", choices=datasets)
image_name = gr.Dropdown(label="Image", choices=[])
with gr.Row(elem_id="ui_body"):
# TODO: add ability to search image by typing
with gr.Column(scale=1, min_width=600):
image = gr.Image(type="filepath").style(height=512)
with gr.Column(scale=1, min_width=600):
prompts = gr.Dropdown(
label="Prompts",
choices=[],
)
prompt = gr.Textbox(
label="Editor",
lines=3,
)
with gr.Row():
save = gr.Button("Save")
delete = gr.Button("Delete")
with gr.Row():
back_image = gr.Button("Back")
next_image = gr.Button("Next")
finish = gr.Button("Finish")
def filter_datasets(dataset):
if dataset is None:
return gr.Dropdown.update(value=None, choices=[])
# create the dataset dir if doesn't exist and download prompt file
dataset_path = str(shark_root) + "/dataset/" + dataset
if not os.path.exists(dataset_path):
os.mkdir(dataset_path)
# read prompt jsonlines file
prompt_data.clear()
if dataset in ds_w_prompts:
prompt_gs_path = args.gs_url + "/" + dataset + "/metadata.jsonl"
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
for line in reader.iter(type=dict, skip_invalid=True):
prompt_data[line["file_name"]] = (
[line["text"]]
if type(line["text"]) is str
else line["text"]
)
return gr.Dropdown.update(choices=images[dataset])
dataset.change(fn=filter_datasets, inputs=dataset, outputs=image_name)
def display_image(dataset, image_name):
if dataset is None or image_name is None:
return gr.Image.update(value=None), gr.Dropdown.update(value=None)
# download and load the image
img_gs_path = args.gs_url + "/" + dataset + "/" + image_name
img_sub_path = "/".join(image_name.split("/")[:-1])
img_dst_path = (
str(shark_root) + "/dataset/" + dataset + "/" + img_sub_path + "/"
)
if not os.path.exists(img_dst_path):
os.mkdir(img_dst_path)
os.system(f'gsutil cp "{img_gs_path}" "{img_dst_path}"')
img = Image.open(img_dst_path + image_name.split("/")[-1])
if image_name not in prompt_data.keys():
prompt_data[image_name] = []
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Image.update(value=img), gr.Dropdown.update(
choices=prompt_choices
)
image_name.change(
fn=display_image,
inputs=[dataset, image_name],
outputs=[image, prompts],
)
def edit_prompt(prompts):
if prompts == "Add new":
return gr.Textbox.update(value=None)
return gr.Textbox.update(value=prompts)
prompts.change(fn=edit_prompt, inputs=prompts, outputs=prompt)
def save_prompt(dataset, image_name, prompts, prompt):
if (
dataset is None
or image_name is None
or prompts is None
or prompt is None
):
return
if prompts == "Add new":
prompt_data[image_name].append(prompt)
else:
idx = prompt_data[image_name].index(prompts)
prompt_data[image_name][idx] = prompt
prompt_path = (
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
)
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
if not value:
continue
v = value if len(value) > 1 else value[0]
f.write(json.dumps({"file_name": key, "text": v}))
f.write("\n")
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Dropdown.update(choices=prompt_choices, value=None)
save.click(
fn=save_prompt,
inputs=[dataset, image_name, prompts, prompt],
outputs=prompts,
)
def delete_prompt(dataset, image_name, prompts):
if dataset is None or image_name is None or prompts is None:
return
if prompts == "Add new":
return
prompt_data[image_name].remove(prompts)
prompt_path = (
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
)
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
if not value:
continue
v = value if len(value) > 1 else value[0]
f.write(json.dumps({"file_name": key, "text": v}))
f.write("\n")
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Dropdown.update(choices=prompt_choices, value=None)
delete.click(
fn=delete_prompt,
inputs=[dataset, image_name, prompts],
outputs=prompts,
)
def get_back_image(dataset, image_name):
if dataset is None or image_name is None:
return
# remove local image
img_path = str(shark_root) + "/dataset/" + dataset + "/" + image_name
os.system(f'rm "{img_path}"')
# get the index for the back image
idx = images[dataset].index(image_name)
if idx == 0:
return gr.Dropdown.update(value=None)
return gr.Dropdown.update(value=images[dataset][idx - 1])
back_image.click(
fn=get_back_image, inputs=[dataset, image_name], outputs=image_name
)
def get_next_image(dataset, image_name):
if dataset is None or image_name is None:
return
# remove local image
img_path = str(shark_root) + "/dataset/" + dataset + "/" + image_name
os.system(f'rm "{img_path}"')
# get the index for the next image
idx = images[dataset].index(image_name)
if idx == len(images[dataset]) - 1:
return gr.Dropdown.update(value=None)
return gr.Dropdown.update(value=images[dataset][idx + 1])
next_image.click(
fn=get_next_image, inputs=[dataset, image_name], outputs=image_name
)
def finish_annotation(dataset):
if dataset is None:
return
# upload prompt and remove local data
dataset_path = str(shark_root) + "/dataset/" + dataset
dataset_gs_path = args.gs_url + "/" + dataset + "/"
os.system(
f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"'
)
os.system(f'rm -rf "{dataset_path}"')
return gr.Dropdown.update(value=None)
finish.click(fn=finish_annotation, inputs=dataset, outputs=dataset)
if __name__ == "__main__":
shark_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

34
dataset/args.py Normal file
View File

@@ -0,0 +1,34 @@
import argparse
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Dataset Annotator flags
##############################################################################
p.add_argument(
"--gs_url",
type=str,
required=True,
help="URL to datasets in GS bucket",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
##############################################################################
args = p.parse_args()

3
dataset/requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
# SHARK Annotator
gradio==3.15.0
jsonlines

29
dataset/utils.py Normal file
View File

@@ -0,0 +1,29 @@
from google.cloud import storage
def get_datasets(gs_url):
datasets = set()
images = dict()
ds_w_prompts = []
storage_client = storage.Client()
bucket_name = gs_url.split("/")[2]
source_blob_name = "/".join(gs_url.split("/")[3:])
blobs = storage_client.list_blobs(bucket_name, prefix=source_blob_name)
for blob in blobs:
dataset_name = blob.name.split("/")[1]
if dataset_name == "":
continue
datasets.add(dataset_name)
if dataset_name not in images.keys():
images[dataset_name] = []
# check if image or jsonl
file_sub_path = "/".join(blob.name.split("/")[2:])
if "/" in file_sub_path:
images[dataset_name] += [file_sub_path]
elif "metadata.jsonl" in file_sub_path:
ds_w_prompts.append(dataset_name)
return list(datasets), images, ds_w_prompts

View File

@@ -2,33 +2,26 @@
"""SHARK Tank"""
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
# will generate local shark tank folder like this:
# HOME
# /.local
# /shark_tank
# /albert_lite_base
# /...model_name...
# /SHARK
# /gen_shark_tank
# /albert_lite_base
# /...model_name...
#
import os
import csv
import argparse
from shark.shark_importer import SharkImporter
from shark.parser import shark_args
import tensorflow as tf
import subprocess as sp
import hashlib
import numpy as np
from pathlib import Path
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
from apps.stable_diffusion.src.models import (
model_wrappers as mw,
)
from apps.stable_diffusion.src.utils.stable_args import (
args,
)
def create_hash(file_name):
@@ -41,9 +34,12 @@ def create_hash(file_name):
def save_torch_model(torch_model_list):
from tank.model_utils import get_hf_model
from tank.model_utils import get_vision_model
from tank.model_utils import get_hf_img_cls_model
from tank.model_utils import (
get_hf_model,
get_vision_model,
get_hf_img_cls_model,
get_fp16_model,
)
with open(torch_model_list) as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")
@@ -59,13 +55,39 @@ def save_torch_model(torch_model_list):
model = None
input = None
if model_type == "stable_diffusion":
args.use_tuned = False
args.import_mlir = True
args.use_tuned = False
args.local_tank_cache = WORKDIR
precision_values = ["fp16"]
seq_lengths = [64, 77]
for precision_value in precision_values:
args.precision = precision_value
for length in seq_lengths:
model = mw.SharkifyStableDiffusionModel(
model_id=torch_model_name,
custom_weights="",
precision=precision_value,
max_len=length,
width=512,
height=512,
use_base_vae=False,
debug=True,
sharktank_dir=WORKDIR,
generate_vmfb=False,
)
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(torch_model_name)
elif model_type == "hf":
model, input, _ = get_hf_model(torch_model_name)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(torch_model_name)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name)
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"
@@ -106,6 +128,17 @@ def save_tf_model(tf_model_list):
get_keras_model,
get_TFhf_model,
)
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
with open(tf_model_list) as csvfile:
tf_reader = csv.reader(csvfile, delimiter=",")
@@ -201,51 +234,48 @@ def is_valid_file(arg):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--torch_model_csv",
type=lambda x: is_valid_file(x),
default="./tank/torch_model_list.csv",
help="""Contains the file with torch_model name and args.
Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
)
parser.add_argument(
"--tf_model_csv",
type=lambda x: is_valid_file(x),
default="./tank/tf_model_list.csv",
help="Contains the file with tf model name and args.",
)
parser.add_argument(
"--tflite_model_csv",
type=lambda x: is_valid_file(x),
default="./tank/tflite/tflite_model_list.csv",
help="Contains the file with tf model name and args.",
)
parser.add_argument(
"--ci_tank_dir",
type=bool,
default=False,
)
parser.add_argument("--upload", type=bool, default=False)
# Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--torch_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/torch_model_list.csv",
# help="""Contains the file with torch_model name and args.
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
# )
# parser.add_argument(
# "--tf_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tf_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--tflite_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tflite/tflite_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--ci_tank_dir",
# type=bool,
# default=False,
# )
# parser.add_argument("--upload", type=bool, default=False)
args = parser.parse_args()
# old_args = parser.parse_args()
home = str(Path.home())
if args.ci_tank_dir == True:
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
else:
WORKDIR = os.path.join(home, ".local/shark_tank/")
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
torch_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tf_model_list.csv"
)
tflite_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
)
if args.torch_model_csv:
save_torch_model(args.torch_model_csv)
if args.tf_model_csv:
save_tf_model(args.tf_model_csv)
if args.tflite_model_csv:
save_tflite_model(args.tflite_model_csv)
if args.upload:
git_hash = sp.getoutput("git log -1 --format='%h'") + "/"
print("uploading files to gs://shark_tank/" + git_hash)
os.system(f"gsutil cp -r {WORKDIR}* gs://shark_tank/" + git_hash)
save_torch_model(torch_model_csv)
save_tf_model(tf_model_csv)
save_tflite_model(tflite_model_csv)

View File

@@ -1,3 +1,3 @@
[pytest]
addopts = --verbose -p no:warnings
norecursedirs = inference tank/tflite
norecursedirs = inference tank/tflite examples benchmarks shark

View File

@@ -28,6 +28,7 @@ Pillow
# web dependecies.
gradio
altair
# Testing and support.
#lit

View File

@@ -3,6 +3,8 @@
numpy==1.22.4
torchvision
pytorch-triton
tabulate
tqdm
@@ -13,7 +15,7 @@ iree-tools-tf
# TensorFlow and JAX.
gin-config
tensorflow==2.10
tensorflow==2.10.1
keras==2.10
#tf-models-nightly
#tensorflow-text-nightly
@@ -34,6 +36,7 @@ sacremoses
# web dependecies.
gradio
altair
scipy
#ONNX and ORT for benchmarking

View File

@@ -1,6 +1,5 @@
setuptools
wheel
pyinstaller
# SHARK Runner
tqdm
@@ -11,6 +10,7 @@ google-cloud-storage
# Testing
pytest
pytest-xdist
pytest-forked
Pillow
parameterized
@@ -20,3 +20,10 @@ diffusers
scipy
ftfy
gradio
altair
omegaconf
safetensors
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller

View File

@@ -2,11 +2,12 @@ from setuptools import find_packages
from setuptools import setup
import os
import glob
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.4"
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5"
backend_deps = []
if "NO_BACKEND" in os.environ.keys():
backend_deps = [
@@ -34,6 +35,7 @@ setup(
],
packages=find_packages(exclude=("examples")),
python_requires=">=3.9",
data_files=glob.glob("apps/stable_diffusion/resources/**"),
install_requires=[
"numpy",
"PyYAML",

View File

@@ -1,3 +1,9 @@
param([string]$arguments)
if ($arguments -eq "--update-src"){
git pull
}
#Write-Host "Installing python"
#Start-Process winget install Python.Python.3.10 '/quiet InstallAllUsers=1 PrependPath=1' -wait -NoNewWindow
@@ -35,6 +41,5 @@ pip install --pre torch-mlir torch torchvision --extra-index-url https://downloa
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
pip install diffusers transformers scipy pillow gradio
Write-Host "Build and installation completed successfully"
Write-Host "Source your venv with ./shark.venv/Scripts/activate"

View File

@@ -123,8 +123,13 @@ fi
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
T_VER=$($PYTHON -m pip show torch | grep Version)
TORCH_VERSION=${T_VER:9:17}
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
TV_VERSION=${TV_VER:9:18}
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu117
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu117."
else

View File

@@ -1,6 +1,6 @@
import torchdynamo
import torch
import torch_mlir
import torch._dynamo as torchdynamo
from shark.sharkdynamo.utils import make_shark_compiler

View File

@@ -36,7 +36,9 @@
" from torchdynamo.optimizations.backends import create_backend\n",
" from torchdynamo.optimizations.subgraph import SubGraph\n",
"except ModuleNotFoundError:\n",
" print(\"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\")\n",
" print(\n",
" \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n",
" )\n",
" exit()\n",
"\n",
"# torch-mlir imports for compiling\n",
@@ -97,7 +99,9 @@
"\n",
" for node in fx_g.graph.nodes:\n",
" if node.op == \"output\":\n",
" assert len(node.args) == 1, \"Output node must have a single argument\"\n",
" assert (\n",
" len(node.args) == 1\n",
" ), \"Output node must have a single argument\"\n",
" node_arg = node.args[0]\n",
" if isinstance(node_arg, tuple) and len(node_arg) == 1:\n",
" node.args = (node_arg[0],)\n",
@@ -116,8 +120,12 @@
" if len(args) == 1 and isinstance(args[0], list):\n",
" args = args[0]\n",
"\n",
" linalg_module = compile(ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS)\n",
" callable, _ = get_iree_compiled_module(linalg_module, \"cuda\", func_name=\"forward\")\n",
" linalg_module = compile(\n",
" ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n",
" )\n",
" callable, _ = get_iree_compiled_module(\n",
" linalg_module, \"cuda\", func_name=\"forward\"\n",
" )\n",
"\n",
" def forward(*inputs):\n",
" return callable(*inputs)\n",
@@ -212,6 +220,7 @@
" assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n",
" return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n",
"\n",
"\n",
"@torchdynamo.optimize(\"torch_mlir\")\n",
"def toy_example2(*args):\n",
" a, b = args\n",

View File

@@ -128,7 +128,6 @@ def load_mlir(mlir_loc):
def compile_through_fx(model, inputs, mlir_loc=None):
module = load_mlir(mlir_loc)
if module == None:
fx_g = make_fx(

View File

@@ -151,7 +151,6 @@ class DLRM_Net(nn.Module):
and (ln_top is not None)
and (arch_interaction_op is not None)
):
# save arguments
self.output_d = 0
self.arch_interaction_op = arch_interaction_op
@@ -216,7 +215,6 @@ class DLRM_Net(nn.Module):
return ly
def interact_features(self, x, ly):
if self.arch_interaction_op == "dot":
# concatenate dense and sparse features
(batch_size, d) = x.shape

View File

@@ -99,7 +99,6 @@ class SparseArchShark(nn.Module):
)
def forward(self, *batched_inputs):
concatenated_list = []
input_enum, embedding_enum = 0, 0
@@ -121,7 +120,6 @@ class SparseArchShark(nn.Module):
def test_sparse_arch() -> None:
D = 3
eb1_config = EmbeddingBagConfig(
name="t1",
@@ -211,7 +209,6 @@ class DLRMShark(nn.Module):
def forward(
self, dense_features: torch.Tensor, *sparse_features
) -> torch.Tensor:
embedded_dense = self.dense_arch(dense_features)
embedded_sparse = self.sparse_arch(*sparse_features)
concatenated_dense = self.inter_arch(

View File

@@ -1,272 +0,0 @@
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
import torch
from PIL import Image
from diffusers import LMSDiscreteScheduler
from tqdm.auto import tqdm
from shark.shark_inference import SharkInference
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import torch_mlir
import tempfile
import numpy as np
# pip install diffusers
# pip install scipy
############### Parsing args #####################
import argparse
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
p.add_argument(
"--prompt",
type=str,
default="a photograph of an astronaut riding a horse",
help="the text prompt to use",
)
p.add_argument("--device", type=str, default="cpu", help="the device to use")
p.add_argument("--steps", type=int, default=10, help="the device to use")
p.add_argument("--mlir_loc", type=str, default=None, help="the device to use")
p.add_argument("--vae_loc", type=str, default=None, help="the device to use")
args = p.parse_args()
#####################################################
def load_mlir(mlir_loc):
import os
if mlir_loc == None:
return None
print(f"Trying to load the model from {mlir_loc}.")
with open(os.path.join(mlir_loc)) as f:
mlir_module = f.read()
return mlir_module
def compile_through_fx(model, inputs, mlir_loc=None, extra_args=[]):
module = load_mlir(mlir_loc)
if mlir_loc == None:
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(*inputs)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
module = torch_mlir.compile(
ts_g,
inputs,
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
mlir_model = module
func_name = "forward"
shark_module = SharkInference(
mlir_model,
func_name,
device=args.device,
mlir_dialect="tm_tensor",
)
shark_module.compile(extra_args)
return shark_module
if __name__ == "__main__":
YOUR_TOKEN = "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"
# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="vae",
use_auth_token=YOUR_TOKEN,
)
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="vae",
use_auth_token=YOUR_TOKEN,
)
def forward(self, input):
return self.vae.decode(input, return_dict=False)[0]
vae = VaeModel()
vae_input = torch.rand(1, 4, 64, 64)
shark_vae = compile_through_fx(vae, (vae_input,), args.vae_loc)
# Wrap the unet model to return tuples.
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="unet",
use_auth_token=YOUR_TOKEN,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, x, y, z):
return self.unet.forward(x, y, z, return_dict=False)[0]
# 3. The UNet model for generating the latents.
unet = UnetModel()
latent_model_input = torch.rand([2, 4, 64, 64])
text_embeddings = torch.rand([2, 77, 768])
shark_unet = compile_through_fx(
unet,
(latent_model_input, torch.tensor([1.0]), text_embeddings),
args.mlir_loc,
["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
)
# torch.jit.script(unet)
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
prompt = [args.prompt]
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
num_inference_steps = args.steps # Number of denoising steps
guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.manual_seed(
42
) # Seed generator to create the inital latent noise
batch_size = len(prompt)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = text_encoder(text_input.input_ids)[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = text_encoder(uncond_input.input_ids)[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
# latents = latents.to(torch_device)
scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.sigmas[0]
# print(latents, latents.shape)
for i, t in tqdm(enumerate(scheduler.timesteps)):
print(f"i = {i} t = {t}")
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
sigma = scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
# with torch.no_grad():
# noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
latent_model_input_numpy = latent_model_input.detach().numpy()
text_embeddings_numpy = text_embeddings.detach().numpy()
noise_pred = shark_unet.forward(
(
latent_model_input_numpy,
np.array([t]).astype(np.float32),
text_embeddings_numpy,
)
)
noise_pred = torch.from_numpy(noise_pred)
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
# print("Latents shape : ", latents.shape)
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents_numpy = latents.detach().numpy()
image = shark_vae.forward((latents_numpy,))
image = torch.from_numpy(image)
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
pil_images[0].save("astro.jpg")

View File

@@ -1,280 +0,0 @@
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
import torch
from PIL import Image
from diffusers import LMSDiscreteScheduler
from tqdm.auto import tqdm
from shark.shark_inference import SharkInference
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import torch_mlir
import tempfile
import numpy as np
# pip install diffusers
# pip install scipy
############### Parsing args #####################
import argparse
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
p.add_argument(
"--prompt",
type=str,
default="a photograph of an astronaut riding a horse",
help="the text prompt to use",
)
p.add_argument("--device", type=str, default="cpu", help="the device to use")
p.add_argument("--steps", type=int, default=50, help="the device to use")
p.add_argument("--mlir_loc", type=str, default=None, help="the device to use")
p.add_argument("--vae_loc", type=str, default=None, help="the device to use")
args = p.parse_args()
#####################################################
def fp16_unet():
from shark.shark_downloader import download_model
mlir_model, func_name, inputs, golden_out = download_model(
"stable_diff_f16_18_OCT",
tank_url="gs://shark_tank/prashant_nod",
frontend="torch",
)
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
)
shark_module.compile()
return shark_module
def load_mlir(mlir_loc):
import os
if mlir_loc == None:
return None
print(f"Trying to load the model from {mlir_loc}.")
with open(os.path.join(mlir_loc)) as f:
mlir_module = f.read()
return mlir_module
def compile_through_fx(model, inputs, mlir_loc=None):
module = load_mlir(mlir_loc)
if mlir_loc == None:
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(*inputs)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
module = torch_mlir.compile(
ts_g,
inputs,
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
mlir_model = module
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
)
shark_module.compile()
return shark_module
if __name__ == "__main__":
YOUR_TOKEN = "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"
# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="vae",
use_auth_token=YOUR_TOKEN,
)
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="vae",
use_auth_token=YOUR_TOKEN,
)
def forward(self, input):
return self.vae.decode(input, return_dict=False)[0]
vae = VaeModel()
vae_input = torch.rand(1, 4, 64, 64)
shark_vae = compile_through_fx(vae, (vae_input,), args.vae_loc)
# Wrap the unet model to return tuples.
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="unet",
use_auth_token=YOUR_TOKEN,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, x, y, z):
return self.unet.forward(x, y, z, return_dict=False)[0]
# # 3. The UNet model for generating the latents.
unet = UnetModel()
shark_unet = fp16_unet()
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
prompt = [args.prompt]
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
num_inference_steps = args.steps # Number of denoising steps
guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.manual_seed(
42
) # Seed generator to create the inital latent noise
batch_size = len(prompt)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = text_encoder(text_input.input_ids)[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = text_encoder(uncond_input.input_ids)[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
# latents = latents.to(torch_device)
scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.sigmas[0]
# print(latents, latents.shape)
for i, t in tqdm(enumerate(scheduler.timesteps)):
print(f"i = {i} t = {t}")
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
sigma = scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
# with torch.no_grad():
# noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
latent_model_input_numpy = (
latent_model_input.detach().numpy().astype(np.half)
)
text_embeddings_numpy = (
text_embeddings.detach().numpy().astype(np.half)
)
noise_pred = shark_unet.forward(
(
latent_model_input_numpy,
np.array([t]).astype(np.half),
text_embeddings_numpy,
)
)
noise_pred = torch.from_numpy(noise_pred).to(torch.float32)
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
# print("Latents shape : ", latents.shape)
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents_numpy = latents.detach().numpy()
image = shark_vae.forward((latents_numpy,))
image = torch.from_numpy(image)
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
pil_images[0].save("astro.jpg")

View File

@@ -1,313 +0,0 @@
import math
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras_cv.models.generative.stable_diffusion.clip_tokenizer import (
SimpleTokenizer,
)
from keras_cv.models.generative.stable_diffusion.constants import (
_ALPHAS_CUMPROD,
)
from keras_cv.models.generative.stable_diffusion.constants import (
_UNCONDITIONAL_TOKENS,
)
from keras_cv.models.generative.stable_diffusion.decoder import Decoder
from keras_cv.models.generative.stable_diffusion.text_encoder import (
TextEncoder,
)
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_model
from PIL import Image
# pip install "git+https://github.com/keras-team/keras-cv.git"
# pip install tensorflow_dataset
############### Parsing args #####################
import argparse
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
p.add_argument(
"--prompt",
type=str,
default="a photograph of an astronaut riding a horse",
help="the text prompt to use",
)
p.add_argument("--device", type=str, default="cpu", help="the device to use")
p.add_argument(
"--steps", type=int, default=10, help="the number of steps to use"
)
p.add_argument(
"--save_path",
type=str,
default=None,
help="the file to save the resulting image to. (default to <input prompt>.jpg)",
)
args = p.parse_args()
#####################################################
MAX_PROMPT_LENGTH = 77
class SharkStableDiffusion:
"""Shark implementation of Stable Diffusion based on model from keras_cv.
Stable Diffusion is a powerful image generation model that can be used,
among other things, to generate pictures according to a short text description
(called a "prompt").
Arguments:
device: Device to use with SHARK. Default: cpu
jit_compile: Whether to compile the underlying models to XLA.
This can lead to a significant speedup on some systems. Default: False.
References:
- [About Stable Diffusion](https://stability.ai/blog/stable-diffusion-announcement)
- [Original implementation](https://github.com/CompVis/stable-diffusion)
"""
def __init__(self, device="cpu", jit_compile=True):
self.img_height = 512
self.img_width = 512
self.tokenizer = SimpleTokenizer()
# Create models
self.text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
mlir_model, func_name, inputs, golden_out = download_model(
"stable_diff", tank_url="gs://shark_tank/quinn", frontend="tf"
)
shark_module = SharkInference(
mlir_model, func_name, device=device, mlir_dialect="mhlo"
)
shark_module.compile()
self.diffusion_model = shark_module
self.decoder = Decoder(self.img_height, self.img_width)
if jit_compile:
self.text_encoder.compile(jit_compile=True)
self.decoder.compile(jit_compile=True)
print(
"By using this model checkpoint, you acknowledge that its usage is "
"subject to the terms of the CreativeML Open RAIL-M license at "
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE"
)
# Load weights
text_encoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
)
decoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5",
file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962",
)
self.text_encoder.load_weights(text_encoder_weights_fpath)
self.decoder.load_weights(decoder_weights_fpath)
def text_to_image(
self,
prompt,
batch_size=1,
num_steps=25,
unconditional_guidance_scale=7.5,
seed=None,
):
encoded_text = self.encode_text(prompt)
return self.generate_image(
encoded_text,
batch_size=batch_size,
num_steps=num_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
seed=seed,
)
def encode_text(self, prompt):
"""Encodes a prompt into a latent text encoding.
The encoding produced by this method should be used as the
`encoded_text` parameter of `StableDiffusion.generate_image`. Encoding
text separately from generating an image can be used to arbitrarily
modify the text encoding priot to image generation, e.g. for walking
between two prompts.
Args:
prompt: a string to encode, must be 77 tokens or shorter.
Example:
```python
from keras_cv.models import StableDiffusion
model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
encoded_text = model.encode_text("Tacos at dawn")
img = model.generate_image(encoded_text)
```
"""
# Tokenize prompt (i.e. starting context)
inputs = self.tokenizer.encode(prompt)
if len(inputs) > MAX_PROMPT_LENGTH:
raise ValueError(
f"Prompt is too long (should be <= {MAX_PROMPT_LENGTH} tokens)"
)
phrase = inputs + [49407] * (MAX_PROMPT_LENGTH - len(inputs))
phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)
context = self.text_encoder.predict_on_batch(
[phrase, self._get_pos_ids()]
)
return context
def generate_image(
self,
encoded_text,
batch_size=1,
num_steps=25,
unconditional_guidance_scale=7.5,
diffusion_noise=None,
seed=None,
):
"""Generates an image based on encoded text.
The encoding passed to this method should be derived from
`StableDiffusion.encode_text`.
Args:
encoded_text: Tensor of shape (`batch_size`, 77, 768), or a Tensor
of shape (77, 768). When the batch axis is omitted, the same encoded
text will be used to produce every generated image.
batch_size: number of images to generate. Default: 1.
num_steps: number of diffusion steps (controls image quality).
Default: 25.
unconditional_guidance_scale: float controling how closely the image
should adhere to the prompt. Larger values result in more
closely adhering to the prompt, but will make the image noisier.
Default: 7.5.
diffusion_noise: Tensor of shape (`batch_size`, img_height // 8,
img_width // 8, 4), or a Tensor of shape (img_height // 8,
img_width // 8, 4). Optional custom noise to seed the diffusion
process. When the batch axis is omitted, the same noise will be
used to seed diffusion for every generated image.
seed: integer which is used to seed the random generation of
diffusion noise, only to be specified if `diffusion_noise` is
None.
Example:
```python
from keras_cv.models import StableDiffusion
batch_size = 8
model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
e_tacos = model.encode_text("Tacos at dawn")
e_watermelons = model.encode_text("Watermelons at dusk")
e_interpolated = tf.linspace(e_tacos, e_watermelons, batch_size)
images = model.generate_image(e_interpolated, batch_size=batch_size)
```
"""
if diffusion_noise is not None and seed is not None:
raise ValueError(
"`diffusion_noise` and `seed` should not both be passed to "
"`generate_image`. `seed` is only used to generate diffusion "
"noise when it's not already user-specified."
)
encoded_text = tf.squeeze(encoded_text)
if encoded_text.shape.rank == 2:
encoded_text = tf.repeat(
tf.expand_dims(encoded_text, axis=0), batch_size, axis=0
)
context = encoded_text
unconditional_context = tf.repeat(
self._get_unconditional_context(), batch_size, axis=0
)
context = tf.concat([context, unconditional_context], 0)
if diffusion_noise is not None:
diffusion_noise = tf.squeeze(diffusion_noise)
if diffusion_noise.shape.rank == 3:
diffusion_noise = tf.repeat(
tf.expand_dims(diffusion_noise, axis=0), batch_size, axis=0
)
latent = diffusion_noise
else:
latent = self._get_initial_diffusion_noise(batch_size, seed)
# Iterative reverse diffusion stage
timesteps = tf.range(1, 1000, 1000 // num_steps)
alphas, alphas_prev = self._get_initial_alphas(timesteps)
progbar = keras.utils.Progbar(len(timesteps))
iteration = 0
for index, timestep in list(enumerate(timesteps))[::-1]:
latent_prev = latent # Set aside the previous latent vector
t_emb = self._get_timestep_embedding(timestep, batch_size)
# Prepare the latent and unconditional latent to be run with a single forward call
latent = tf.concat([latent, latent], 0)
t_emb = tf.concat([t_emb, t_emb], 0)
latent_numpy = self.diffusion_model.forward(
[latent.numpy(), t_emb.numpy(), context.numpy()]
)
latent = tf.convert_to_tensor(latent_numpy, dtype=tf.float32)
latent, unconditional_latent = tf.split(latent, 2)
latent = unconditional_latent + unconditional_guidance_scale * (
latent - unconditional_latent
)
a_t, a_prev = alphas[index], alphas_prev[index]
pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(
a_t
)
latent = (
latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0
)
iteration += 1
progbar.update(iteration)
# Decoding stage
decoded = self.decoder.predict_on_batch(latent)
decoded = ((decoded + 1) / 2) * 255
return np.clip(decoded, 0, 255).astype("uint8")
def _get_unconditional_context(self):
unconditional_tokens = tf.convert_to_tensor(
[_UNCONDITIONAL_TOKENS], dtype=tf.int32
)
unconditional_context = self.text_encoder.predict_on_batch(
[unconditional_tokens, self._get_pos_ids()]
)
return unconditional_context
def _get_timestep_embedding(
self, timestep, batch_size, dim=320, max_period=10000
):
half = dim // 2
freqs = tf.math.exp(
-math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
)
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
embedding = tf.reshape(embedding, [1, -1])
return tf.repeat(embedding, batch_size, axis=0)
def _get_initial_alphas(self, timesteps):
alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
alphas_prev = [1.0] + alphas[:-1]
return alphas, alphas_prev
def _get_initial_diffusion_noise(self, batch_size, seed):
return tf.random.normal(
(batch_size, self.img_height // 8, self.img_width // 8, 4),
seed=seed,
)
@staticmethod
def _get_pos_ids():
return tf.convert_to_tensor(
[list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32
)
if __name__ == "__main__":
SD = SharkStableDiffusion(device=args.device)
images = SD.text_to_image(args.prompt, num_steps=args.steps)
pil_images = [Image.fromarray(image) for image in images]
save_fname = args.prompt + ".jpg"
if args.save_path is not None:
save_fname = args.save_path
pil_images[0].save(save_fname)

View File

@@ -1,2 +0,0 @@
*.vmfb
*.jpg

View File

@@ -1,44 +0,0 @@
# STABLE DIFFUSION
## Installation
Follow setup instructions in the main [README.md](https://github.com/nod-ai/SHARK#readme) for regular usage.
## Debug commands and other advanced usage follows.
```shell
python main.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
```
## dump all dispatch .spv and isa using amdllpc
```shell
python main.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
```
## Compile and save the .vmfb (using vulkan fp16 as an example):
```shell
python shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
```
## Capture an RGP trace
```shell
python shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
```
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
```shell
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --device=vulkan --function_input=1x4x64x64xf16
```
## Run the unet module with iree-benchmark-module (same config as above):
```shell
##if you want to use .npz inputs:
unzip ~/.local/shark_tank/<your unet>/inputs.npz
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --function_input=@arr_0.npy --function_input=1xf16 --function_input=@arr_2.npy --function_input=@arr_3.npy --function_input=@arr_4.npy
```

View File

@@ -1,25 +0,0 @@
from PIL import Image
import requests
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(
text=["a photo of a cat", "a photo of a dog"],
images=image,
return_tensors="pt",
padding=True,
)
outputs = model(**inputs)
logits_per_image = (
outputs.logits_per_image
) # this is the image-text similarity score
probs = logits_per_image.softmax(
dim=1
) # we can take the softmax to get the label probabilities

View File

@@ -1,188 +0,0 @@
from transformers import CLIPTextModel, CLIPTokenizer
import torch
from PIL import Image
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from tqdm.auto import tqdm
import numpy as np
from stable_args import args
from utils import get_shark_model, set_iree_runtime_flags
from opt_params import get_unet, get_vae, get_clip
import time
from model_wrappers import get_vae_mlir
from shark.iree_utils.compile_utils import dump_isas
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device:
import iree
print(f"Profiling and saving to {file_path}.")
vulkan_device = iree.runtime.get_device(args.device)
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
return vulkan_device
return None
def end_profiling(device):
if device:
return device.end_profiling()
if __name__ == "__main__":
dtype = torch.float32 if args.precision == "fp32" else torch.half
prompt = args.prompts
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
if args.version == "v2":
height = 768
width = 768
num_inference_steps = args.steps # Number of denoising steps
# Scale for classifier-free guidance
guidance_scale = torch.tensor(args.guidance_scale).to(torch.float32)
generator = torch.manual_seed(
args.seed
) # Seed generator to create the inital latent noise
batch_size = len(prompt)
set_iree_runtime_flags()
unet = get_unet()
vae = get_vae()
clip = get_clip()
if args.dump_isa:
dump_isas(args.dispatch_benchmarks_dir)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
scheduler = DPMSolverMultistepScheduler.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="scheduler",
)
if args.version == "v2":
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="tokenizer"
)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
"stabilityai/stable-diffusion-2",
subfolder="scheduler",
)
if args.version == "v2.1base":
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
)
scheduler = EulerDiscreteScheduler.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
subfolder="scheduler",
)
start = time.time()
text_input = tokenizer(
prompt,
padding="max_length",
max_length=args.max_length,
truncation=True,
return_tensors="pt",
)
clip_inf_start = time.time()
text_embeddings = clip.forward((text_input.input_ids,))
clip_inf_end = time.time()
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_clip_inf_start = time.time()
uncond_embeddings = clip.forward((uncond_input.input_ids,))
uncond_clip_inf_end = time.time()
uncond_embeddings = torch.from_numpy(uncond_embeddings).to(dtype)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents = torch.randn(
(batch_size, 4, height // 8, width // 8),
generator=generator,
dtype=torch.float32,
).to(dtype)
scheduler.set_timesteps(num_inference_steps)
scheduler.is_scale_input_called = True
latents = latents * scheduler.init_noise_sigma
text_embeddings_numpy = text_embeddings.detach().numpy()
avg_ms = 0
for i, t in tqdm(enumerate(scheduler.timesteps)):
step_start = time.time()
print(f"i = {i} t = {t}", end="")
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = scheduler.scale_model_input(latents, t)
latents_numpy = latent_model_input.detach().numpy()
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = unet.forward(
(
latents_numpy,
timestep,
text_embeddings_numpy,
guidance_scale,
)
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
step_time = time.time() - step_start
avg_ms += step_time
step_ms = int((step_time) * 1000)
print(f" ({step_ms}ms)")
latents = scheduler.step(noise_pred, t, latents).prev_sample
avg_ms = 1000 * avg_ms / args.steps
print(f"Average step time: {avg_ms}ms/it")
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
# latents = latents.
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
image = vae.forward((latents_numpy,))
vae_end = time.time()
end_profiling(profile_device)
image = torch.from_numpy(image)
image = image.detach().cpu().permute(0, 2, 3, 1) * 255.0
images = image.numpy().round().astype("uint8")
total_end = time.time()
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
uncond_clip_inf_time = (uncond_clip_inf_end - uncond_clip_inf_start) * 1000
avg_clip_inf = (clip_inf_time + uncond_clip_inf_time) / 2
vae_inf_time = (vae_end - vae_start) * 1000
print(
f"Clip Inference Avg time (ms) = ({clip_inf_time:.3f} + {uncond_clip_inf_time:.3f}) / 2 = {avg_clip_inf:.3f}"
)
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
print(f"Total image generation runtime (s): {total_end - start:.4f}")
pil_images = [Image.fromarray(image) for image in images]
for i in range(batch_size):
pil_images[i].save(f"{args.prompts[i]}_{i}.jpg")

View File

@@ -1,173 +0,0 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from transformers import CLIPTextModel
from utils import compile_through_fx
from stable_args import args
import torch
BATCH_SIZE = len(args.prompts)
model_config = {
"v2": "stabilityai/stable-diffusion-2",
"v1.4": "CompVis/stable-diffusion-v1-4",
}
model_input = {
"v2": {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 96, 96),),
"unet": (
torch.randn(1, 4, 96, 96), # latents
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 1024), # embedding
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v1.4": {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 64, 64),),
"unet": (
torch.randn(1, 4, 64, 64),
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 768),
torch.tensor(1).to(torch.float32),
),
},
}
# revision param for from_pretrained defaults to "main" => fp32
model_revision = "fp16" if args.precision == "fp16" else "main"
def get_clip_mlir(model_name="clip_text", extra_args=[]):
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
if args.version == "v2":
text_encoder = CLIPTextModel.from_pretrained(
model_config[args.version], subfolder="text_encoder"
)
class CLIPText(torch.nn.Module):
def __init__(self):
super().__init__()
self.text_encoder = text_encoder
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText()
shark_clip = compile_through_fx(
clip_model,
model_input[args.version]["clip"],
model_name=model_name,
extra_args=extra_args,
)
return shark_clip
def get_vae_mlir(model_name="vae", extra_args=[]):
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version],
subfolder="vae",
revision=model_revision,
)
def forward(self, input):
x = self.vae.decode(input, return_dict=False)[0]
return (x / 2 + 0.5).clamp(0, 1)
vae = VaeModel()
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
shark_vae = compile_through_fx(
vae,
inputs,
model_name=model_name,
extra_args=extra_args,
)
return shark_vae
def get_vae_encode_mlir(model_name="vae_encode", extra_args=[]):
class VaeEncodeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version],
subfolder="vae",
revision="fp16",
)
def forward(self, x):
input = 2 * (x - 0.5)
return self.vae.encode(input, return_dict=False)[0]
vae = VaeEncodeModel()
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input[args.version]["vae"]]
)
shark_vae = compile_through_fx(
vae,
inputs,
model_name=model_name,
extra_args=extra_args,
)
return shark_vae
def get_unet_mlir(model_name="unet", extra_args=[]):
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_config[args.version],
subfolder="unet",
revision=model_revision,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, latent, timestep, text_embedding, guidance_scale):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel()
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input[args.version]["unet"]
]
)
else:
inputs = model_input[args.version]["unet"]
shark_unet = compile_through_fx(
unet,
inputs,
model_name=model_name,
extra_args=extra_args,
)
return shark_unet

View File

@@ -1,153 +0,0 @@
import sys
from model_wrappers import (
get_vae_mlir,
get_vae_encode_mlir,
get_unet_mlir,
get_clip_mlir,
)
from stable_args import args
from utils import get_shark_model
BATCH_SIZE = len(args.prompts)
if BATCH_SIZE != 1:
sys.exit("Only batch size 1 is supported.")
def get_unet():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Tuned model is present for `fp16` precision.
if args.precision == "fp16":
if args.use_tuned:
bucket = "gs://shark_tank/vivian"
model_name = "unet_1dec_fp16_tuned"
return get_shark_model(bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_1dec_fp16"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
# Tuned model is not present for `fp32` case.
if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_1dec_fp32"
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "int8":
bucket = "gs://shark_tank/prashant_nod"
model_name = "unet_int8"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
]
sys.exit("int8 model is currently in maintenance.")
# # TODO: Pass iree_flags to the exported model.
# if args.import_mlir:
# sys.exit(
# "--import_mlir is not supported for the int8 model, try --no-import_mlir flag."
# )
# return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
if args.precision in ["fp16", "int8"]:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_1dec_fp16"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_1dec_fp32"
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae_encode():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
if args.precision in ["fp16", "int8"]:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_encode_1dec_fp16"
if args.version == "v2":
model_name = "vae2_encode_29nov_fp16"
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
]
if args.import_mlir:
return get_vae_encode_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_encode_1dec_fp32"
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
bucket = "gs://shark_tank/stable_diffusion"
model_name = "clip_1dec_fp32"
if args.version == "v2.1base":
model_name = "clip2base_8dec_fp32"
iree_flags += [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops",
]
if args.import_mlir:
return get_clip_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)

View File

@@ -1,44 +0,0 @@
Compile / Run Instructions:
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
Compile Commands FP32/FP16:
```shell
Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=mhlo for tf models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
CPU:
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
```
Run / Benchmark Command (FP32 - NCHW):
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
```shell
## Vulkan AMD:
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --device=vulkan --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
## CUDA:
iree-benchmark-module --module_file=/path/to/vmfb --entry_function=forward --device=cuda --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
## CPU:
iree-benchmark-module --module_file=/path/to/vmfb --entry_function=forward --device=local-task --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
```
Run via vulkan_gui for RGP Profiling:
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
```shell
./build/vulkan_gui/iree-vulkan-gui --module_file=/path/to/unet.vmfb --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
```

View File

@@ -1,111 +0,0 @@
# Stable Diffusion optimized for AMD RDNA2/RDNA3 GPUs
## Install the latest AMD Drivers
### RDNA2 Drivers:
*AMD Software: Adrenalin Edition 22.11.1 for MLIR/IREE Driver Version 22.20.29.09 for Windows® 10 and Windows® 11 (Windows Driver Store Version 31.0.12029.9003)*
https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mlir-iree
Note that if you previously tried Stable Diffusion with a different driver, it may be necessary to clear vulkan cache after changing drivers.
For Windows users this can be done by clearing the contents of `C:\Users\<username>\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
## Installation
Download the latest Windows `shark_sd.exe` [here](https://storage.googleapis.com/shark-public/anush/windows/shark_sd_05122022v3.exe). Accept if Windows warns of an unsigned .exe.
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
Here are some samples generated:
![tajmahal, snow, sunflowers, oil on canvas_0](https://user-images.githubusercontent.com/74956/204934186-141f7e43-6eb2-4e89-a99c-4704d20444b3.jpg)
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
<details>
<summary>Advanced Installation </summary>
## Setup your Python VirtualEnvironment and Dependencies
### Windows 10/11 Users
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
* Install Git for Windows from [here](https://git-scm.com/download/win)
#### Allow the install script to run in Powershell
```powershell
set-executionpolicy remotesigned
```
#### Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...)
```powershell
git clone https://github.com/nod-ai/SHARK.git
cd SHARK
./setup_venv.ps1 #You can re-run this script to get the latest version
```
### Linux
```shell
git clone https://github.com/nod-ai/SHARK.git
cd SHARK
./setup_venv.sh
source shark.venv/bin/activate
```
### Run Stable Diffusion on your device - WebUI
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\Users\nod\SHARK> cd web
(shark.venv) PS C:\Users\nod\SHARK\web> python index.py
```
#### Linux Users
```shell
(shark.venv) > cd web
(shark.venv) > python index.py
```
### Run Stable Diffusion on your device - Commandline
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\shark\examples\shark_inference\stable_diffusion\main.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux
```shell
python3.10 shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
The output on a 6900XT would like:
```shell
44it [00:08, 5.14it/s]i = 44 t = 120 (191ms)
45it [00:08, 5.15it/s]i = 45 t = 100 (191ms)
46it [00:08, 5.16it/s]i = 46 t = 80 (191ms)
47it [00:09, 5.16it/s]i = 47 t = 60 (193ms)
48it [00:09, 5.15it/s]i = 48 t = 40 (195ms)
49it [00:09, 5.12it/s]i = 49 t = 20 (196ms)
50it [00:09, 5.14it/s]
Average step time: 192.8154182434082ms/it
Total image generation runtime (s): 10.390909433364868
(shark.venv) PS C:\g\shark>
```
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)
</details>
<details>
<summary>Discord link</summary>
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
</details>

View File

@@ -1,83 +0,0 @@
import os
import torch
from shark.shark_inference import SharkInference
from stable_args import args
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), extended_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.shark_downloader import download_model
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=tank_url,
frontend="torch",
)
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
)
return _compile_module(shark_module, model_name, extra_args)
# Converts the torch-module into a shark_module.
def compile_through_fx(model, inputs, model_name, extra_args=[]):
mlir_module, func_name = import_with_fx(model, inputs)
shark_module = SharkInference(
mlir_module,
func_name,
device=args.device,
mlir_dialect="linalg",
)
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
]
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if "vulkan" in args.device:
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
return

View File

@@ -0,0 +1,21 @@
import requests
from PIL import Image
from io import BytesIO
from pipeline_shark_stable_diffusion_upscale import (
SharkStableDiffusionUpscalePipeline,
)
import torch
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipeline = SharkStableDiffusionUpscalePipeline(model_id)
# let's download an image
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
response = requests.get(url)
low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
low_res_img = low_res_img.resize((128, 128))
prompt = "a white cat"
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
upscaled_image.save("upsampled_cat.png")

View File

@@ -0,0 +1,98 @@
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from utils import compile_through_fx
import torch
model_id = "stabilityai/stable-diffusion-x4-upscaler"
model_input = {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 128, 128),),
"unet": (
torch.randn(2, 7, 128, 128), # latents
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 1024), # embedding
torch.randn(2).to(torch.int64), # noise_level
),
}
def get_clip_mlir(model_name="clip_text", extra_args=[]):
text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
)
class CLIPText(torch.nn.Module):
def __init__(self):
super().__init__()
self.text_encoder = text_encoder
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText()
shark_clip = compile_through_fx(
clip_model,
model_input["clip"],
model_name=model_name,
extra_args=extra_args,
)
return shark_clip
def get_vae_mlir(model_name="vae", extra_args=[]):
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
)
def forward(self, input):
x = self.vae.decode(input, return_dict=False)[0]
return x
vae = VaeModel()
shark_vae = compile_through_fx(
vae,
model_input["vae"],
model_name=model_name,
extra_args=extra_args,
)
return shark_vae
def get_unet_mlir(model_name="unet", extra_args=[]):
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, latent, timestep, text_embedding, noise_level):
unet_out = self.unet.forward(
latent,
timestep,
text_embedding,
noise_level,
return_dict=False,
)[0]
return unet_out
unet = UnetModel()
f16_input_mask = (True, True, True, False)
shark_unet = compile_through_fx(
unet,
model_input["unet"],
model_name=model_name,
is_f16=True,
f16_input_mask=f16_input_mask,
extra_args=extra_args,
)
return shark_unet

View File

@@ -0,0 +1,48 @@
import sys
from model_wrappers import (
get_vae_mlir,
get_unet_mlir,
get_clip_mlir,
)
from upscaler_args import args
from utils import get_shark_model
BATCH_SIZE = len(args.prompts)
if BATCH_SIZE != 1:
sys.exit("Only batch size 1 is supported.")
unet_flag = [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
vae_flag = [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
clip_flag = [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
bucket = "gs://shark_tank/stable_diffusion/"
def get_unet():
model_name = "upscaler_unet"
if args.import_mlir:
return get_unet_mlir(model_name, unet_flag)
return get_shark_model(bucket, model_name, unet_flag)
def get_vae():
model_name = "upscaler_vae"
if args.import_mlir:
return get_vae_mlir(model_name, vae_flag)
return get_shark_model(bucket, model_name, vae_flag)
def get_clip():
model_name = "upscaler_clip"
if args.import_mlir:
return get_clip_mlir(model_name, clip_flag)
return get_shark_model(bucket, model_name, clip_flag)

View File

@@ -0,0 +1,489 @@
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import PIL
from PIL import Image
from diffusers.utils import is_accelerate_available
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers import (
DDIMScheduler,
DDPMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers import logging
from diffusers.pipeline_utils import ImagePipelineOutput
from opt_params import get_unet, get_vae, get_clip
from tqdm.auto import tqdm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
def shark_run_wrapper(model, *args):
np_inputs = tuple([x.detach().numpy() for x in args])
outputs = model("forward", np_inputs)
return torch.from_numpy(outputs)
class SharkStableDiffusionUpscalePipeline:
def __init__(
self,
model_id,
):
self.tokenizer = CLIPTokenizer.from_pretrained(
model_id, subfolder="tokenizer"
)
self.low_res_scheduler = DDPMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
self.scheduler = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
self.vae = get_vae()
self.unet = get_unet()
self.text_encoder = get_clip()
self.max_noise_level = (350,)
self._execution_device = "cpu"
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
# if (
# hasattr(self.text_encoder.config, "use_attention_mask")
# and self.text_encoder.config.use_attention_mask
# ):
# attention_mask = text_inputs.attention_mask.to(device)
# else:
# attention_mask = None
text_embeddings = shark_run_wrapper(
self.text_encoder, text_input_ids.to(device)
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
# if (
# hasattr(self.text_encoder.config, "use_attention_mask")
# and self.text_encoder.config.use_attention_mask
# ):
# attention_mask = uncond_input.attention_mask.to(device)
# else:
# attention_mask = None
uncond_embeddings = shark_run_wrapper(
self.text_encoder,
uncond_input.input_ids.to(device),
)
uncond_embeddings = uncond_embeddings
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(
1, num_images_per_prompt, 1
)
uncond_embeddings = uncond_embeddings.view(
batch_size * num_images_per_prompt, seq_len, -1
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
def decode_latents(self, latents):
latents = 1 / 0.08333 * latents
image = shark_run_wrapper(self.vae, latents)
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def check_inputs(self, prompt, image, noise_level, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
)
# verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, list) or isinstance(image, torch.Tensor):
if isinstance(prompt, str):
batch_size = 1
else:
batch_size = len(prompt)
if isinstance(image, list):
image_batch_size = len(image)
else:
image_batch_size = image.shape[0]
if batch_size != image_batch_size:
raise ValueError(
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
" Please make sure that passed `prompt` matches the batch size of `image`."
)
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [
Image.fromarray(image.squeeze(), mode="L") for image in images
]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
if device == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(
shape, generator=generator, device="cpu", dtype=dtype
).to(device)
else:
latents = torch.randn(
shape, generator=generator, device=device, dtype=dtype
)
else:
if latents.shape != shape:
raise ValueError(
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
)
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[
torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]
],
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[
Union[torch.Generator, List[torch.Generator]]
] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[
Callable[[int, int, torch.FloatTensor], None]
] = None,
callback_steps: Optional[int] = 1,
):
# 1. Check inputs
self.check_inputs(prompt, image, noise_level, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
# 4. Preprocess image
image = preprocess(image)
image = image.to(dtype=text_embeddings.dtype, device=device)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Add noise to image
noise_level = torch.tensor(
[noise_level], dtype=torch.long, device=device
)
if device == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(
image.shape,
generator=generator,
device="cpu",
dtype=text_embeddings.dtype,
).to(device)
else:
noise = torch.randn(
image.shape,
generator=generator,
device=device,
dtype=text_embeddings.dtype,
)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1
image = torch.cat([image] * batch_multiplier * num_images_per_prompt)
noise_level = torch.cat([noise_level] * image.shape[0])
# 6. Prepare latent variables
height, width = image.shape[2:]
# num_channels_latents = self.vae.config.latent_channels
num_channels_latents = 4
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 7. Check that sizes of image and latents match
num_channels_image = image.shape[1]
# if (
# num_channels_latents + num_channels_image
# != self.unet.config.in_channels
# ):
# raise ValueError(
# f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
# f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
# f" `num_channels_image`: {num_channels_image} "
# f" = {num_channels_latents+num_channels_image}. Please verify the config of"
# " `pipeline.unet` or your `image` input."
# )
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 9. Denoising loop
num_warmup_steps = (
len(timesteps) - num_inference_steps * self.scheduler.order
)
for i, t in tqdm(enumerate(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2)
if do_classifier_free_guidance
else latents
)
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = torch.cat([latent_model_input, image], dim=1)
timestep = torch.tensor([t]).to(torch.float32)
# predict the noise residual
noise_pred = shark_run_wrapper(
self.unet,
latent_model_input.half(),
timestep,
text_embeddings.half(),
noise_level,
)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
# # call the callback, if provided
# if i == len(timesteps) - 1 or (
# (i + 1) > num_warmup_steps
# and (i + 1) % self.scheduler.order == 0
# ):
# progress_bar.update()
# if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
# 10. Post-processing
# make sure the VAE is in float32 mode, as it overflows in float16
# self.vae.to(dtype=torch.float32)
image = self.decode_latents(latents.float())
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@@ -4,15 +4,24 @@ p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Stable Diffusion Params
##############################################################################
p.add_argument(
"--prompts",
nargs="+",
default=["a photograph of an astronaut riding a horse"],
default=["cyberpunk forest by Salvador Dali"],
help="text of which images to be generated.",
)
p.add_argument(
"--device", type=str, default="cpu", help="device to run the model."
"--negative-prompts",
nargs="+",
default=[""],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--steps",
type=int,
@@ -20,19 +29,13 @@ p.add_argument(
help="the no. of steps to do the sampling.",
)
p.add_argument(
"--version",
type=str,
default="v1.4",
help="Specify version of stable diffusion model",
)
p.add_argument(
"--seed",
type=int,
default=42,
help="the seed to use.",
)
p.add_argument(
"--guidance_scale",
type=float,
@@ -40,6 +43,18 @@ p.add_argument(
help="the value to be used for guidance scaling.",
)
##############################################################################
### Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
@@ -47,17 +62,6 @@ p.add_argument(
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
)
p.add_argument(
"--precision", type=str, default="fp32", help="precision to run the model."
)
p.add_argument(
"--max_length",
type=int,
default=77,
help="max length of the tokenizer output.",
)
p.add_argument(
"--load_vmfb",
default=True,
@@ -72,6 +76,10 @@ p.add_argument(
help="saves the compiled flatbuffer to the local directory",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree-vulkan-target-triple",
type=str,
@@ -86,43 +94,18 @@ p.add_argument(
help="Profiles vulkan device and collects the .rdc info",
)
p.add_argument(
"--use_tuned",
default=False,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="4294967296",
default="4147483648",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
p.add_argument(
"--enable_rgp",
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
help="flag for disabling vulkan validation layers when benchmarking",
)
args = p.parse_args()

View File

@@ -0,0 +1,232 @@
import os
import torch
from shark.shark_inference import SharkInference
from upscaler_args import args
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), extended_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.shark_downloader import download_model
from shark.parser import shark_args
# Set local shark_tank cache directory.
# shark_args.local_tank_cache = args.local_tank_cache
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=tank_url,
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
)
return _compile_module(shark_module, model_name, extra_args)
# Converts the torch-module into a shark_module.
def compile_through_fx(
model, inputs, model_name, is_f16=False, f16_input_mask=None, extra_args=[]
):
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def set_init_device_flags():
if "vulkan" in args.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "cpu" in args.device:
args.device = "cpu"
# set max_length based on availability.
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
args.max_length = 77
elif args.variant == "openjourney":
args.max_length = 64
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
if (
args.variant in ["openjourney", "dreamlike"]
or args.precision != "fp16"
or "vulkan" not in args.device
or "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
print("Tuned models are currently not supported for this setting.")
elif args.use_base_vae and args.variant != "stablediffusion":
args.use_tuned = False
print("Tuned models are currently not supported for this setting.")
if args.use_tuned:
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
# Utility to get list of devices available.
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
for i, device in enumerate(device_list_dict):
device_list.append(f"{driver_name}://{i} => {device['name']}")
return device_list
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
return available_devices

View File

@@ -1,7 +1,7 @@
import torch
from torch.nn.utils import _stateless
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark.shark_runner import SharkTrainer
from shark.shark_trainer import SharkTrainer
class MiniLMSequenceClassification(torch.nn.Module):
@@ -42,6 +42,7 @@ def forward(params, buffers, args):
return params, buffers
shark_module = SharkTrainer(mod, inp, custom_inference_fn=forward)
shark_module = SharkTrainer(mod, inp)
shark_module.compile(forward)
print(shark_module.forward())
print(shark_module.train())

View File

@@ -169,6 +169,7 @@ imagenet_style_templates_small = [
"a large painting in the style of {}",
]
# Setup the dataset
class TextualInversionDataset(Dataset):
def __init__(
@@ -184,7 +185,6 @@ class TextualInversionDataset(Dataset):
placeholder_token="*",
center_crop=False,
):
self.data_root = data_root
self.tokenizer = tokenizer
self.learnable_property = learnable_property
@@ -244,7 +244,10 @@ class TextualInversionDataset(Dataset):
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)

View File

@@ -21,7 +21,6 @@ import torch
from iree.runtime import DeviceArray
from torch_mlir._mlir_libs._mlir.ir import Module
from torch_mlir.compiler_utils import (
get_module_name_for_debug_dump,
run_pipeline_with_repro_report,
)
from torch_mlir.eager_mode.torch_mlir_eager_backend import (
@@ -64,14 +63,13 @@ class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend):
)
def compile(self, imported_module: Module):
fn_name = get_module_name_for_debug_dump(imported_module)
run_pipeline_with_repro_report(
imported_module,
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline",
"EagerMode",
)
callable, _ = get_iree_compiled_module(
imported_module, self.raw_device_str, func_name=fn_name
imported_module, self.raw_device_str
)
return callable

View File

@@ -33,61 +33,17 @@ def run_cmd(cmd):
)
result_str = result.stdout.decode()
return result_str
except Exception:
sys.exit("Exiting program due to error running:", cmd)
except subprocess.CalledProcessError as e:
print(e.output)
sys.exit(f"Exiting program due to error running {cmd}")
def iree_device_map(device):
from iree.runtime import get_driver, get_device
def get_all_devices(driver_name):
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list = []
for device_dict in device_list_src:
device_list.append(f"{driver_name}://{device_dict['path']}")
device_list.sort()
return device_list
# only supported for vulkan as of now
if "vulkan://" in device:
device_list = get_all_devices("vulkan")
_, d_index = device.split("://")
matched_index = None
match_with_index = False
if 0 <= len(d_index) <= 2:
try:
d_index = int(d_index)
except:
print(
f"{d_index} is not valid index or uri. Will choose device 0"
)
d_index = 0
match_with_index = True
if len(device_list) > 1:
print("List of available vulkan devices:")
for i, d in enumerate(device_list):
print(f"vulkan://{i} => {d}")
if (match_with_index and d_index == i) or (
not match_with_index and d == device
):
matched_index = i
print(
f"Choosing device vulkan://{matched_index}\nTo choose another device please specify device index or uri accordingly."
)
return get_device(device_list[matched_index])
elif len(device_list) == 1:
print(f"Found one vulkan device: {device_list[0]}. Using this.")
return get_device(device_list[0])
else:
print(
f"No device found! returning device corresponding to driver name: vulkan"
)
return _IREE_DEVICE_MAP["vulkan"]
uri_parts = device.split("://", 2)
if len(uri_parts) == 1:
return _IREE_DEVICE_MAP[uri_parts[0]]
else:
return _IREE_DEVICE_MAP[device]
return f"{_IREE_DEVICE_MAP[uri_parts[0]]}://{uri_parts[1]}"
def get_supported_device_list():
@@ -119,6 +75,7 @@ _IREE_TARGET_MAP = {
"intel-gpu": "opencl-spirv",
}
# Finds whether the required drivers are installed for the given device.
def check_device_drivers(device):
"""Checks necessary drivers present for gpu and vulkan devices"""

View File

@@ -18,6 +18,7 @@ from shark.iree_utils.cpu_utils import get_cpu_count
import numpy as np
import os
import re
import platform
UNIT_TO_SECOND_MAP = {"us": 1e-6, "ms": 0.001, "s": 1}
@@ -62,24 +63,33 @@ def build_benchmark_args(
Outputs: string that execute benchmark-module on target model.
"""
path = benchmark_module.__path__[0]
benchmarker_path = os.path.join(path, "..", "..", "iree-benchmark-module")
benchmark_cl = [benchmarker_path, f"--module_file={input_file}"]
if platform.system() == "Windows":
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module.exe"
)
time_extractor = None
else:
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module"
)
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
# TODO: The function named can be passed as one of the args.
fn_name = "forward"
if training == True:
# TODO: Replace name of train with actual train fn name.
fn_name = "train"
benchmark_cl.append(f"--entry_function={fn_name}")
benchmark_cl.append(f"--function={fn_name}")
benchmark_cl.append(f"--device={iree_device_map(device)}")
mlir_input_types = tensor_to_type_str(input_tensors, mlir_dialect)
for mlir_input in mlir_input_types:
benchmark_cl.append(f"--function_input={mlir_input}")
benchmark_cl.append(f"--input={mlir_input}")
if device == "cpu":
num_cpus = get_cpu_count()
if num_cpus is not None:
benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}")
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmark_cl.append(time_extractor)
# if time_extractor:
# benchmark_cl.append(time_extractor)
return benchmark_cl
@@ -96,16 +106,24 @@ def build_benchmark_args_non_tensor_input(
Outputs: string that execute benchmark-module on target model.
"""
path = benchmark_module.__path__[0]
benchmarker_path = os.path.join(path, "..", "..", "iree-benchmark-module")
benchmark_cl = [benchmarker_path, f"--module_file={input_file}"]
if platform.system() == "Windows":
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module.exe"
)
else:
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module"
)
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
# TODO: The function named can be passed as one of the args.
if function_name:
benchmark_cl.append(f"--entry_function={function_name}")
benchmark_cl.append(f"--function={function_name}")
benchmark_cl.append(f"--device={iree_device_map(device)}")
for input in inputs:
benchmark_cl.append(f"--function_input={input}")
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmark_cl.append(time_extractor)
benchmark_cl.append(f"--input={input}")
if platform.system() != "Windows":
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmark_cl.append(time_extractor)
return benchmark_cl
@@ -121,8 +139,9 @@ def run_benchmark_module(benchmark_cl):
benchmark_path
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
bench_result = run_cmd(" ".join(benchmark_cl))
regex_split = re.compile("([0-9]+[.]*[0-9]*)([a-zA-Z]+)")
match = regex_split.match(bench_result)
print(bench_result)
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
unit = match.group(2)
return 1.0 / (time * UNIT_TO_SECOND_MAP[unit])
unit = match.group(3)
return 1.0 / (time * 0.001)

View File

@@ -20,23 +20,30 @@ import numpy as np
import os
import re
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
if "://" in device:
device = device.split("://")[0]
if device == "cpu":
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan"]:
print(
f"Specific device selection only supported for vulkan now."
f"Proceeding with {device} as device."
)
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
return get_iree_cpu_args()
if device == "cuda":
if device_uri[0] == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
return get_iree_gpu_args()
if device in ["metal", "vulkan"]:
if device_uri[0] in ["metal", "vulkan"]:
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(extra_args=extra_args)
if device == "rocm":
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args()
@@ -63,6 +70,7 @@ def get_iree_common_args():
return [
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
]
@@ -73,7 +81,17 @@ def get_iree_common_args():
def get_model_specific_args():
ms_args = []
if shark_args.enable_conv_transform == True:
ms_args += ["--iree-flow-enable-conv-nchw-to-nhwc-transform"]
ms_args += [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc))"
]
if shark_args.enable_img2col_transform == True:
ms_args += [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col))"
]
if shark_args.use_winograd == True:
ms_args += [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-linalg-ext-convert-conv2d-to-winograd))"
]
return ms_args
@@ -136,7 +154,6 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
in_dispatches = True
if all_dispatches or in_dispatches:
for f_ in os.listdir(f"{bench_dir}/{d_}"):
if "benchmark.mlir" in f_:
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
module = dispatch_file.read()
@@ -227,7 +244,6 @@ def compile_module_to_flatbuffer(
module,
device,
frontend,
func_name,
model_config_path,
extra_args,
model_name="None",
@@ -270,15 +286,25 @@ def compile_module_to_flatbuffer(
return flatbuffer_blob
def get_iree_module(flatbuffer_blob, device, func_name):
def get_iree_module(flatbuffer_blob, device, device_idx=None):
# Returns the compiled module and the configs.
config = get_iree_runtime_config(device)
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"]
)
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
ModuleCompiled = ctx.modules.module[func_name]
ModuleCompiled = ctx.modules.module
return ModuleCompiled, config
@@ -286,25 +312,22 @@ def get_iree_compiled_module(
module,
device: str,
frontend: str = "torch",
func_name: str = "forward",
model_config_path: str = None,
extra_args: list = [],
device_idx: int = None,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, func_name, model_config_path, extra_args
module, device, frontend, model_config_path, extra_args
)
return get_iree_module(flatbuffer_blob, device, func_name)
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
def load_flatbuffer(
flatbuffer_path: str, device: str, func_name: str = "forward"
):
def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None):
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
return get_iree_module(flatbuffer_blob, device, func_name)
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
def export_iree_module_to_vmfb(
@@ -312,20 +335,19 @@ def export_iree_module_to_vmfb(
device: str,
directory: str,
mlir_dialect: str = "linalg",
func_name: str = "forward",
model_config_path: str = None,
module_name: str = None,
extra_args: list = [],
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, mlir_dialect, func_name, model_config_path, extra_args
module, device, mlir_dialect, model_config_path, extra_args
)
if module_name is None:
device_name = (
device if "://" not in device else "-".join(device.split("://"))
)
module_name = f"{mlir_dialect}_{func_name}_{device_name}"
module_name = f"{mlir_dialect}_{device_name}"
filename = os.path.join(directory, module_name + ".vmfb")
print(f"Saved vmfb in {filename}.")
with open(filename, "wb") as f:
@@ -347,28 +369,39 @@ def export_module_to_mlir_file(module, frontend, directory: str):
return filename
def get_results(compiled_vm, input, config, frontend="torch"):
def get_results(
compiled_vm,
function_name,
input,
config,
frontend="torch",
send_to_host=True,
):
"""Runs a .vmfb file given inputs and config and returns output."""
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm(*device_inputs)
result = compiled_vm[function_name](*device_inputs)
result_tensors = []
if isinstance(result, tuple):
for val in result:
result_tensors.append(np.copy(np.asarray(val, val.dtype)))
if send_to_host:
for val in result:
result_tensors.append(np.asarray(val, val.dtype))
else:
for val in result:
result_tensors.append(val)
return result_tensors
elif isinstance(result, dict):
data = list(result.items())
res = np.array(data, dtype=object)
return np.copy(res)
if send_to_host:
res = np.array(data, dtype=object)
return np.copy(res)
return data
else:
return result.to_host()
if send_to_host and result is not None:
return result.to_host()
return result
def get_iree_runtime_config(device):
device = iree_device_map(device)
if type(device) == ireert.HalDevice:
config = ireert.Config(device=device)
else:
driver_name = device.split("://")[0] if "://" in device else device
config = ireert.Config(driver_name=driver_name)
config = ireert.Config(device=ireert.get_device(device))
return config

View File

@@ -15,6 +15,7 @@
# All the iree_cpu related functionalities go here.
import subprocess
import platform
def get_cpu_count():
@@ -29,25 +30,16 @@ def get_cpu_count():
# Get the default cpu args.
def get_iree_cpu_args():
find_triple_cmd = "uname -s -m"
os_name, proc_name = (
subprocess.run(
find_triple_cmd, shell=True, stdout=subprocess.PIPE, check=True
)
.stdout.decode("utf-8")
.split()
)
uname = platform.uname()
os_name, proc_name = uname.system, uname.machine
if os_name == "Darwin":
find_kernel_version_cmd = "uname -r"
kernel_version = subprocess.run(
find_kernel_version_cmd,
shell=True,
stdout=subprocess.PIPE,
check=True,
).stdout.decode("utf-8")
kernel_version = uname.release
target_triple = f"{proc_name}-apple-darwin{kernel_version}"
elif os_name == "Linux":
target_triple = f"{proc_name}-linux-gnu"
elif os_name == "Windows":
target_triple = "x86_64-pc-windows-msvc"
else:
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
raise Exception(error_message)

View File

@@ -18,14 +18,16 @@ import iree.runtime as ireert
import ctypes
from shark.parser import shark_args
# Get the default gpu args given the architecture.
def get_iree_gpu_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
ireert.flags.parse_flags("--cuda_allow_inline_execution")
ireert.flags.parse_flags("--cuda_allow_inline_execution", "--device_allocator=caching")
# TODO: Give the user_interface to pass the sm_arch.
sm_arch = get_cuda_sm_cc()
if (
sm_arch in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86"]
sm_arch
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
) and (shark_args.enable_tf32 == True):
return [
"--iree-hal-cuda-disable-loop-nounroll-wa",
@@ -38,8 +40,17 @@ def get_iree_gpu_args():
# Get the default gpu args given the architecture.
def get_iree_rocm_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
# TODO: find a way to get arch from code.
rocm_arch = "gfx908"
# get arch from rocminfo.
import re
import subprocess
rocm_arch = re.match(
r".*(gfx\w+)",
subprocess.check_output(
"rocminfo | grep -i 'gfx'", shell=True, text=True
),
).group(1)
print(f"Found rocm arch {rocm_arch}...")
return [
f"--iree-rocm-target-chip={rocm_arch}",
"--iree-rocm-link-bc=true",
@@ -56,7 +67,7 @@ CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36
def get_cuda_sm_cc():
libnames = ("libcuda.so", "libcuda.dylib", "cuda.dll")
libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll")
for libname in libnames:
try:
cuda = ctypes.CDLL(libname)

View File

@@ -0,0 +1,462 @@
# Copyright 2020 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
def get_vulkan_target_env(vulkan_target_triple):
arch, product, os = vulkan_target_triple.split("=")[1].split("-")
triple = (arch, product, os)
# get version
version = get_version(triple=triple)
# TODO get revision
revision = 120
# extensions
extensions = get_extensions(triple)
# get vendor
vendor = get_vendor(triple)
# get device type
device_type = get_device_type(triple)
# get capabilities
capabilities = get_vulkan_target_capabilities(triple)
target_env = f"#vk.target_env<{version}, r({revision}), {extensions}, {vendor}:{device_type}, #vk.caps< {capabilities} >>"
return target_env
def get_vulkan_target_env_flag(vulkan_target_triple):
target_env = get_vulkan_target_env(vulkan_target_triple)
target_env_flag = f"--iree-vulkan-target-env={target_env}"
return target_env_flag
def get_version(triple):
arch, product, os = triple
if os in ["android30", "android31"]:
return "v1.1"
if product in ["android30", "android31"]:
return "v1.1"
if arch in ["unknown"]:
return "v1.1"
return "v1.3"
def get_extensions(triple):
def make_ext_list(ext_list):
res = ""
for e in ext_list:
res += e + ", "
res = f"[{res[:-2]}]"
return res
arch, product, os = triple
if arch == "m1":
ext = [
"VK_KHR_16bit_storage",
"VK_KHR_8bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
]
return make_ext_list(ext_list=ext)
if arch == "valhall":
ext = [
"VK_KHR_16bit_storage",
"VK_KHR_8bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_spirv_1_4",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
]
return make_ext_list(ext_list=ext)
if arch == "adreno":
ext = [
"VK_KHR_16bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_spirv_1_4",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
]
if os == "android31":
ext.append("VK_KHR_8bit_storage")
return make_ext_list(ext_list=ext)
if get_vendor(triple) == "SwiftShader":
ext = ["VK_KHR_storage_buffer_storage_class"]
return make_ext_list(ext_list=ext)
if arch == "unknown":
ext = [
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
]
return make_ext_list(ext_list=ext)
ext = [
"VK_KHR_16bit_storage",
"VK_KHR_8bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_spirv_1_4",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
"VK_EXT_subgroup_size_control",
]
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
ext.append("VK_NV_cooperative_matrix")
return make_ext_list(ext_list=ext)
def get_vendor(triple):
arch, product, os = triple
if arch == "unknown":
return "Unknown"
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn4", "rgcn5"]:
return "AMD"
if arch == "valhall":
return "ARM"
if arch == "m1":
return "Apple"
if arch in ["turing", "ampere"]:
return "NVIDIA"
if arch == "ardeno":
return "Qualcomm"
if arch == "cpu":
if product == "swiftshader":
return "SwiftShader"
return "Unknown"
print(f"Vendor for target triple - {triple} not found. Using unknown")
return "Unknown"
def get_device_type(triple):
arch, product, _ = triple
if arch == "unknown":
return "Unknown"
if arch == "cpu":
return "CPU"
if arch in ["turing", "ampere"]:
return "DiscreteGPU"
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]:
if product == "ivega10":
return "IntegratedGPU"
return "DiscreteGPU"
if arch in ["m1", "valhall", "adreno"]:
return "IntegratedGPU"
print(f"Device type for target triple - {triple} not found. Using unknown")
return "Unknown"
# get all the capabilities for the device
# TODO: make a dataclass for capabilites and init using vulkaninfo
def get_vulkan_target_capabilities(triple):
def get_subgroup_val(l):
return int(sum([subgroup_feature[sgf] for sgf in l]))
cap = OrderedDict()
arch, product, os = triple
subgroup_feature = {
"Basic": 1,
"Vote": 2,
"Arithmetic": 4,
"Ballot": 8,
"Shuffle": 16,
"ShuffleRelative": 32,
"Clustered": 64,
"Quad": 128,
"PartitionedNV": 256,
}
cap["maxComputeSharedMemorySize"] = 16384
cap["maxComputeWorkGroupInvocations"] = 128
cap["maxComputeWorkGroupSize"] = [128, 128, 64]
cap["subgroupSize"] = 32
cap["subgroupFeatures"] = ["Basic"]
cap["minSubgroupSize"] = None
cap["maxSubgroupSize"] = None
cap["shaderFloat16"] = False
cap["shaderFloat64"] = False
cap["shaderInt8"] = False
cap["shaderInt16"] = False
cap["shaderInt64"] = False
cap["storageBuffer16BitAccess"] = False
cap["storagePushConstant16"] = False
cap["uniformAndStorageBuffer16BitAccess"] = False
cap["storageBuffer8BitAccess"] = False
cap["storagePushConstant8"] = False
cap["uniformAndStorageBuffer8BitAccess"] = False
cap["variablePointers"] = False
cap["variablePointersStorageBuffer"] = False
cap["coopmatCases"] = None
if arch in ["rdna1", "rdna2", "rdna3"]:
cap["maxComputeSharedMemorySize"] = 65536
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
cap["subgroupSize"] = 64
cap["minSubgroupSize"] = 32
cap["maxSubgroupSize"] = 64
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderFloat64"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
if arch == "rdna3":
# TODO: Get scope value
cap["coopmatCases"] = [
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
]
if product == "rx5700xt":
cap["storagePushConstant16"] = False
cap["storagePushConstant8"] = False
elif arch in ["rgcn5", "rgcn4", "rgcn3"]:
cap["maxComputeSharedMemorySize"] = 65536
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
cap["subgroupSize"] = 64
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["minSubgroupSize"] = 64
cap["maxSubgroupSize"] = 64
if arch == "rgcn5":
cap["shaderFloat16"] = True
cap["shaderFloat64"] = True
cap["storageBuffer16BitAccess"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["storagePushConstant16"] = False
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = False
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "m1":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
cap["subgroupSize"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderFloat64"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "valhall":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 512
cap["maxComputeWorkGroupSize"] = [512, 512, 512]
cap["subgroupSize"] = 16
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Clustered",
"Quad",
]
if os == "android31":
cap["subgroupFeatures"].append("Shuffle")
cap["subgroupFeatures"].append("ShuffleRelative")
cap["shaderFloat16"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "cpu":
if product == "swiftshader":
cap["maxComputeSharedMemorySize"] = 16384
cap["subgroupSize"] = 4
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
]
elif arch in ["ampere", "turing"]:
cap["maxComputeSharedMemorySize"] = 49152
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
cap["subgroupSize"] = 32
cap["minSubgroupSize"] = 32
cap["maxSubgroupSize"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderFloat64"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
cap["coopmatCases"] = [
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
]
elif arch == "adreno":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
cap["subgroupSize"] = 64
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["storageBuffer16BitAccess"] = True
if os == "andorid31":
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "unknown":
cap["subgroupSize"] = 64
cap["variablePointers"] = False
cap["variablePointersStorageBuffer"] = False
else:
print(
f"Architecture {arch} not matched. Using default vulkan target device capability"
)
def get_comma_sep_str(ele_list):
l = ""
for ele in ele_list:
l += f"{ele}, "
l = f"[{l[:-2]}]"
return l
res = ""
for k, v in cap.items():
if v is None or v == False:
continue
if isinstance(v, bool):
res += f"{k} = {'unit' if v == True else None}, "
elif isinstance(v, list):
if k == "subgroupFeatures":
res += f"subgroupFeatures = {get_subgroup_val(v)}: i32, "
elif k == "maxComputeWorkGroupSize":
res += f"maxComputeWorkGroupSize = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
elif k == "coopmatCases":
cmc = ""
for case in v:
cmc += f"#vk.coop_matrix_props<{case}>, "
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
else:
res += f"{k} = {get_comma_sep_str(v)}, "
else:
res += f"{k} = {v}, "
res = res[:-2]
return res

View File

@@ -18,6 +18,7 @@ from os import linesep
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_vulkan_device_name():
@@ -26,9 +27,10 @@ def get_vulkan_device_name():
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
print(
f"Found {len(vulkaninfo_list)} device names. choosing first one: {vulkaninfo_list[0]}"
)
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing first one: {vulkaninfo_list[0]}")
return vulkaninfo_list[0]
@@ -44,56 +46,115 @@ def get_os_name():
return "linux"
def get_vulkan_triple_flag(extra_args=[]):
if "-iree-vulkan-target-triple=" in " ".join(extra_args):
print(f"Using target triple from command line args")
return None
def get_vulkan_target_triple(device_name):
"""This method provides a target triple str for specified vulkan device.
Args:
device_name (str): name of the hardware device to be used with vulkan
Returns:
str or None: target triple or None if no match found for given name
"""
system_os = get_os_name()
vulkan_device = get_vulkan_device_name()
if all(x in vulkan_device for x in ("Apple", "M1")):
print(f"Found {vulkan_device} Device. Using m1-moltenvk-macos")
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
elif all(x in vulkan_device for x in ("Apple", "M2")):
print("Found Apple M2 Device. Using m1-moltenvk-macos")
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
elif all(x in vulkan_device for x in ("A100", "SXM4")):
print(
f"Found {vulkan_device} Device. Using ampere-rtx3080-{system_os}"
)
return f"-iree-vulkan-target-triple=ampere-rtx3080-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "3090")):
print(
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
)
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "4090")):
print(
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
)
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
elif all(x in vulkan_device for x in ("AMD", "7900")):
print(f"Found {vulkan_device} Device. Using rdna3-7900-{system_os}")
return f"-iree-vulkan-target-triple=rdna3-7900-{system_os}"
elif any(x in vulkan_device for x in ("AMD", "Radeon")):
print(f"Found AMD device. Using rdna2-unknown-{system_os}")
return f"-iree-vulkan-target-triple=rdna2-unknown-{system_os}"
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"
# Nvidia Targets
elif all(x in device_name for x in ("RTX", "2080")):
triple = f"turing-rtx2080-{system_os}"
elif all(x in device_name for x in ("A100", "SXM4")):
triple = f"ampere-a100-{system_os}"
elif all(x in device_name for x in ("RTX", "3090")):
triple = f"ampere-rtx3090-{system_os}"
elif all(x in device_name for x in ("RTX", "3080")):
triple = f"ampere-rtx3080-{system_os}"
elif all(x in device_name for x in ("RTX", "3070")):
triple = f"ampere-rtx3070-{system_os}"
elif all(x in device_name for x in ("RTX", "3060")):
triple = f"ampere-rtx3060-{system_os}"
elif all(x in device_name for x in ("RTX", "3050")):
triple = f"ampere-rtx3050-{system_os}"
# We use ampere until lovelace target triples are plumbed in.
elif all(x in device_name for x in ("RTX", "4090")):
triple = f"ampere-rtx4090-{system_os}"
elif all(x in device_name for x in ("RTX", "4080")):
triple = f"ampere-rtx4080-{system_os}"
elif all(x in device_name for x in ("RTX", "4070")):
triple = f"ampere-rtx4070-{system_os}"
elif all(x in device_name for x in ("RTX", "4000")):
triple = f"turing-rtx4000-{system_os}"
elif all(x in device_name for x in ("RTX", "5000")):
triple = f"turing-rtx5000-{system_os}"
elif all(x in device_name for x in ("RTX", "6000")):
triple = f"turing-rtx6000-{system_os}"
elif all(x in device_name for x in ("RTX", "8000")):
triple = f"turing-rtx8000-{system_os}"
elif all(x in device_name for x in ("TITAN", "RTX")):
triple = f"turing-titanrtx-{system_os}"
elif all(x in device_name for x in ("GTX", "1060")):
triple = f"pascal-gtx1060-{system_os}"
elif all(x in device_name for x in ("GTX", "1070")):
triple = f"pascal-gtx1070-{system_os}"
elif all(x in device_name for x in ("GTX", "1080")):
triple = f"pascal-gtx1080-{system_os}"
# Amd Targets
# Linux: Radeon RX 7900 XTX
# Windows: AMD Radeon RX 7900 XTX
elif all(x in device_name for x in ("RX", "7900")):
triple = f"rdna3-7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
triple = f"rdna2-unknown-{system_os}"
else:
triple = None
return triple
def get_vulkan_triple_flag(device_name="", extra_args=[]):
for flag in extra_args:
if "-iree-vulkan-target-triple=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None
if device_name == "" or device_name == [] or device_name is None:
vulkan_device = get_vulkan_device_name()
else:
vulkan_device = device_name
triple = get_vulkan_target_triple(vulkan_device)
if triple is not None:
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
or pull up an issue."""
f"Found vulkan device {vulkan_device}. Using target triple {triple}"
)
print(f"Target : {vulkan_device}")
return None
return f"-iree-vulkan-target-triple={triple}"
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
or pull up an issue."""
)
print(f"Target : {vulkan_device}")
return None
def get_iree_vulkan_args(extra_args=[]):
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
vulkan_flag = []
vulkan_triple_flag = get_vulkan_triple_flag(extra_args)
res_vulkan_flag = ["--device_allocator=caching"]
vulkan_triple_flag = None
for arg in extra_args:
if "-iree-vulkan-target-triple=" in arg:
print(f"Using target triple {arg} from command line args")
vulkan_triple_flag = arg
break
if vulkan_triple_flag is None:
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
if vulkan_triple_flag is not None:
vulkan_flag.append(vulkan_triple_flag)
return vulkan_flag
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)
res_vulkan_flag.append(vulkan_target_env)
return res_vulkan_flag
def set_iree_vulkan_runtime_flags(flags):

View File

@@ -22,7 +22,7 @@ from shark.model_annotation import model_annotation
with create_context() as ctx:
module = model_annotation(ctx, input_contents=..., config_path=..., search_op=...)
2. Run model_annotation.py directly
python model_annotation.py path_to_original_mlir path_to_config_file
python model_annotation.py -model path_to_original_mlir -config_path path_to_config_file
"""
import json
@@ -39,21 +39,27 @@ def model_annotation(
*,
input_contents: str,
config_path: str,
search_op: str = "matmul",
search_op: str,
winograd: bool = False,
):
if os.path.isfile(input_contents):
with open(input_contents, "rb") as f:
input_contents = f.read()
module = ir.Module.parse(input_contents)
with open(config_path, "r") as f:
data = json.load(f)
configs = data["options"]
if config_path == "":
return module
if winograd:
with open(config_path, "r") as f:
data = json.load(f)
configs = data["c,f"]
else:
configs = load_model_configs(config_path)
# The Python API does not expose a general walk() function, so we just
# do it ourselves.
walk_children(module.operation, configs, 0, search_op)
walk_children(module.operation, configs, search_op, winograd)
if not module.operation.verify():
raise RuntimeError("Modified program does not verify!")
@@ -61,8 +67,42 @@ def model_annotation(
return module
def load_model_configs(config_path: str):
config = {}
with open(config_path, "r") as f:
for line in f:
data = json.loads(line)
if "identifier" not in data.keys():
continue
if data["identifier"] == "matmul":
matrix_size = [data["m"], data["n"], data["k"]]
elif data["identifier"] == "bmm":
matrix_size = [data["b"], data["m"], data["n"], data["k"]]
elif data["identifier"] == "generic":
matrix_size = [1, data["b"], data["m"], data["n"], data["k"]]
elif data["identifier"] == "conv":
matrix_size = [
data["n"],
data["ih"],
data["iw"],
data["c"],
data["kh"],
data["kw"],
data["f"],
data["oh"],
data["ow"],
data["d"],
data["s"],
data["p"],
]
config[shape_list_to_string(matrix_size)] = data
f.close()
return config
def walk_children(
op: ir.Operation, configs: List[Dict], idx: int, search_op: str
op: ir.Operation, configs: List[Dict], search_op: str, winograd: bool
):
if search_op == "matmul":
op_names = ["linalg.matmul", "mhlo.dot"]
@@ -70,6 +110,8 @@ def walk_children(
op_names = ["linalg.batch_matmul", "mhlo.dot_general"]
elif search_op == "conv":
op_names = ["mhlo.convolution", "linalg.conv_2d_nhwc_hwcf"]
elif search_op == "generic":
op_names = ["linalg.generic"]
elif search_op == "all":
op_names = [
"mhlo.dot",
@@ -78,6 +120,7 @@ def walk_children(
"linalg.matmul",
"linalg.batch_matmul",
"linalg.conv_2d_nhwc_hwcf",
"linalg.generic",
]
else:
raise ValueError(f"{search_op} op is not tunable.")
@@ -89,37 +132,171 @@ def walk_children(
# 'operation' and 'name' attributes.
if isinstance(child_op, ir.OpView):
child_op = child_op.operation
if child_op.name in op_names and idx < len(configs):
add_attributes(child_op, configs[idx])
idx = idx + 1
print(f"Updated op {child_op}", file=sys.stderr)
walk_children(child_op, configs, idx, search_op)
if winograd and child_op.name in [
"linalg.conv_2d_nchw_fchw",
"linalg.conv_2d_nhwc_hwcf",
]:
add_winograd_attribute(child_op, configs)
if child_op.name in op_names:
if child_op.name == "linalg.generic":
# This is for generic op that has contractionOpInterface
# which is basically einsum("mk,bkn->bmn")
op_result = str(child_op.results[0])
op_iterator = str(
child_op.attributes["iterator_types"]
)
if len(child_op.operands) != 3:
continue
if "reduction" not in op_iterator:
continue
if (
"arith.addf" not in op_result
or "arith.mulf" not in op_result
):
continue
if "arith.subf" in op_result:
continue
child_op_shape = get_op_shape(child_op, search_op)
if (
child_op_shape in configs.keys()
and configs[child_op_shape]["options"][0] != None
):
add_attributes(
child_op, configs[child_op_shape]["options"][0]
)
walk_children(child_op, configs, search_op, winograd)
def add_attributes(op: ir.Operation, config: Dict):
(
tile_sizes,
pipeline,
workgroup_size,
split_k,
pipeline_depth,
) = parse_config(config)
def get_op_shape(op: ir.Operation, search_op: str):
shape_list = []
if search_op in ["generic", "all"]:
if op.name in ["linalg.generic"]:
input1 = str(op.operands[0].type)
input2 = str(op.operands[1].type)
m = input1.split("tensor<")[1].split("x")[0]
b = input2.split("tensor<")[1].split("x")[0]
k = input2.split("tensor<")[1].split("x")[1]
n = input2.split("tensor<")[1].split("x")[2]
shape_list = [1, int(b), int(m), int(n), int(k)]
add_compilation_info(
op,
tile_sizes=tile_sizes,
pipeline=pipeline,
workgroup_size=workgroup_size,
pipeline_depth=pipeline_depth,
)
if search_op in ["matmul", "all"]:
if op.name in ["mhlo.dot"]:
op_result = str(op.results[0])
m = op_result.split("tensor<")[1].split("x")[0]
k = op_result.split("tensor<")[1].split("x")[1]
n = op_result.split("tensor<")[2].split("x")[1]
shape_list = [int(m), int(n), int(k)]
elif op.name in ["linalg.matmul"]:
op_result = str(op.results[0]).split("ins(")[1]
m = op_result.split("tensor<")[1].split("x")[0]
k = op_result.split("tensor<")[1].split("x")[1]
n = op_result.split("tensor<")[2].split("x")[1]
shape_list = [int(m), int(n), int(k)]
if split_k:
add_attribute_by_name(op, "iree_flow_split_k", split_k)
if search_op in ["bmm", "all"]:
if op.name in ["mhlo.dot_general"]:
op_result = str(op.results[0])
b = op_result.split("tensor<")[1].split("x")[1]
m = op_result.split("tensor<")[1].split("x")[2]
k = op_result.split("tensor<")[1].split("x")[3]
n = op_result.split("tensor<")[3].split("x")[3]
shape_list = [int(b), int(m), int(n), int(k)]
elif op.name in ["linalg.batch_matmul"]:
op_result = str(op.results[0]).split("ins(")[1]
b = op_result.split("tensor<")[1].split("x")[0]
m = op_result.split("tensor<")[1].split("x")[1]
k = op_result.split("tensor<")[1].split("x")[2]
n = op_result.split("tensor<")[3].split("x")[2]
shape_list = [int(b), int(m), int(n), int(k)]
if search_op in ["conv", "all"]:
if op.name in ["mhlo.convolution"]:
op_result = str(op.results[0])
dilation = (
str(op.attributes["rhs_dilation"])
.split("dense<")[1]
.split(">")[0]
)
stride = (
str(op.attributes["window_strides"])
.split("dense<")[1]
.split(">")[0]
)
pad = (
str(op.attributes["padding"]).split("dense<")[1].split(">")[0]
)
n = op_result.split("tensor<")[1].split("x")[0]
ih = op_result.split("tensor<")[1].split("x")[1]
iw = op_result.split("tensor<")[1].split("x")[2]
c = op_result.split("tensor<")[1].split("x")[3]
kh = op_result.split("tensor<")[2].split("x")[0]
kw = op_result.split("tensor<")[2].split("x")[1]
f = op_result.split("tensor<")[2].split("x")[3]
oh = op_result.split("tensor<")[3].split("x")[1]
ow = op_result.split("tensor<")[3].split("x")[2]
shape_list = [
int(n),
int(ih),
int(iw),
int(c),
int(kh),
int(kw),
int(f),
int(oh),
int(ow),
int(dilation),
int(stride),
int(pad),
]
elif op.name in ["linalg.conv_2d_nhwc_hwcf"]:
op_result = str(op.results[0]).split("ins(")[1]
dilation = (
str(op.attributes["dilations"])
.split("dense<")[1]
.split(">")[0]
)
stride = (
str(op.attributes["strides"]).split("dense<")[1].split(">")[0]
)
pad = 0
n = op_result.split("tensor<")[1].split("x")[0]
ih = op_result.split("tensor<")[1].split("x")[1]
iw = op_result.split("tensor<")[1].split("x")[2]
c = op_result.split("tensor<")[1].split("x")[3]
kh = op_result.split("tensor<")[2].split("x")[0]
kw = op_result.split("tensor<")[2].split("x")[1]
f = op_result.split("tensor<")[2].split("x")[3]
oh = op_result.split("tensor<")[3].split("x")[1]
ow = op_result.split("tensor<")[3].split("x")[2]
shape_list = [
int(n),
int(ih),
int(iw),
int(c),
int(kh),
int(kw),
int(f),
int(oh),
int(ow),
int(dilation),
int(stride),
int(pad),
]
shape_str = shape_list_to_string(shape_list)
return shape_str
def parse_config(config: Dict):
def add_attributes(op: ir.Operation, config: List[Dict]):
# Parse the config file
split_k = None
pipeline_depth = None
store_stage = None
subgroup_size = None
if "GPU" in config["pipeline"]:
pipeline = (
"LLVMGPUMatmulSimt"
@@ -139,11 +316,17 @@ def parse_config(config: Dict):
config["parallel_tile_sizes"],
config["reduction_tile_sizes"],
]
workgroup_size = config["work_group_sizes"]
if "vector_tile_sizes" in config.keys():
tile_sizes += [config["vector_tile_sizes"]]
if "window_tile_sizes" in config.keys():
tile_sizes += [config["window_tile_sizes"]]
workgroup_size = config["work_group_sizes"]
if "subgroup_size" in config.keys():
subgroup_size = config["subgroup_size"]
if "pipeline_depth" in config.keys():
pipeline_depth = config["pipeline_depth"]
if "store_stage" in config.keys():
store_stage = config["store_stage"]
else:
# For IREE CPU pipelines
pipeline = config["pipeline"]
@@ -153,40 +336,77 @@ def parse_config(config: Dict):
config["reduction_tile_sizes"],
]
workgroup_size = []
return tile_sizes, pipeline, workgroup_size, split_k, pipeline_depth
def add_compilation_info(
op: ir.Operation,
tile_sizes: List[List[int]],
pipeline: str,
workgroup_size: List[int],
pipeline_depth: int,
):
# We don't have a Python binding for CompilationInfo, so we just parse
# its string form.
if pipeline_depth:
attr = ir.Attribute.parse(
f"#iree_codegen.compilation_info<"
f"lowering_config = <tile_sizes = {repr(tile_sizes)}>, "
f"translation_info = <{pipeline} pipeline_depth = {pipeline_depth}>, "
f"workgroup_size = {repr(workgroup_size)}>"
)
# Add compilation info as an attribute. We don't have a Python binding for CompilationInfo,
# so we just parse its string form.
if pipeline_depth != None:
translation_info = f"{pipeline} pipeline_depth = {pipeline_depth}"
if store_stage != None:
translation_info += f" store_stage = {store_stage}"
else:
attr = ir.Attribute.parse(
f"#iree_codegen.compilation_info<"
f"lowering_config = <tile_sizes = {repr(tile_sizes)}>, "
f"translation_info = <{pipeline}>, "
f"workgroup_size = {repr(workgroup_size)}>"
)
translation_info = f"{pipeline}"
compilation_info = (
f"#iree_codegen.compilation_info<"
f"lowering_config = <tile_sizes = {repr(tile_sizes)}>, "
f"translation_info = <{translation_info}>, "
f"workgroup_size = {repr(workgroup_size)} "
)
if subgroup_size != None:
compilation_info += f", subgroup_size = {subgroup_size}>"
else:
compilation_info += ">"
attr = ir.Attribute.parse(compilation_info)
op.attributes["compilation_info"] = attr
# Add other attributes if required.
if split_k:
add_attribute_by_name(op, "iree_flow_split_k", split_k)
def add_winograd_attribute(op: ir.Operation, config: List):
op_result = str(op.results[0]).split("ins(")[1]
dilation = int(
str(op.attributes["dilations"]).split("dense<")[1].split(">")[0]
)
stride = int(
str(op.attributes["strides"]).split("dense<")[1].split(">")[0]
)
if op.name == "linalg.conv_2d_nchw_fchw":
f = int(op_result.split("tensor<")[2].split("x")[0])
c = int(op_result.split("tensor<")[2].split("x")[1])
kh = int(op_result.split("tensor<")[2].split("x")[2])
kw = int(op_result.split("tensor<")[2].split("x")[3])
else:
kh = int(op_result.split("tensor<")[2].split("x")[0])
kw = int(op_result.split("tensor<")[2].split("x")[1])
c = int(op_result.split("tensor<")[2].split("x")[2])
f = int(op_result.split("tensor<")[2].split("x")[3])
if (
dilation == 1
and stride == 1
and kh == 3
and kw == 3
and [c, f] in config
):
op.attributes["iree_winograd_conv"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), 1
)
def add_attribute_by_name(op: ir.Operation, name: str, val: int):
attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
op.attributes[name] = attr
def shape_list_to_string(input):
return "x".join([str(d) for d in input])
def create_context() -> ir.Context:
context = ir.Context()
ireec_trans.register_all_dialects(context)
@@ -195,15 +415,48 @@ def create_context() -> ir.Context:
if __name__ == "__main__":
import argparse
from pathlib import Path
def path_expand(s):
return Path(s).expanduser().resolve()
parser = argparse.ArgumentParser()
parser.add_argument(
"-model",
type=path_expand,
default="model.mlir",
help="Path to the input mlir file",
)
parser.add_argument(
"-config_path",
type=path_expand,
default="best_configs.json",
help="Path where stores the op config file",
)
parser.add_argument(
"-output_path",
type=path_expand,
default="tuned_model.mlir",
help="Path to save the annotated mlir file",
)
parser.add_argument(
"-search_op",
type=str,
default="all",
help="Op to be optimized. options are matmul, bmm, conv.",
)
args = parser.parse_args()
with create_context() as ctx:
module = model_annotation(
ctx,
input_contents=sys.argv[1],
config_path=sys.argv[2],
search_op="all",
input_contents=args.model,
config_path=args.config_path,
search_op=args.search_op,
)
mlir_str = str(module)
filename = "tuned_model.mlir"
with open(filename, "w") as f:
with open(args.output_path, "w") as f:
f.write(mlir_str)
print(f"Saved mlir in {filename}.")
print(f"Saved mlir in {args.output_path}.")

View File

@@ -15,24 +15,6 @@
import argparse
import os
def dir_path(path):
if os.path.isdir(path):
return path
else:
os.mkdir(path)
return path
def dir_file(path):
if os.path.isfile(path):
return path
else:
raise argparse.ArgumentTypeError(
f"readable_file:{path} is not a valid file"
)
parser = argparse.ArgumentParser(description="SHARK runner.")
parser.add_argument(
"--device",
@@ -40,12 +22,6 @@ parser.add_argument(
default="cpu",
help="Device on which shark_runner runs. options are cpu, cuda, and vulkan",
)
parser.add_argument(
"--repro_dir",
help="Directory to which module files will be saved for reproduction or debugging.",
type=dir_path,
default="./shark_tmp",
)
parser.add_argument(
"--enable_tf32",
type=bool,
@@ -83,13 +59,19 @@ parser.add_argument(
)
parser.add_argument(
"--update_tank",
default=False,
default=True,
action="store_true",
help="When enabled, SHARK downloader will update local shark_tank if local hash is different from latest upstream hash.",
)
parser.add_argument(
"--force_update_tank",
default=False,
action="store_true",
help="When enabled, SHARK downloader will force an update of local shark_tank artifacts for each request.",
)
parser.add_argument(
"--local_tank_cache",
default="",
default=None,
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
)
@@ -108,8 +90,22 @@ parser.add_argument(
parser.add_argument(
"--enable_conv_transform",
default=False,
action="store",
action="store_true",
help="Enables the --iree-flow-enable-conv-nchw-to-nhwc-transform flag.",
)
parser.add_argument(
"--enable_img2col_transform",
default=False,
action="store_true",
help="Enables the --iree-flow-enable-conv-img2col-transform flag.",
)
parser.add_argument(
"--use_winograd",
default=False,
action="store_true",
help="Enables the --iree-flow-enable-conv-winograd-transform flag.",
)
shark_args, unknown = parser.parse_known_args()

View File

@@ -60,12 +60,12 @@ class SharkBenchmarkRunner(SharkRunner):
def __init__(
self,
mlir_module: bytes,
function_name: str = "forward",
device: str = "none",
mlir_dialect: str = "linalg",
extra_args: list = [],
):
self.device = shark_args.device if device == "none" else device
self.enable_tf32 = shark_args.enable_tf32
self.frontend_model = None
self.vmfb_file = None
self.mlir_dialect = mlir_dialect
@@ -73,7 +73,6 @@ class SharkBenchmarkRunner(SharkRunner):
SharkRunner.__init__(
self,
mlir_module,
function_name,
device,
self.mlir_dialect,
self.extra_args,
@@ -83,9 +82,8 @@ class SharkBenchmarkRunner(SharkRunner):
self.vmfb_file = export_iree_module_to_vmfb(
mlir_module,
device,
shark_args.repro_dir,
".",
self.mlir_dialect,
function_name,
extra_args=self.extra_args,
)
@@ -100,15 +98,19 @@ class SharkBenchmarkRunner(SharkRunner):
def benchmark_frontend(self, modelname):
if self.mlir_dialect in ["linalg", "torch"]:
return self.benchmark_torch(modelname)
elif self.mlir_dialect in ["mhlo", "tf"]:
return self.benchmark_tf(modelname)
def benchmark_torch(self, modelname):
import torch
import torch._dynamo as dynamo
from tank.model_utils import get_torch_model
if self.device == "cuda":
torch.set_default_tensor_type(torch.cuda.FloatTensor)
if self.enable_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
else:
torch.set_default_tensor_type(torch.FloatTensor)
torch_device = torch.device(
@@ -116,6 +118,7 @@ class SharkBenchmarkRunner(SharkRunner):
)
HFmodel, input = get_torch_model(modelname)[:2]
frontend_model = HFmodel.model
# frontend_model = dynamo.optimize("inductor")(frontend_model)
frontend_model.to(torch_device)
input.to(torch_device)
@@ -138,11 +141,26 @@ class SharkBenchmarkRunner(SharkRunner):
def benchmark_tf(self, modelname):
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
from tank.model_utils_tf import get_tf_model
tf_device = "/GPU:0" if self.device == "cuda" else "/CPU:0"
# tf_device = "/GPU:0" if self.device == "cuda" else "/CPU:0"
tf_device = "/CPU:0"
with tf.device(tf_device):
model, input, = get_tf_model(
(
model,
input,
) = get_tf_model(
modelname
)[:2]
frontend_model = model
@@ -172,11 +190,11 @@ class SharkBenchmarkRunner(SharkRunner):
def benchmark_python(self, inputs):
input_list = [x for x in inputs]
for i in range(shark_args.num_warmup_iterations):
self.run(input_list)
self.run("forward", input_list)
begin = time.time()
for i in range(shark_args.num_iterations):
out = self.run(input_list)
out = self.run("forward", input_list)
if i == shark_args.num_iterations - 1:
end = time.time()
print(
@@ -262,7 +280,8 @@ for currently supported models. Exiting benchmark ONNX."
]
def get_metadata(self, modelname):
with open("./tank/model_metadata.csv", mode="r") as csvfile:
metadata_path = os.path.join(".", "tank", "model_metadata.csv")
with open(metadata_path, mode="r") as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")
fields = next(torch_reader)
for row in torch_reader:
@@ -323,7 +342,10 @@ for currently supported models. Exiting benchmark ONNX."
else:
bench_result["shape_type"] = "static"
bench_result["device"] = device_str
bench_result["data_type"] = inputs[0].dtype
if "fp16" in modelname:
bench_result["data_type"] = "float16"
else:
bench_result["data_type"] = inputs[0].dtype
for e in engines:
(
bench_result["param_count"],

View File

@@ -14,6 +14,7 @@
import numpy as np
import os
from tqdm.std import tqdm
import sys
from pathlib import Path
from shark.parser import shark_args
@@ -33,7 +34,6 @@ def download_public_file(
dest_filename = None
desired_file = None
if single_file:
desired_file = full_gs_url.split("/")[-1]
source_blob_name = "/".join(full_gs_url.split("/")[3:-1])
destination_folder_name, dest_filename = os.path.split(
@@ -52,12 +52,18 @@ def download_public_file(
destination_filename = os.path.join(
destination_folder_name, dest_filename
)
blob.download_to_filename(destination_filename)
with open(destination_filename, "wb") as f:
with tqdm.wrapattr(
f, "write", total=blob.size
) as file_obj:
storage_client.download_blob_to_file(blob, file_obj)
else:
continue
destination_filename = os.path.join(destination_folder_name, blob_name)
blob.download_to_filename(destination_filename)
with open(destination_filename, "wb") as f:
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
storage_client.download_blob_to_file(blob, file_obj)
input_type_to_np_dtype = {
@@ -70,29 +76,31 @@ input_type_to_np_dtype = {
"int8": np.int8,
}
# Save the model in the home local so it needn't be fetched everytime in the CI.
home = str(Path.home())
alt_path = os.path.join(os.path.dirname(__file__), "../gen_shark_tank/")
custom_path = shark_args.local_tank_cache
if os.path.exists(alt_path):
WORKDIR = alt_path
print(
f"Using {WORKDIR} as shark_tank directory. Delete this directory if you aren't working from locally generated shark_tank."
)
if custom_path:
if custom_path is not None:
if not os.path.exists(custom_path):
os.mkdir(custom_path)
WORKDIR = custom_path
print(f"Using {WORKDIR} as local shark_tank cache directory.")
elif os.path.exists(alt_path):
WORKDIR = alt_path
print(
f"Using {WORKDIR} as shark_tank directory. Delete this directory if you aren't working from locally generated shark_tank."
)
else:
WORKDIR = os.path.join(home, ".local/shark_tank/")
print(
f"shark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache= flag"
)
# Checks whether the directory and files exists.
def check_dir_exists(model_name, frontend="torch", dynamic=""):
model_dir = os.path.join(WORKDIR, model_name)
@@ -118,10 +126,7 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
and os.path.isfile(os.path.join(model_dir, "hash.npy"))
):
print(
f"""The models are present in the {WORKDIR}. If you want a fresh
download, consider deleting the directory."""
)
print(f"""Using cached models from {WORKDIR}...""")
return True
return False
@@ -146,6 +151,9 @@ def download_model(
):
print(f"Downloading artifacts for model {model_name}...")
download_public_file(full_gs_url, model_dir)
elif shark_args.force_update_tank == True:
print(f"Force-updating artifacts for model {model_name}...")
download_public_file(full_gs_url, model_dir)
else:
if not _internet_connected():
print(
@@ -161,29 +169,25 @@ def download_model(
os.path.join(model_dir, "upstream_hash.npy"),
single_file=True,
)
upstream_hash = str(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
if local_hash != upstream_hash:
if shark_args.update_tank == True:
print(f"Updating artifacts for model {model_name}...")
download_public_file(full_gs_url, WORKDIR)
else:
print(
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
)
try:
upstream_hash = str(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
except FileNotFoundError:
upstream_hash = None
if local_hash != upstream_hash and shark_args.update_tank == True:
print(f"Updating artifacts for model {model_name}...")
download_public_file(full_gs_url, model_dir)
elif local_hash != upstream_hash:
print(
"Hash does not match upstream in gs://shark_tank/latest. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
)
model_dir = os.path.join(WORKDIR, model_dir_name)
suffix = (
"_" + frontend + ".mlir"
if tuned is None
else "_" + frontend + "_" + tuned + ".mlir"
)
tuned_str = "" if tuned is None else "_" + tuned
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
filename = os.path.join(model_dir, model_name + suffix)
if not os.path.isfile(filename):
filename = os.path.join(
model_dir, model_name + "_" + frontend + ".mlir"
)
with open(filename, mode="rb") as f:
mlir_file = f.read()

View File

@@ -55,6 +55,7 @@ class SharkImporter:
inputs: tuple = (),
frontend: str = "torch",
raw_model_file: str = "",
return_str: bool = False,
):
self.module = module
self.inputs = None if len(inputs) == 0 else inputs
@@ -65,6 +66,7 @@ class SharkImporter:
)
sys.exit(1)
self.raw_model_file = raw_model_file
self.return_str = return_str
# NOTE: The default function for torch is "forward" and tf-lite is "main".
@@ -72,10 +74,14 @@ class SharkImporter:
from shark.torch_mlir_utils import get_torch_mlir_module
return get_torch_mlir_module(
self.module, self.inputs, is_dynamic, tracing_required
self.module,
self.inputs,
is_dynamic,
tracing_required,
self.return_str,
)
def _tf_mlir(self, func_name, save_dir="./shark_tmp/"):
def _tf_mlir(self, func_name, save_dir="."):
from iree.compiler import tf as tfc
return tfc.compile_module(
@@ -85,7 +91,7 @@ class SharkImporter:
output_file=save_dir,
)
def _tflite_mlir(self, func_name, save_dir="./shark_tmp/"):
def _tflite_mlir(self, func_name, save_dir="."):
from iree.compiler import tflite as tflitec
self.mlir_model = tflitec.compile_file(
@@ -158,6 +164,7 @@ class SharkImporter:
func_name="forward",
dir=tempfile.gettempdir(),
model_name="model",
golden_values=None,
):
if self.inputs == None:
print(
@@ -177,7 +184,11 @@ class SharkImporter:
if self.frontend in ["torch", "pytorch"]:
import torch
golden_out = self.module(*self.inputs)
golden_out = None
if golden_values is not None:
golden_out = golden_values
else:
golden_out = self.module(*self.inputs)
if torch.is_tensor(golden_out):
golden_out = tuple(
golden_out.detach().cpu().numpy(),
@@ -245,12 +256,128 @@ class SharkImporter:
)
def get_f16_inputs(inputs, is_f16, f16_input_mask):
if is_f16 == False:
return inputs
if f16_input_mask == None:
return tuple([x.half() for x in inputs])
f16_masked_inputs = []
for i in range(len(inputs)):
if f16_input_mask[i]:
f16_masked_inputs.append(inputs[i].half())
else:
f16_masked_inputs.append(inputs[i])
return tuple(f16_masked_inputs)
def transform_fx(fx_g):
import torch
kwargs_dict = {
"dtype": torch.float16,
"device": torch.device(type="cpu"),
"pin_memory": False,
}
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
]:
node.kwargs = kwargs_dict
# Inputs and outputs of aten.var.mean should be upcasted to fp32.
if node.target in [torch.ops.aten.var_mean]:
with fx_g.graph.inserting_before(node):
new_node = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
node.args = (new_node, node.args[1])
if node.name.startswith("getitem"):
with fx_g.graph.inserting_before(node):
if node.args[0].target in [torch.ops.aten.var_mean]:
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node,),
kwargs={"dtype": torch.float16},
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
new_node.kwargs = {"dtype": torch.float16}
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
# Doesn't replace the None type.
def change_fx_graph_return_to_tuple(fx_g):
for node in fx_g.graph.nodes:
if node.op == "output":
# output nodes always have one argument
node_arg = node.args[0]
out_nodes = []
if isinstance(node_arg, list):
# Don't return NoneType elements.
for out_node in node_arg:
if not isinstance(out_node, type(None)):
out_nodes.append(out_node)
# If there is a single tensor/element to be returned don't
# a tuple for it.
if len(out_nodes) == 1:
node.args = out_nodes
else:
node.args = (tuple(out_nodes),)
fx_g.graph.lint()
fx_g.recompile()
return fx_g
def flatten_training_input(inputs):
flattened_input = []
for i in inputs:
if isinstance(i, dict):
for value in i.values():
flattened_input.append(value.detach())
elif isinstance(i, tuple):
for value in i:
flattened_input.append(value)
else:
flattened_input.append(i)
return tuple(flattened_input)
# Applies fx conversion to the model and imports the mlir.
def import_with_fx(model, inputs, debug=False):
def import_with_fx(
model,
inputs,
is_f16=False,
f16_input_mask=None,
debug=False,
training=False,
return_str=False,
save_dir=tempfile.gettempdir(),
model_name="model",
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
golden_values = None
if debug:
golden_values = model(*inputs)
# TODO: Control the decompositions.
fx_g = make_fx(
model,
@@ -265,6 +392,7 @@ def import_with_fx(model, inputs, debug=False):
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
]
),
)(*inputs)
@@ -285,16 +413,29 @@ def import_with_fx(model, inputs, debug=False):
strip_overloads(fx_g)
if is_f16:
fx_g = fx_g.half()
transform_fx(fx_g)
fx_g.recompile()
if training:
change_fx_graph_return_to_tuple(fx_g)
inputs = flatten_training_input(inputs)
ts_graph = torch.jit.script(fx_g)
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)
mlir_importer = SharkImporter(
fx_g,
ts_graph,
inputs,
frontend="torch",
return_str=return_str,
)
if debug:
(mlir_module, func_name), _, _ = mlir_importer.import_debug()
if debug: # and not is_f16:
(mlir_module, func_name), _, _ = mlir_importer.import_debug(
dir=save_dir, model_name=model_name, golden_values=golden_values
)
return mlir_module, func_name
mlir_module, func_name = mlir_importer.import_mlir()
return mlir_module, func_name

View File

@@ -40,8 +40,6 @@ class SharkInference:
----------
mlir_module : str
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
function_name : str
function to execute in the given mlir_module.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
@@ -53,10 +51,10 @@ class SharkInference:
Methods
-------
run(inputs=None):
Runs the mlir_module with the given inputs, if the inputs are not
given it autogenerates the inputs. Also, the inputs should be a
numpy array.
__call__(function_name, inputs=None):
Runs the function with `function_name` within the mlir_module along
with the given inputs, if the inputs are not given it autogenerates the
inputs. Also, the inputs should be a numpy array.
input_info():
Gives the information about the inputs required by the `function_name`.
This can be expensive as it does string matching to do so.
@@ -66,18 +64,18 @@ class SharkInference:
def __init__(
self,
mlir_module: bytes,
function_name: str = "forward",
device: str = "none",
mlir_dialect: str = "linalg",
is_benchmark: bool = False,
dispatch_benchmark: str = None,
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
device_idx: int = None,
):
self.mlir_module = mlir_module
self.function_name = function_name
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.is_benchmark = is_benchmark
self.device_idx = device_idx
self.dispatch_benchmarks = (
shark_args.dispatch_benchmarks
if dispatch_benchmark is None
@@ -92,7 +90,6 @@ class SharkInference:
self.shark_runner = None
def compile(self, extra_args=[]):
if self.dispatch_benchmarks is not None:
extra_args.append(
f"--iree-hal-dump-executable-sources-to={self.dispatch_benchmarks_dir}"
@@ -113,7 +110,6 @@ class SharkInference:
self.shark_runner = SharkBenchmarkRunner(
self.mlir_module,
self.function_name,
self.device,
self.mlir_dialect,
extra_args=extra_args,
@@ -122,10 +118,10 @@ class SharkInference:
else:
self.shark_runner = SharkRunner(
self.mlir_module,
self.function_name,
self.device,
self.mlir_dialect,
extra_args=extra_args,
device_idx=self.device_idx,
)
if self.dispatch_benchmarks is not None:
@@ -138,21 +134,25 @@ class SharkInference:
os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}")
# inputs are considered to be tuple of np.array.
def forward(self, inputs: tuple):
return self.shark_runner.run(inputs)
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
return self.shark_runner.run(function_name, inputs, send_to_host)
# Get all function names defined within the compiled module.
def get_functions_in_module(self):
return self.shark_runner.get_functions_in_module()
# Captures the static input information from the mlir_module.
# TODO(pashu123): Generate the input information for dynamic shapes.
def _input_info(self):
def _input_info(self, function_name):
# func_key to get the line which contains the function.
func_key = "func.func @" + self.function_name
func_key = "func.func @" + function_name
func_header = None
for line in str(self.mlir_module).splitlines():
if func_key in line:
func_header = line
break
if func_header is None:
print(f"Function: {self.function_name} not found")
print(f"Function: {function_name} not found")
import re
@@ -190,7 +190,6 @@ class SharkInference:
self.device,
dir,
self.mlir_dialect,
self.function_name,
module_name=module_name,
extra_args=extra_args,
)
@@ -198,7 +197,6 @@ class SharkInference:
# load and return the module.
def load_module(self, path, extra_args=[]):
self.shark_runner = SharkRunner(
function_name=self.function_name,
device=self.device,
compile_vmfb=False,
extra_args=extra_args,
@@ -209,6 +207,6 @@ class SharkInference:
) = load_flatbuffer(
path,
self.device,
self.function_name,
self.device_idx,
)
return

View File

@@ -39,8 +39,6 @@ class SharkRunner:
----------
mlir_module : str
mlir_module represented in string.
function_name : str
function to execute in the given mlir_module.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
@@ -50,10 +48,10 @@ class SharkRunner:
Methods
-------
run(inputs=None):
Runs the mlir_module with the given inputs, if the inputs are not
given it autogenerates the inputs. Also, the inputs should be a
numpy array.
run(function_name, inputs=None):
Runs the function with `function_name` within the mlir_module along
with the given inputs, if the inputs are not given it autogenerates the
inputs. Also, the inputs should be a numpy array.
input_info():
Gives the information about the inputs required by the `function_name`.
This can be expensive as it does string matching to do so.
@@ -62,17 +60,17 @@ class SharkRunner:
def __init__(
self,
mlir_module: bytes = None,
function_name: str = "forward",
device: str = "none",
mlir_dialect: str = "linalg",
extra_args: list = [],
compile_vmfb: bool = True,
device_idx: int = None,
):
self.mlir_module = mlir_module
self.function_name = function_name
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
self.device_idx = device_idx
if check_device_drivers(self.device):
print(device_driver_info(self.device))
@@ -87,14 +85,20 @@ class SharkRunner:
self.mlir_module,
self.device,
self.mlir_dialect,
func_name=self.function_name,
extra_args=self.extra_args,
device_idx=self.device_idx,
)
def run(self, inputs: tuple):
def run(self, function_name, inputs: tuple, send_to_host=False):
return get_results(
self.iree_compilation_module,
function_name,
inputs,
self.iree_config,
self.mlir_dialect,
send_to_host,
)
# Get all function names defined within the compiled module.
def get_functions_in_module(self):
return self.iree_compilation_module._vm_module.function_names

Some files were not shown because too many files have changed in this diff Show More