Compare commits

..

1230 Commits

Author SHA1 Message Date
dan
489a858af1 enforce fp32 accumulates for cpu 2023-10-29 18:59:00 +00:00
Vivek Khandelwal
b83d32fafe Fix Falcon GPTQ Pipeline 2023-10-11 20:09:32 +05:30
Vivek Khandelwal
0a618e1863 Add support for Falcon GPTQ 2023-10-11 10:47:48 +05:30
Phaneesh Barwaria
a731eb6ed4 Macos fixes (#1883)
* fix venv setup for MacOS

* allow stream fuse binding on mac

* clean iree metal args
2023-10-09 23:36:12 -07:00
Ean Garvey
2004d16945 Revert "[SDXL] Add SDXL pipeline to SHARK (#1731)" (#1882)
This reverts commit 9f0a421764.
2023-10-09 18:01:44 -07:00
Gaurav Shukla
6e409bfb77 fix else if syntax error
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-10 06:23:56 +05:30
Gaurav Shukla
77727d149c [warning] Fix dropdown warning
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-10 05:18:43 +05:30
Ean Garvey
66f6e79d68 Split CPU/GPU definitions conditionally outside of torch contexts. (#1879) 2023-10-09 16:46:41 -07:00
Ean Garvey
3b825579a7 (LLaMa-2) Point to int4 + f32 acc .mlir for cpu (#1878)
- fixes some issues with non-system prompt invocation

Co-authored-by: Gaurav Shukla <gauravshukla789@gmail.com>
2023-10-09 14:37:35 -05:00
Abhishek Varma
9f0a421764 [SDXL] Add SDXL pipeline to SHARK (#1731)
-- This commit adds SDXL pipeline to SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-10-09 13:01:37 -05:00
Gaurav Shukla
c28682110c [chatbot] Flag to add system prompt
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-09 22:17:39 +05:30
Ean Garvey
caf6cc5d8f Switch most compile flows to use ireec.compile_file. (#1863)
* Switch most compile flows to use ireec.compile_file.

* re-add input type to compile_str path.

* Check if mlir_module exists before checking if it's a path or pyobject.

* Fix some save_dir cases
2023-10-06 23:04:43 -05:00
Ean Garvey
8614a18474 Remove tf dependencies from importer path. (#1874)
* Remove tf dependencies from import path.

* Fix formatting.
2023-10-06 12:27:12 -07:00
Jakub Kuderski
86c1c0c215 Add aggregate statistics to microbenchmark (#1871)
Print averaged results at the end of all iterations. Increase the
default number of iterations to 5.

Example:
```
Number of iterations: 5
Prefill: avg. 0.03 s, stddev 0.00
Decode: avg. 43.34 tokens/s, stdev 0.13
```

Also remove the -2 in the number of generated tokens -- I did not find
any evidence we need it.
2023-10-06 10:03:07 -07:00
Daniel Garvey
8bb364bcb8 enforce fp32 accumulates for cpu (#1873) 2023-10-06 11:34:49 -05:00
Daniel Garvey
7abddd01ec argmax inside model + brevitas pin (#1872) 2023-10-05 20:15:21 -07:00
Abhishek Varma
2a451fa0c7 [Llama2] Add a standalone utility for dynamic and combining IRs
-- This script adds a standalone utility for converting Llama IRs
   to dynamic and combining them as well.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-10-05 20:01:06 +05:30
Jakub Kuderski
9c4610b9da Add microbenchmark mode to vicuna CLI (#1864)
Add flags to enable a non-internactive mode for microbenchmarking llama
models. In this mode, the system and user prompts are specified with CLI
flags, and the number of generated tokens and iterations is fixed.

Also move the stats below the response and trim any response blankspace.
2023-10-05 00:12:08 -04:00
powderluv
a38cc9d216 Update vulkan_utils.py for Radeon 780m igpu (#1866) 2023-10-04 20:33:07 -07:00
Jakub Kuderski
1c382449ec [vulkan] Print note about module load times. NFC. (#1862)
Print a note ahead of a potentially long inactivity to set the right expectations.

Separately, we should add progress to the UI and make this loading faster.
2023-10-03 17:27:27 -04:00
Gaurav Shukla
7cc9b3f8e8 [llama cli] Fix llama cli
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-03 20:39:53 +05:30
Gaurav Shukla
e54517e967 [UI] Disable config generator, lora train and model manager (#1858)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-02 22:34:40 -07:00
Ean Garvey
326327a799 Collect pipeline submodules for diffusers ckpt preprocessing. (#1859) 2023-10-03 00:29:28 -04:00
Ean Garvey
785b65c7b0 Add flag for specifying device-local caching allocator heap key. (#1856) 2023-10-03 00:28:39 -04:00
Sungsoon Cho
0d16c81687 Remove unused import. (#1857) 2023-10-02 11:36:08 -05:00
Vivek Khandelwal
8dd7850c69 Add Falcon-GPTQ support 2023-10-02 16:39:57 +05:30
Gaurav Shukla
e930ba85b4 [os] Remove os dependency from vmfb naming (#1854)
Also fixes a small ui issue for chatbot.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 12:38:17 -05:00
Gaurav Shukla
cd732e7a38 [chatbot] split execution time to prefill and decode
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
Gaurav Shukla
8e0f8b3227 [ui] Update chatbot UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
Gaurav Shukla
b8210ef796 [chatbot] Re-instantiate the chatbot object if device id changes
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
PhaneeshB
94594542a9 remove use of vulkaninfo 2023-09-28 21:57:00 +05:30
Gaurav Shukla
82f833e87d [vulkan] Update vmfb naming
Update vmfb naming for vulkan devices in order to resolve naming
conflicts in the presence of multiple vulkan devices.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-28 14:52:11 +05:30
Vivek Khandelwal
c9d6870105 Modify falcon pipeline for 180b support 2023-09-28 12:39:35 +05:30
Jakub Kuderski
4fec03a6cc [vulkan] Switch from coop matrix NV to KHR (#1848) 2023-09-27 21:43:37 -04:00
harsh-nod
9a27f51378 Deprecate inference directory
This patch removes the inference directory that was no longer being used.
2023-09-27 14:29:00 -07:00
Abhishek Varma
ad1a0f35ff Fix misdirection while saving vmfb
-- Currently SHARK suggests that vmfb has been saved, while
    that is not the case and no vmfb is generated. 
    This creates a misdirection for IR/vmfbs which are of larger
    size.
-- This commit therefore fixes that misdirection.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-27 16:25:29 +05:30
Nelson Sharpe
6773278ec2 Fix checkpoint_path unexpected argument (#1832) 2023-09-24 14:17:52 -07:00
Abhishek Varma
9a0efffcca [Llama2] Fix wrong Vulkan device ID + Add Vulkan compile flags
-- This commit fixes the wrong Vulkan device being selected during
   runtime.
-- It also adds couple of IREE compilation flags to target specific
   Vulkan device.
-- It also changes the Vulkan device listing to be more in tune with
   lowering control flow.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-22 22:24:18 +05:30
gpetters94
61c6f153d9 Switch to keras-nightly to fix a Linux issue (#1835) 2023-09-21 12:33:45 -04:00
Phaneesh Barwaria
effd42e8f5 pin gradio to v3.44.3 2023-09-21 17:33:43 +05:30
Sungsoon Cho
b5fbb1a8a0 Rename the func arg save_json to avoid name collision. (#1837)
* Rename the func arg save_json to avoid name collision.

* black formatted.
2023-09-19 17:29:27 -05:00
Quinn Dawkins
ded74d09cd [vicuna.py] Keep past key values on device (#1836)
The past key values are only used within the models themselves and can
be kept on device. For vulkan int4, this gives 44 tok/s (for the first
prompt) and settles at around 26 tok/s on 7900xtx.
2023-09-19 18:17:41 -04:00
Boian Petkantchin
79267931c1 Add argument --additional_compile_args (#1119)
This allows to pass more arguemnts to the IREE compiler
Example:
python my-app.py --additional_compile_args="--mlir-pretty-debuginfo --mlir-timing"

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-09-19 11:26:03 -05:00
zjgarvey
9eceba69b7 local_tank_cache included into clear_all (#1833) 2023-09-18 00:27:23 -05:00
Ean Garvey
ca609afb6a Update README.md (#1830) 2023-09-14 10:33:57 -05:00
Gaurav Shukla
11bdce9790 [flags] Fix vulkan runtime flags as vma is dropped from iree (#1831) 2023-09-14 08:58:59 -05:00
Ean Garvey
684943a4a6 (SD) Fix tokenizers imports in pyinstaller builds. (#1828)
* Fix tokenizers metadata.

* (SD) Disable VAE lowering configs (rdna3) and add versioned tunings.

* Update sd_annotation.py

* (SD) Add cv2 to spec.

* Update stencil pipeline with the new img2img arg.
2023-09-12 12:23:48 -05:00
PhaneeshB
b817bb8455 add roles for llama2 2023-09-12 10:59:28 +05:30
Ean Garvey
780f520f02 Fix vk.target_env extensions and remove redundant SD imports. (#1826)
* Remove redundant IREE runtime imports.

* Fix vulkan target env extensions.
2023-09-11 13:42:52 -05:00
Dom
c61b6f8d65 Code refactoring (#1817)
* use join

* fix bug

* further code optimizations

---------

Co-authored-by: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com>
2023-09-11 11:30:56 -05:00
Abhishek Varma
c854208d49 [Llama2] Prefetch llama2 tokenizer configs (#1824)
-- This commit prefetches llama2 tokenizer configs from shark_tank.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-08 11:29:54 -07:00
Gaurav Shukla
c5dcfc1f13 [vicuna] Exit when mlir is not present in shark tank (#1825)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-08 10:30:29 -07:00
Abhishek Varma
bde63ee8ae Add logging feature in WebUI (#1821) 2023-09-08 05:48:05 -07:00
Vivek Khandelwal
9681d494eb Update decomp list and shark trainer for DLRM 2023-09-06 21:24:50 +05:30
Gaurav Shukla
ede6bf83e2 [vicuna] Disabling the IR generation path
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-06 20:13:17 +05:30
Ean Garvey
2c2693fb7d Fix torchvision versioning in Linux importer setup. (#1809) 2023-09-05 12:57:03 -05:00
Vivek Khandelwal
1d31b2b2c6 Fix StableHLO Compilation flag 2023-09-05 21:32:33 +05:30
Gaurav Shukla
d2f64eefa3 [chatbot] Remove few outdated models from list (#1814) 2023-09-04 09:26:32 -07:00
Abhishek Varma
87ae14b6ff [SD] Add sdpfa decomposition + update IREE flag
-- This commit adds Scaled Dot Product Flash Attention's decomposition
   in shark_importer.
-- It also updates `iree-flow-enable-data-tiling` to `iree-opt-data-tiling`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-04 18:03:53 +05:30
Phaneesh Barwaria
1ccafa1fc1 fix llama2-70b rewrite tensor dim 2023-09-01 17:27:06 +05:30
jinchen62
4c3d8a0a7f Enable downloading vmfb/mlir for webui (#1807) 2023-08-31 11:05:47 -07:00
jinchen62
3601dc7c3b Fix llama2 13b combined ir (#1803) 2023-08-28 11:34:44 -07:00
Daniel Garvey
671881cf87 Llama2 70b (#1783)
* llama2 70b IR gen

* fix IR sec llama2 + debug

* llama270b

---------

Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
2023-08-25 23:04:28 -07:00
Gaurav Shukla
4e9be6be59 [chatbot] Add debug as class attribute (#1799)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-25 21:46:29 -07:00
Ean Garvey
9c8cbaf498 Add support for ROCM (Windows) in Studio + compile utils (#1770)
* WIP: MSVC ROCM support for SHARK Studio

* Make get_iree_rocm_args platform-agnostic.

* Update stable_args.py

* Update rocm arg handling in SD utils

* Guard quantization imports.

Co-authored-by: jam https://github.com/jammm
2023-08-25 20:56:05 -07:00
Ean Garvey
9e348a114e Revert changes process_skipfiles.py (#1798)
Keeps a small typo fix but reverts the rest of changes to this file from 450c231171
2023-08-25 15:31:49 -07:00
jinchen62
51f90a4d56 Update conversion passes for brevitas quant op (#1795) 2023-08-25 17:28:07 -05:00
Abhishek Varma
310d5d0a49 Fix llama2 13b crashing + add spec file for CLI execution of Llama (#1797)
* [Llama2] Add a fix for Llama2 13B downloading/crashing

-- This commit fixes downloading/crashing of llama2 13B on wrong
   .mlir file.
-- Also adds support for downloading vmfb from shark_tank in CLI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [llama2] Add a spec file to run Llama/Vicuna CLI exe

-- This commit adds a spec file to run Llama/Vicuna CLI exe.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-08-25 09:36:09 -05:00
Ean Garvey
9697981004 Pipe through a debug option to iree compile utils. (#1796)
* Update compile_utils.py

* Pipe through a flag to toggle debug options in compile utils.

* Update SharkLLMBase.py
2023-08-25 07:11:11 -07:00
Ean Garvey
450c231171 Add tokenizers to requirements.txt (#1790)
* Add tokenizers to requirements and pin version

* Update process_skipfiles.py
2023-08-24 19:44:04 -05:00
Ean Garvey
07f6f4a2f7 Add a short README for the OPT examples and small tweaks. (#1793)
* Small changes to OPT example.

* Update opt README.

* Add a few modes to batch script.

* Update README.md
2023-08-24 17:26:11 -07:00
jinchen62
610813c72f Add iree flag to strip assertions (#1791) 2023-08-24 10:51:19 -07:00
Ean Garvey
8e3860c9e6 Remove flags that are default in upstream IREE (#1785)
* Remove index bits flags now set by default

* Update shark_studio_imports.py
2023-08-24 11:57:54 -05:00
xzuyn
e37d6720eb Add Hires Fix (#1787)
* improper test hiresfix

* add sliders & use `clear_cache`

* add resample choices & fix step adjustment

* add step adjustment to img2img

* add resample options to img2img

* simplify hiresfix
- import `img2img_inf` from `img2img_ui.py` instead of just copying it into `txt2img_ui.py`

* set `hri` to None after using

* add more resample types, and don't show output until hiresfix is done

* cleaner implementation

* ran black

* ran black again with jupyter dependencies
2023-08-24 09:01:41 -07:00
Vivek Khandelwal
16160d9a7d Fix combine mlir script 2023-08-24 19:10:49 +05:30
Sungsoon Cho
79075a1a07 Opt perf (#1786)
* Define command line args, model-name, max-seq-len, platform, etc.

* Add usage example.

* Add opt_perf_comparision_batch.py.

* Use shlex instead.
2023-08-24 08:33:12 -05:00
Abhishek Varma
db990826d3 Add Llama2 13B int4 fp16 support (#1784)
Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-08-23 10:00:32 -07:00
gpetters94
7ee3e4ba5d Add stencil_unet_512 support (#1778)
This should fix any remaining issues with stencils and long prompts.
2023-08-22 12:23:46 -04:00
Vivek Khandelwal
05889a8fe1 Add LLaMa2-int4-fp16 support (#1782) 2023-08-22 07:45:50 -07:00
jinchen62
b87efe7686 Fix venv setup for brevitas (#1779) 2023-08-21 11:58:51 -07:00
gpetters94
82b462de3a Fix stencils for long prompts (#1777) 2023-08-19 00:26:51 -07:00
Daniel Garvey
d8f0f7bade replace public with private (#1776)
unload footguns
2023-08-18 14:22:46 -07:00
gpetters94
79bd0b84a1 Fix an issue with diffusers>0.19.3 (#1775) 2023-08-18 14:06:06 -04:00
jinchen62
8738571d1e Adapt the change of brevitas custom op name (#1772) 2023-08-17 14:24:43 -07:00
Gaurav Shukla
a4c354ce54 [version] Pin diffusers==0.19.3
Once the latest works with LORA train, unpin it.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
cc53efa89f [cli] Fix chatbot cli
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
9ae8bc921e [chatbot] Fix chatbot cli and webview warning
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
32eb78f0f9 [chatbot] Fix switching parameters in chatbot
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 19:14:17 +05:30
Ean Garvey
cb509343d9 Fix pytest benchmarks and shark_tank generation. (#1632)
- fix setup_venv.sh for benchmarks/imports etc.
- fix torch benchmarks in SharkBenchmarkRunner
- generate SD artifacts using build_tools/stable_diffusion_testing.py and --import_mlir
- decouple SD gen from tank/generate_sharktank for now
2023-08-16 17:48:47 -05:00
powderluv
6da391c9b1 update signtool to use /fd certHash 2023-08-15 15:11:40 -07:00
Ean Garvey
9dee7ae652 fix tkinter window (#1766) 2023-08-15 13:23:09 -07:00
Ean Garvey
343dfd901c Update SHARK-Runtime links to SRT (#1765)
* Update nightly.yml

* Update setup_venv.ps1

* Update CMakeLists.txt

* Update shark_iree_profiling.md

* Update setup_venv.sh

* Update README.md

* Update .gitmodules

* Update CMakeLists.txt

* Update README.md

* fix signtool flags

* Update nightly.yml

* Update benchmark_utils.py

* uncomment tkinter launch
2023-08-15 12:40:44 -07:00
Ean Garvey
57260b9c37 (Studio) Add hf-hub to pyinstaller metadata (#1761) 2023-08-14 23:01:50 -05:00
Ean Garvey
18e7d2d061 Enable vae tunings for rdna3. (#1764) 2023-08-14 21:00:14 -07:00
Stanley Winata
51a1009796 Add Forward method to SHARKRunner and fix examples. (#1756) 2023-08-14 19:20:37 -07:00
Daniel Garvey
045c3c3852 enable iree-opt-const-expr-hoisting in vicuna (#1742)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-08-14 18:43:42 -07:00
Ean Garvey
0139dd58d9 Specify max allocation size in IREE compile args. (#1760) 2023-08-14 15:43:09 -05:00
Ean Garvey
c96571855a prevents recompiles for cuda benchmarks + update benchmark_module path (#1759)
* xfail resnet50_fp16

* Fix cuda benchmarks and prevent recompilation.
2023-08-14 15:30:32 -05:00
PhaneeshB
4f61d69d86 add support passing iree flags for LLMs 2023-08-15 00:22:56 +05:30
Phaneesh Barwaria
531d447768 set default allocator for metal device creation (#1755) 2023-08-14 06:17:52 -07:00
Vivek Khandelwal
16f46f8de9 Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
c4723f469f Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
d804f45a61 Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
d22177f936 Update requirements.txt 2023-08-14 14:32:19 +05:30
George Petterson
75e68f02f4 Remove CUDNN 2023-08-14 14:32:19 +05:30
Gaurav Shukla
4dc9c59611 [chatbot] Add tokens generated per second (#1753) 2023-08-13 11:25:41 -07:00
Gaurav Shukla
18801dcabc [chat] Update chatbot ui
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-13 18:39:22 +05:30
Gaurav Shukla
3c577f7168 [vicuna] fix shard config generator script (#1747)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-10 11:26:03 -07:00
Stefan Kapusniak
f5e4fa6ffe UI/Web - Revert tab order (#1724)
* Revert ui tab order

* Reverts the tab order, so that SD, LLM, and Experimental are grouped
together again as far as is possible.
* Labelled "Generate Sharding Config" as experimental as pressing the
'Get Model Config' errors for me.

* Fix formatting in index.py
2023-08-10 11:25:36 -07:00
powderluv
48de445325 Enable caching and disable vma (#1746)
* Enable caching allocator by default

Going to toggle VMA off too and this is required for performance.  Will have to monitor in the wild reports.

* Disable VMA

Disable VMA
2023-08-10 10:49:44 -07:00
Gaurav Shukla
8e90f1b81a [vicuna] add default config in case of sharded vicuna
Signed-Off-by: Gaurav Shukla<gaurav@nod-labs.com>
2023-08-10 21:28:08 +05:30
Vivek Khandelwal
e8c1203be2 Fix vicuna script (#1745) 2023-08-10 06:11:14 -07:00
Vivek Khandelwal
e4d7abb519 Final patch for fixing Langchain token streaming issue (#1744) 2023-08-09 10:09:41 -07:00
powderluv
96185c9dc1 pin safetensors to 0.3.1 (#1740) 2023-08-08 19:24:44 -07:00
powderluv
bc22a81925 re-enable constant folding (#1739)
Tested and works well. (modulo unrelated driver issue)
2023-08-08 17:17:38 -07:00
Eliasj42
5203679f1f Bandaid fix 2 (#1728)
* download all mlirs

* fixed install method

* download all mlirs (#1727)

Co-authored-by: Elias Joseph <elias@nod-labs.com>

* added taggs

* fix name check for file existence

* Remove SD from all_models.csv (#1706)

Removes SD from pytests as it has its own test suite.

* gpt_langchain.py fixes for pydantic (#1722)

* removed dead code

---------

Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: Stefan Kapusniak <121311569+one-lithe-rune@users.noreply.github.com>
2023-08-08 12:14:57 -05:00
Vivek Khandelwal
bf073f8f37 [Langchain] Expand pipelines to fix token streaming issue 2023-08-08 10:27:23 +05:30
Stella Laurenzo
cec6eda6b4 Optimize device enumeration overhead and log details on long operations. (#1734)
* Optimize device enumeration overhead and log details on long operations.

* Various fixes to add `@functools.cache` to what should be one time, expensive, device enumeration and setup activities. Cuts several seconds off of initialization on my machine.
* Add detailed tracing to actual invocations if they exceed a certain timeout or have an exception.
* Add detailed tracing to loading status.
* By default detail logging is only printed if an operation takes an excessive amount of time. All logging/timing can be printed by setting the variable `$env:SHARK_DETAIL_TRACE = "1"`

* Remove cache from unhashable functions
2023-08-07 17:20:53 -07:00
Stella Laurenzo
9e37e03741 Clearly differentiate phases of loading modules to better understand if things are taking a long time. (#1733) 2023-08-07 14:03:12 -07:00
Stefan Kapusniak
9b8c4401b5 gpt_langchain.py fixes for pydantic (#1722) 2023-08-07 00:55:38 -07:00
Ean Garvey
a9f95a218b Remove SD from all_models.csv (#1706)
Removes SD from pytests as it has its own test suite.
2023-08-05 15:55:52 -05:00
PhaneeshB
872bd72d0b fix name check for file existence 2023-08-05 21:33:53 +05:30
Eliasj42
fd1c4db5d0 download all mlirs (#1727)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-08-04 18:22:06 -05:00
Daniel Garvey
759664bb48 add py files to pyinstaller for shark (#1723) 2023-08-04 14:10:43 -07:00
Daniel Garvey
14fd0cdd87 add missing subprocess import (#1721) 2023-08-04 15:15:22 -05:00
Daniel Garvey
a57eccc997 fix lint (#1720) 2023-08-04 14:54:33 -05:00
Daniel Garvey
a686d7d89f temporarily disable langchain stuff in webui (#1719)
its breaking the exe
2023-08-04 12:48:06 -07:00
Eliasj42
ed484b8253 added functionality for int8 vicuna and 4 shards (#1712)
combined vicuna_4_shards.py and vicuna.py to reduce code duplication

Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-08-04 14:05:05 -05:00
gpetters94
7fe57ebaaf Add vector database and add support on the web UI (#1699) 2023-08-04 13:47:19 -04:00
Nithin Meganathan
c287fd2be8 Add GPU ID's in model_confg.json by default for manual annotation (#1718) 2023-08-04 12:46:27 -05:00
Gaurav Shukla
51ec1a1360 [vicuna] Integrate sharded vicuna in web (#1717)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-04 11:46:53 -05:00
Gaurav Shukla
bd30044c0b [Shard] Add sharding generation in shark studio
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-04 21:51:14 +05:30
Ean Garvey
c9de2729b2 Add flag for toggling constant folding. (#1714) 2023-08-04 04:55:52 -07:00
Vivek Khandelwal
a5b13fcc2f [Langchain] Patch for fixing streaming of tokens (#1709) 2023-08-03 10:06:49 -07:00
Stefan Kapusniak
6bb329c4af Unsharded Vicuna: Fix Memory Error compiling mlir for lmsys/vicuna-7b-v1.3 fp16 with 64 GiB (#1702) 2023-08-01 06:07:56 -07:00
Vivek Khandelwal
98fb6c52df Expand pipelines to fix streaming of tokens 2023-07-31 22:11:01 +05:30
Stefan Kapusniak
206c1b70f4 UI/Web: Reorder tabs to separate SD and LLM (#1701)
Shuffle the tabs around so that:

* All the SD tabs are together
* All the LLM tabs are together
* All the experimental tabs are together
2023-07-29 22:25:30 -04:00
PhaneeshB
cdb037ee54 use shark_args for vulkan debug utils flag 2023-07-30 07:54:26 +05:30
PhaneeshB
ce2fd84538 fix cpu device name for SharkStudio 2023-07-30 07:54:26 +05:30
PhaneeshB
4684afad34 update upscalar example 2023-07-28 21:06:28 +05:30
PhaneeshB
8d65456b7a Move vulkan runtime flags to shark_args 2023-07-28 21:06:28 +05:30
PhaneeshB
d6759a852b add vulkan vma alloc flag 2023-07-28 21:06:28 +05:30
Daniel Garvey
ab57af43c1 Couple of fixes for vicuna.py (#1696)
* mega vicuna merge pt 2

* add fallback to ensure compile is called
2023-07-27 15:53:05 -07:00
jinchen62
4d5c55dd9f Fix vicuna script (#1697) 2023-07-27 17:24:26 -05:00
Vivek Khandelwal
07399ad65c [Langchain] Remove unused code (#1698) 2023-07-27 11:59:54 -05:00
Vivek Khandelwal
776a9c2293 Fix for Langchain (#1694)
For CPU, remove max time stopping criteria
Fix web UI issue
2023-07-26 09:00:23 -07:00
Eliasj42
9d399eb988 fixed bug where device_idx was hardcoded (#1693)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-07-25 19:00:13 -05:00
Vivek Khandelwal
927b662aa7 Add Langchain SHARK Compilation support for all paths 2023-07-25 22:15:42 +05:30
Abhishek Varma
47f8a79c75 [MiniGPT4] Add MiniGPT4 to SHARK (#1554)
* [MiniGPT4] Add MiniGPT4 to SHARK

-- This is the first installment of MiniGPT4 in SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* Add int8 support for MiniGPT4

-- This commit adds int8 support for MiniGPT4.

Signed-off-by: Abhishek Varma <abhishek@nod-lab.com>

* Update .spec for MiniGPT4's config files

* black format MiniGPT4

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Signed-off-by: Abhishek Varma <abhishek@nod-lab.com>
2023-07-25 09:42:27 -07:00
Stefan Kapusniak
289f983f41 SD - Implement seed arrays for batch runs (#1690)
* SD Scripts and UI tabs that support batch_count can now take a
string containing a JSON array, or a list of integers, as their seed
input.
* Each batch in a run will now take the seed specified at the
corresponding array index if one exists. If there is no seed at
that index, the seed value will be treated as -1 and a random
seed will be assigned at that position. If an integer rather than
a list or json array has been, everything works as before.
* UI seed input controls are now Textboxes with info lines about
the seed formats allowed.
* UI error handling updated to be more helpful if the seed input is
invalid.
2023-07-24 19:22:34 -07:00
Daniel Garvey
453e46562f mega vicuna merge pt 2 (#1685) 2023-07-24 12:42:20 -05:00
Gaurav Shukla
5497af1f56 [config] Add support for uploading sharding config file in chatbot (#1689)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-07-24 10:18:03 -07:00
Vivek Khandelwal
f3cb63fc9c Fix Langchain multiple device isssue (#1688) 2023-07-24 08:03:46 -07:00
Vivek Khandelwal
d7092aafaa Fix multiple issue for Langchain
This commit fixes the following issue for the Langchain:
1.) Web UI not able to fetch results.
2.) For each query model getting reloaded.
3.) SHARK module not using user provided device and precision.
4.) Create a class for main Langchain code.
5.) Misc issues
2023-07-21 21:56:27 +05:30
Vivek Khandelwal
a415f3f70e Fix Langchain Prompt issue and add web UI support (#1682) 2023-07-21 06:36:55 -07:00
Vivek Khandelwal
c292e5c9d7 Add Langchain CPU support and update requirements 2023-07-20 18:53:34 +05:30
Vivek Khandelwal
03c4d9e171 Add support for Llama-2-70b for web and cli, and for hf_auth_token 2023-07-20 14:57:48 +05:30
jinchen62
3662224c04 Update brevitas requirement (#1677)
also clean up useless args

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-19 22:03:32 -07:00
Vivek Khandelwal
db3f222933 Revert "Add Llama2 70B option in CLI and WebUI (#1673)" (#1679)
This reverts commit 41e5088908.
2023-07-19 22:02:48 -07:00
Stefan Kapusniak
68b3021325 Fixes cosmetic problems with Gradio 3.37.0 (#1676)
* Fix nod-ai logo having a white border
* Fix control labels having a black background
* Remove extra lower border below Save Prompt checkboxes in Txt2Img UI
2023-07-19 17:28:53 -07:00
AyaanShah2204
336469154d added copy-metadata for pyyaml (#1678) 2023-07-19 17:27:25 -07:00
Abhishek Varma
41e5088908 Add Llama2 70B option in CLI and WebUI (#1673) 2023-07-19 10:41:42 -07:00
PhaneeshB
0a8f7673f4 Add README for CodeGen server 2023-07-19 23:10:23 +05:30
PhaneeshB
c482ab78da fix second vic clearing for low mem device 2023-07-19 23:10:23 +05:30
Vivek Khandelwal
4be80f7158 Add support for the Llama-2 model 2023-07-19 20:57:08 +05:30
AyaanShah2204
536aba1424 unpinned torch_mlir (#1668)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-19 06:28:00 -07:00
Ean Garvey
dd738a0e02 small changes to opt_perf_comparison.py (#1670)
* Use longer prompts for OPT comparison script

* small tweaks
2023-07-19 06:26:50 -07:00
Daniel Garvey
8927cb0a2c set optional vmfb download (#1667) 2023-07-18 10:57:28 -07:00
Daniel Garvey
8c317e4809 fix cli for vicuna (#1666) 2023-07-18 10:03:40 -07:00
Vivek Khandelwal
b0136593df Add support for different compilation paths for DocuChat (#1665) 2023-07-18 09:49:44 -07:00
Vivek Khandelwal
11f62d7fac Minor fixes for MiniLM Training 2023-07-18 17:16:44 +05:30
powderluv
14559dd620 Update DocuChat as experimental (#1660) 2023-07-17 22:12:05 -07:00
AyaanShah2204
e503a3e8d6 fixed joblib import error (#1659) 2023-07-17 12:56:10 -07:00
AyaanShah2204
22a4254adf fixed pyinstaller path for langchain imports (#1658) 2023-07-17 12:19:21 -07:00
Vivek Khandelwal
ab01f0f048 Add Langchain model in SHARK (#1657)
* Add H2OGPT

* Add UI tab for h2ogpt

* Add source files from h2ogpt

* Add the rest of the files

* Add h2ogpt support

* Add SHARK Compilation support for langchain model for cli mode

---------

Co-authored-by: George Petterson <gpetters@protonmail.com>
2023-07-17 09:58:15 -07:00
Phaneesh Barwaria
c471d17cca codegen API (#1655) 2023-07-16 20:00:39 -07:00
Stefan Kapusniak
a2a436eb0c SD - Add repeatable (batch) seeds option (#1654)
* Generates the seeds for all batch_count batches being run up front
rather than generating the seed for a batch just before it is run.
* Adds a --repeatable_seeds argument defaulting to False
* When repeatable_seeds=True, the first seed for a set of batches will
also be used as the rng seed for the subsequent batch seeds in the run.
The rng seed is then reset.
* When repeatable_seeds=False, batch seeding works as currently.
* Update scripts under apps/scripts that support the batch_count
argument to also support the repeatable_seeds argument.
* UI/Web: Adds a checkbox element on each SD tab after batch count/size
for toggling repeatable seeds, and update _inf functions to take
this into account.
* UI/Web: Moves the Stop buttons out of the Advanced sections and next
to Generate to make things not fit quite so badly with the extra UI
elements.
* UI/Web: Fixes logging to the upscaler output text box not working
correctly when running multiple batches.
2023-07-15 16:22:41 -07:00
powderluv
1adb51b29d Update docker README.md 2023-07-15 14:31:56 -07:00
anush elangovan
aab2233e25 Add a dev Ubuntu 22.04 docker image 2023-07-15 16:25:37 +00:00
jinchen62
e20cd71314 Change to a separate pass to unpack quantized weights (#1652) 2023-07-15 04:54:53 -07:00
powderluv
5ec91143f5 add a HF accelerate requirement (#1651) 2023-07-14 05:56:12 -07:00
Ean Garvey
7cf19230e2 add perf comparison script for opt. (#1650) 2023-07-13 13:29:48 -05:00
powderluv
1bcf6b2c5b pin diffusers to 0.18.1 (#1648) 2023-07-13 01:02:24 -07:00
jinchen62
91027f8719 Remove done TODOs, a sup PR for #1644 (#1647) 2023-07-12 23:30:45 -07:00
powderluv
a909fc2e78 add tiktoken to spec file (#1646) 2023-07-12 16:12:02 -07:00
jinchen62
247f69cf9d Apply canonicalize for unpacking int4 (#1644)
- tested it unpacks int4 as expected
- tested it doesn't make difference on int8
2023-07-11 19:41:09 -07:00
PhaneeshB
3b8f7cc231 Add codegen support in UI + lint 2023-07-11 21:58:01 +05:30
PhaneeshB
6e8dbf72bd mlir/vmfb path fixes for vic pipeline 2023-07-11 21:58:01 +05:30
PhaneeshB
38e5b62d80 adapt UI to send model details to pipeline 2023-07-11 21:58:01 +05:30
PhaneeshB
1c7eecc981 add codegen support in vic pipeline 2023-07-11 21:58:01 +05:30
PhaneeshB
be417f0bf4 fix precision for fp16 2023-07-11 21:58:01 +05:30
AyaanShah2204
a517e217b0 Added support for building ZIP distributions (#1639)
* added support for zip files

* making linter happy

* Added temporary fix for NoneType padding

* Removed zip script

* Added shared imports file

* making linter happy
2023-07-09 06:45:36 -07:00
Ranvir Singh Virk
9fcae4f808 Metal testing (#1595)
* Fixing metal_platform and device selection

* fixing for metal platform

* fixed for black lint formating
2023-07-08 15:22:53 -07:00
Stefan Kapusniak
788d469c5b UI/Web Refix remaining gradio deprecation warning (#1638) 2023-07-08 13:48:36 -07:00
Stefan Kapusniak
8a59f7cc27 UI/Web add 'open folder' button to output gallery (#1634)
* Adds a button that opens the currently selected subdirectory using
the default OS file manager
* Improve output gallery handling of having images deleted out from
under it.
* Don't show VAE or LoRA lines in parameter info panel when their
value is 'None'
* Use a css class for small icon buttons on the output gallery
tab instead using the same id for multiple buttons
2023-07-08 12:44:59 -07:00
Stefan Kapusniak
1c2ec3c7a2 Some Fixes for Gradio 3.36.1 (#1637)
* Clear .style deprecation warnings.
* Re-remove download button from Nod logos.
* Add work around for `container=false` not doing what it did before on
dropdowns to the output gallery CSS
2023-07-08 11:20:34 -07:00
powderluv
af0f715e20 Unpin gradio 2023-07-08 09:41:14 -07:00
jinchen62
47ec7275e6 Fix brevitas quantize argument (#1633) 2023-07-07 11:30:31 -07:00
powderluv
3a24cff901 change binary names 2023-07-06 23:59:14 -07:00
powderluv
1f72907886 Fix the pyinstaller for chatbots (#1631) 2023-07-06 23:30:01 -07:00
Daniel Garvey
06c8aabd01 remove local-sync from webui (#1629) 2023-07-06 13:58:59 -07:00
Phaneesh Barwaria
55a12cc0c4 cpu name in device (#1628)
* show cpu name in devices

* change device order for chatbot
2023-07-06 12:00:09 -07:00
Ean Garvey
7dcbbde523 Xfail models for data tiling flag changes (#1624) 2023-07-06 06:57:17 -07:00
Abhishek Varma
1b62dc4529 [Vicuna] Revert the formatting for Brevitas op (#1626)
-- This commit reverts the formatting for Brevitas op.
-- It also excludes vicuna.py script from `black` formatter.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-07-06 06:56:17 -07:00
Daniel Garvey
c5a47887f4 Revert revert negative prompt change (#1625)
* revert default flag changes

* revert revert negative prompt change

* revert revert negative prompt change
2023-07-05 22:09:06 -07:00
Daniel Garvey
c72d0eaf87 revert default flag changes (#1622) 2023-07-05 15:43:26 -05:00
powderluv
c41f58042a Update compile_utils.py (#1617)
* Update compile_utils.py

* Update compile_utils.py

* Update compile_utils.py
2023-07-05 10:06:48 -07:00
xzuyn
043e5a5c7a fix a mistake I made, and more formatting changes, and add ++/Karras (#1619)
* fixed missing line break in `stablelm_ui.py` `start_message`
- also more formatting changes

* fix variable spelling mistake

* revert some formatting cause black wants it different

* one less line, still less than 79

* add ++, karras, and karras++ types of dpmsolver.

* black line length 79

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-05 09:00:16 -07:00
Abhishek Varma
a1b1ce935c int8 e2e for WebUI (#1620) 2023-07-05 07:08:36 -07:00
jinchen62
bc6fee1a0c Add int4/int8 vicuna (#1598) 2023-07-05 07:01:51 -07:00
xzuyn
91ab594744 minor fix, some changes, some additions, and cleaning up (#1618)
* - fix overflowing text (a janky fix)
- add DEISMultistep scheduler as an option
- set default scheduler to DEISMultistep
- set default CFG to 3.5
- set default steps to 16
- add `xzuyn/PhotoMerge` as a model option
- add 3 new example prompts (which work nicely with PhotoMerge)
- formatting

* Set DEISMultistep in the cpu_only list instead

* formatting

* formatting

* modify prompts

* resize window to 81% & 85% monitor resolution instead of (WxH / 1.0625).

* increase steps to 32 after some testing. somewhere in between 16 and 32 is best compromise on speed/quality for DEIS, so 32 steps to play it safe.

* black line length 79

* revert settings DEIS as default scheduler.

* add more schedulers & revert accidental DDIM change
- add DPMSolverSingleStep, KDPM2AncestralDiscrete, & HeunDiscrete.
- did not add `DPMSolverMultistepInverse` or `DDIMInverse` as they only output latent noise, there are a few I did not try adding yet.
- accidentally set `upscaler_ui.py` to EulerDiscrete by default last commit while reverting DEIS changes.
- add `xzuyn/PhotoMerge-inpainting` as an in or out painting model.

* black line length 79

* add help section stuff and some other changes
- list the rest of the schedulers in argument help section.
- replace mutable default arguments.
- increased default window height to 91% to remove any scrolling for the main txt2img page (tested on a 1920x1080 monitor). width is the same as its just enough to have the image output on the side instead of the bottom.
- cleanup
2023-07-04 18:51:23 -07:00
Eliasj42
4015793f84 changed method of compiling vicuna to remove first and second vicuna (#1611)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-03 12:12:43 -07:00
Ean Garvey
d63ce76dd8 Use sortable image filenames for SD outputs. (#1528) 2023-07-03 10:30:47 -07:00
Prashant Kumar
1c32915570 Add the shark compile downstream due to https://github.com/pytorch/pytorch/pull/104185#issuecomment-1615110613 (#1615) 2023-07-01 08:30:58 -07:00
Ean Garvey
6d286c0609 Enable tuning for rectangle sizes on rdna2. (#1608) 2023-06-30 22:28:24 -07:00
Stefan Kapusniak
7392b22731 UI/Web Reduce animation of default --progress_bars (#1613) 2023-06-30 21:12:10 -07:00
jinchen62
534de05791 Update precision check for vicuna (#1610) 2023-06-29 16:16:33 -05:00
Daniel Garvey
5779e8c039 int4/int8 vicuna download support (#1609)
* set task_topology_max_group to cpu_count

by default. Can be overriden with a flag of the same str

* add download for int4/int8 mlir
2023-06-29 13:35:51 -07:00
Abhishek Varma
d496053590 [SHARK] Add a compile API to use for quick testing of inference (#1606) 2023-06-28 08:40:28 -07:00
gpetters94
6274a813c9 Add unet512 support for the other StableDiffusion pipelines (#1602) 2023-06-27 12:28:57 -07:00
Gaurav Shukla
1d6a1f9f8a [vicuna] Add tokens streaming(step=3) (#1600)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-27 08:59:27 -07:00
Daniel Garvey
75672c0e28 set task_topology_max_group to cpu_count (#1594)
by default. Can be overriden with a flag of the same str
2023-06-26 14:54:06 -07:00
Prashant Kumar
74a7202173 Make the tensors contiguous. 2023-06-26 17:29:54 +05:30
Prashant Kumar
27a08735db Add the shark backend for torch.compile API. (#1596) 2023-06-26 03:53:32 -07:00
Stefan Kapusniak
eaa49cce17 UI/App - Allow text selection (#1593)
* When run in app mode on windows, allows selection of text from
non-input controls, which is the same behaviour as web mode.
2023-06-26 02:16:53 -07:00
powderluv
10657d6fb1 Disable upx 2023-06-25 07:28:52 -07:00
Stefan Kapusniak
e3ab844cd1 Fix output gallery for csv format inc. VAE & LoRA (#1591) 2023-06-24 06:20:53 -07:00
powderluv
5ce6001b41 Update stablelm_ui.py to default to fp16 2023-06-23 22:55:47 -07:00
powderluv
501d0ca52e Add sentencepiece to webui for pyinstaller 2023-06-23 22:52:06 -07:00
powderluv
b444528715 Pin torch-mlir for windows too 2023-06-23 19:19:28 -07:00
Ean Garvey
6e6c90f62b Pin torch-mlir and use local-task in OPT. (#1592) 2023-06-23 19:17:05 -07:00
AyaanShah2204
8cdb38496e Final REST API Fixes (#1590)
* fixed outpaint api and added tests

* fixed text2img api

* more elegant generator to subscriptable conversion

* final fixes
2023-06-23 16:46:47 -07:00
powderluv
726d73d6ba Revert "[vicuna] Add streaming of tokens (#1587)" (#1588)
This reverts commit 4d55e51d46.
2023-06-23 10:29:00 -07:00
Gaurav Shukla
4d55e51d46 [vicuna] Add streaming of tokens (#1587)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-23 08:20:46 -07:00
Prashant Kumar
6ef78ee7ba Add cpu compile time flags. (#1585) 2023-06-23 07:23:26 -07:00
jinchen62
4002da7161 Add int4/int8 options to chatbot webui (#1586) 2023-06-23 07:18:34 -07:00
powderluv
ecb5e8e5d8 Update txt2img_ui.py 2023-06-23 06:42:12 -07:00
PhaneeshB
28e0919321 Add AMD cpu device 2023-06-23 18:47:04 +05:30
Daniel Garvey
28f4d44a6b downloader was double downloading (#1580) 2023-06-22 18:30:27 -07:00
AyaanShah2204
97f7e79391 [Blender Integration] Fixed Inpainting REST API (#1577)
* fixed inpaint api

* added inpainting test

* fixed linter errors

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-22 16:08:26 -07:00
Nelson Sharpe
44a8f2f8db Include VAE & LoRA data into PNG metadata (#1573)
* include custom lora and vae data in png metadata

* include pycharm settings

* lint with black
2023-06-22 16:05:54 -07:00
Eliasj42
8822b9acd7 added ability to use config file to shard vicuna (#1565)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-06-22 17:40:35 -05:00
Daniel Garvey
0ca3b9fce3 fix some mmap and vicuna bugs (#1576) 2023-06-22 17:39:55 -05:00
Nithin Meganathan
045f2bb147 Add dispatch-level config file generator for manual annotation (#1566) 2023-06-22 15:11:41 -07:00
Prashant Kumar
a811b867b9 Add shark_eager mode.
-- Eager mode with step by step op compilation and execution.
2023-06-22 22:59:14 +05:30
Abhishek Varma
cdd505e2dd [SharkInference-SharkRuntime] Adds capability to mmap vmfbs
-- This commit is based on [VmModule.mmap() API](https://github.com/openxla/iree/pull/14124).
-- It thereby adds capability to mmap vmfbs in SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-22 20:43:40 +05:30
powderluv
1b0f39107c Move torch_mlir import to the top (#1574) 2023-06-21 22:31:35 -07:00
powderluv
b9b8955f74 exclude vulkan on macos 2023-06-21 22:22:27 -07:00
powderluv
6f7a85eee3 switch to metal backend for CI 2023-06-21 22:17:11 -07:00
Ranvir Singh Virk
18c8e9e51e Metal typo fix (#1572)
* fixing typos for metal changes

* black formating
2023-06-21 21:56:11 -07:00
Daniel Garvey
a202bb466a fp16 fixes for webui (#1571) 2023-06-21 20:24:02 -07:00
Ranvir Singh Virk
07c1e1d712 Adding metal_utils for iree_utils (#1561)
* Adding metal_utils for iree_utils

* Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)

-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)

* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.

* Fixing iree-metal-target-platform

* adding metal to txt2img pipeline

* Fixing Copyright date

* removing debug prints

* black lint formating

* fixing device dump

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <avarma094@gmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-21 19:09:03 -07:00
Ranvir Singh Virk
18daec78c8 Added check for python version (#1570)
* Added check for python version

* Update for PYTHON_VERSION_X_Y
2023-06-21 18:56:47 -07:00
Ean Garvey
1a8e2024d6 Exclude non-square sizes from use_tuned on rdna2 (#1568) 2023-06-21 11:36:55 -05:00
AyaanShah2204
d61b6641fb Rest API: Resolved Generator Object not Subscripatable error (#1556) 2023-06-20 19:27:41 -07:00
Phaneesh Barwaria
88cc2423cc Enable Vicuna fp16 cpu (#1562)
* fix second vic mlir gen

* fp16 mlir/vmfb download from shark_tank
2023-06-20 13:43:21 -05:00
Ean Garvey
ccf944c1bd Enable tuner for upscaler unet. (#1563) 2023-06-20 13:40:13 -05:00
Ean Garvey
0def74f520 [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)
* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.
2023-06-20 10:26:36 -07:00
Abhishek Varma
3fb72e192e Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)
-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-20 10:04:17 -07:00
Vivek Khandelwal
855435ee24 Fix for the user input for Falcon pipeline 2023-06-20 18:09:32 +05:30
Elias Joseph
6f9f868fc0 fixed a bug where designating device for vicuna didn't work 2023-06-20 17:09:32 +05:30
powderluv
fb865f1b99 Move to checkout@v3
This will break Windows again but we have to fix it up since the old node.js is now deprecated.
2023-06-19 18:44:36 -07:00
rprasad2
3e5c50f07b changes for tuning (#1542)
* Add tuning sizes for rdna3
2023-06-19 15:29:08 -05:00
powderluv
a544f30a8f Move mega to the shark examples (#1555) 2023-06-19 11:10:51 -07:00
Abhishek Varma
1fe56d460a [MEGABYTE] Add script to compile MEGABYTE through SHARK (#1553)
-- Usage: `python mega_test.py`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-19 11:00:35 -07:00
Vivek Khandelwal
fafd713141 Minor change to falcon pipeline 2023-06-19 22:36:32 +05:30
Vivek Khandelwal
015d0132c3 Modify falcon pipeline to add fp16 support (#1551) 2023-06-19 09:57:13 -07:00
powderluv
20ddd96ef7 unpin diffusers (#1550) 2023-06-18 13:45:55 -07:00
powderluv
ee33cfd2d1 Add PIL in main index.py (#1549)
* Add PIL in main index.py

This is to ensure pyinstaller picks it up

* Update index.py
2023-06-18 11:51:44 -07:00
Stefan Kapusniak
a3cba21d5b Fix load of unet512 vmfb fail on get of iree opts (#1546)
* Change retrieval of Iree options used when loading an existing
unet512 vmfb to look up the "unet" options rather than attempt to
find a non-existent set of options for "unet512"

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-18 06:42:20 -07:00
Stefan Kapusniak
a7b6ec4095 Fix unet512 always being used when --max_length=77 (#1547)
* Switches a few places in the SD pipeline where an assumption of
max_length=64 was being made, to using the actual max_length
as passed into the pipeline. This prevents unet512 always being
used and producing different images than previously when
--max_length=77
2023-06-18 06:41:25 -07:00
Ean Garvey
d80b087d95 Add PIL hidden imports to sd spec. (#1544)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-18 06:39:08 -07:00
Stefan Kapusniak
297a209608 Remove workarounds for gradio tempfile bugs (#1548) 2023-06-17 19:50:36 -07:00
gpetters94
b204113563 Add UNet512 (#1504)
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-06-17 03:46:25 -04:00
Chi_Liu
f60ab1f4fa Add Deberta to stablehlo in shark tank (#1545) 2023-06-16 13:24:44 -07:00
Surya Jasper
b203779462 Added Adreno target triples to vulkan_utils (#1543) 2023-06-15 16:42:59 -07:00
Stefan Kapusniak
38570a9bbb Some Fixes for update to gradio 3.34.0 (#1538)
* Fixes randomize seed buttons that stopped working.
* Update now deprecated method to set initial colums for output
gallery to the newer undeprecated one.
2023-06-15 01:10:36 -07:00
dependabot[bot]
a5c882f296 Bump gradio from 3.15.0 to 3.34.0 (#1518)
Bumps [gradio](https://github.com/gradio-app/gradio) from 3.15.0 to 3.34.0.
- [Release notes](https://github.com/gradio-app/gradio/releases)
- [Changelog](https://github.com/gradio-app/gradio/blob/main/CHANGELOG.md)
- [Commits](https://github.com/gradio-app/gradio/compare/v3.15.0...v3.34.0)

---
updated-dependencies:
- dependency-name: gradio
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-14 18:13:48 -07:00
Ean Garvey
eb6d11cfed Change mlir dialects for tf tests to stablehlo. (#1535)
* Change mlir dialects for tf tests to stablehlo

* Update shark_runner.py
2023-06-14 10:43:49 -07:00
Vivek Khandelwal
46184a81ac Add Falcon pipeline (#1534) 2023-06-14 09:39:16 -07:00
PhaneeshB
149165a2f0 add multi-device mutli-precision vmfb names 2023-06-14 22:08:24 +05:30
dan
bec82a665f mega vicuna merge
single endpoint in apps/language/models/scripts/vicuna.py
removed main functions from pipelines
replaced divergent utils compile with shark_importer
adds support for different precisions
2023-06-14 19:06:29 +05:30
Ean Garvey
9551490341 Remove deprecared --iree-mhlo-demote-164-to-132 flag usage. (#1533) 2023-06-13 22:40:47 -05:00
Ean Garvey
49b3ecdbca (pytest) don't run redundant tests in cpu suite (#1532) 2023-06-13 22:40:33 -05:00
Ean Garvey
f53e3594c3 OPT Refactor (#1516)
* Change script to 1.3b model and add pytorch comparison

* fix CLI command

* Match OPT transformers model updates + numerics against latest version

* Cleanup OPT sentence completion script.

* Fix formatting and add standalone validation scripts.

* Add minimal OPT wrapper and example with import_with_fx

* Rename OPT full model wrapper.

* Cleanup test scripts for OPT.
2023-06-13 22:40:07 -05:00
Ean Garvey
5562d1dfda Fix xfails for cpu pytest cases (#1527)
Adding cpu-sync and cpu-task device configs was allowing respective tests to bypass the xfail conditional for cpu pytests marked in tank/all_models.csv. This commit updates the conditional to xfail those cases for cpu-sync and cpu-task as well.
2023-06-13 17:01:51 -07:00
Stefan Kapusniak
c7b0c2961e UI/Web Improve output gallery temp file handling (#1531)
* On startup report that cleaning up of temp files is taking place, in
case it takes a long time.
* Have the output gallery tab delete any zero length temporary files
generated by gradio < 3.32.0 for its gallery control whenever it
needs to update that control with images. This prevents such
files multiplying out of control.
2023-06-13 16:25:37 -05:00
Ean Garvey
44273b0791 Fix conditional in transform_fx() (#1530) 2023-06-13 16:24:53 -05:00
Prashant Kumar
0a4c8fcb3e Minor changes in the fx transforms. 2023-06-13 21:23:35 +05:30
Stefan Kapusniak
2fec3c8169 re-indents add_upcast in shark importer (#1523)
* The two with blocks in add_upcast appear to be underindented making
SD 1.4 break on rdna3, I've pushed them out one more tab, and then
everything appears to work again.
2023-06-12 14:41:10 -05:00
Gaurav Shukla
5e7d5930dd [vicuna] Add device and precision propagation in vicuna (#1520)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-12 12:14:43 -05:00
Prashant Kumar
b6dbd20250 Modify the fx transforms. (#1521)
- The bounds are set properly.
- The upcasting and downcasting is done for vicuna.
2023-06-12 09:40:14 -07:00
Nithin Meganathan
34f1295349 Add a model config generator (#1511)
Model config generator takes a PyTorch model as input and generates a JSON file with model layers and other propperties that define sharding on a particular hardware.
2023-06-09 15:32:00 -07:00
Phaneesh Barwaria
1980d7b2c3 Cpu device map (#1515)
* update cpu iree device

* fix vmfb paths vic unsharded
2023-06-09 11:27:02 -05:00
powderluv
2cfacc5051 fix osx torch_mlir (#1513)
* fix osx torch_mlir

* Update index.py

* Update index.py
2023-06-09 00:57:26 -07:00
Phaneesh Barwaria
436f58ddc4 cli using generate and mem fixes (#1509) 2023-06-08 13:13:32 -05:00
Phaneesh Barwaria
6b29bd17c8 Enable compilation vicuna (#1507)
* add cli for unsharded vic

* enable mlir download and compile
2023-06-07 13:08:22 -07:00
Ean Garvey
2c3485ca3e Add standalone OPT sentence completion script. (#1506) 2023-06-07 10:58:03 -07:00
Daniel Garvey
f206ecc635 reenable compilation in vicuna pipeline, add flags (#1505)
* replace vicuna.py backend with pipeline

* add some memory management to fist vicuna compile

reenable compilation
2023-06-07 09:49:27 -07:00
Stefan Kapusniak
a187e05ae6 Prevent having no cuda devices breaking the UI (#1503)
Don't break the UI when the LLM tab only wants cuda devices but there
aren't any.
2023-06-06 11:41:16 -07:00
Gaurav Shukla
8c21960486 [vicuna] Set only cuda devices in vicuna UI for now
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-06 22:15:20 +05:30
Gaurav Shukla
be62fce676 [vicuna] Fix vicuna chatbot (#1499)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-06 09:23:32 -07:00
PhaneeshB
f23b778a6c remove old vicuna scripts 2023-06-06 21:35:58 +05:30
PhaneeshB
436edf900d add vic sharded pipeline 2023-06-06 21:35:58 +05:30
Gaurav Shukla
ed58c2553f [vicuna] Integrate vicuna in shark studio
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-06 20:57:48 +05:30
Stefan Kapusniak
f2ca58e844 Add .csv and .json param info to output gallery (#1495) 2023-06-06 07:08:34 -07:00
Ean Garvey
1dbcc736eb [SD] (RDNA2) Enable new tuning for sd1.4 (#1498) 2023-06-06 06:48:58 -07:00
Phaneesh Barwaria
a83808ddc5 Vicuna cuda on A100 40G (#1496)
* vic chat with memory management (precompiled vmfb)

* fix vmfb path and download
2023-06-06 15:10:33 +05:30
Ean Garvey
a07fe80530 Update OPT, ResNet example scripts. (#1492)
* Update API in OPT example.

* fix resnet50 script

* Add OPT1.3b test script.
2023-06-05 20:19:35 -07:00
Ean Garvey
d0ba3ef8fa disable use_tuned on SD1.4 for rdna2 (#1490)
this is a temporary measure while we retune SD1.4 for rdna2. The current config fails during iree-compile.
2023-06-05 19:46:16 -05:00
Stefan Kapusniak
8400529c2c Fix output gallery not using shark_tmp (#1493)
This fix the gallery component of the  output gallery dumping temporary
files into the standard folders rather than shark_tmp so those files never
got cleared out on restart and would build up.
2023-06-05 16:23:49 -05:00
powderluv
7eaee9c242 update SHARK to nodai SHARK 2023-06-05 00:44:49 -07:00
powderluv
8230eebce5 Switch to CPU torch builds for shark.whl 2023-06-05 00:36:03 -07:00
Ean Garvey
6296ea4be9 fix config handling for sd1.4 on rdna2 (#1489) 2023-06-05 00:02:30 -07:00
Ean Garvey
4151ec3a8f (pytest) tag efficientnet, mobilenet as xfails on vulkan (#1488) 2023-06-04 23:22:32 -07:00
powderluv
a2467e8d43 Enable SHARK whl packages 2023-06-04 23:21:22 -07:00
Ean Garvey
e677178bcc Replace RDNA2 SD lowering configs. (#1486) 2023-06-05 00:57:43 -05:00
Anush Elangovan
7ef1bea953 XFAIL some macos tests 2023-06-04 15:27:03 -07:00
Chi_Liu
ad89bb1413 Add distilgpt2 to stablehlo in shark tank (#1481) 2023-06-02 16:44:46 -05:00
Ean Garvey
218ed78c40 Change instances of input_type='mhlo' to 'auto' (#1482) 2023-06-02 16:43:47 -05:00
Stefan Kapusniak
6046f36ab6 UI/Web: Fix upscaler stop button (mostly) (#1479)
* UI/Web: Fix upscaler stop button

* Hook the cancel_sd function up to the Stop button.
* Adds checks for SD_STATE_CANCEL in the upscaler ui inference function.
* Set and check for SD_STATE_IDLE, SD_STATE_CANCEL in the upscaler
pipeline.

* UI/Web: lint fixes for upscaler stop button fix

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-01 22:26:55 -07:00
Foxlum
5915bf7de3 Add to and tweak vulkan configuration environments. (#1475)
* Update vulkan_target_env_utils.py

* Update vulkan_target_env_utils.py

Adjust target environment capabilities.

* Update vulkan_target_env_utils.py

black linted?
2023-06-01 22:25:20 -07:00
Phaneesh Barwaria
f0a4e59758 LLM Pipeline Wrapper (#1477)
* [LLM] Add LLM pipeline

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

* add base pipeline and stableLM

* StableLM on UI - full block

* add SLM default model name

* add vicuna with pipeline

* add one token gen api for vic

* Fix stableLM bugs

* debug vic memory

* lint fix

---------

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
Co-authored-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-31 10:17:20 -07:00
Stefan Kapusniak
1ddef26af5 Web/UI: Add an Output Gallery tab for SD (#1470)
* WebUI: Adds an Output Gallery tab

Adds an new Output Gallery tab to the ui/webui with these features:

* Subdirectory select dropdown listing subdirectories at any depth below
the <output_dir>/generated_imgs directory,
* Large, full height, gallery area displaying the images in the selected
subdirectory. Shows nod logo when no images are in the selected
subdirectory.
* Slider that changes the number of columns of images that the gallery
displays from between 1 to 16 columns (defaults to 4).
* Expandable parameter info panel showing any generation parameters
saved in the file of the selected image for PNGs, alternatively the
image's EXIF data for JPEGs
* Send to buttons for txt2img, img2img, inpaint, outpaint and upscaler.
* Auto update of gallery and gallery label (to show generation status),
when a new image is generated by any of the stable diffusion tabs, and
is outputted to the currently selected subdirectory.
* Command line option for enabling and disabling the output gallery
(defaults to enabled)
* Command line option for following symlinks when getting entries
for the subdirectory list (defaults to off, as Python os.walk doesn't
check for circular references if following symlinks)

* Reformat with black

Reformat changes with black and then adjust some places where black's
formatting then needed some rephrasing of the code to make things
clearer.

* Add back transformers and sd_cancel imports

Adds back the transformers import in index.py needed for .exe
generation. Add comment so it doesn't get mistakenly removed
next time.
Adds back sd_cancel import in upscaler.py that is currently unused
but should be being used for the 'Stop' button.
2023-05-30 13:47:48 -07:00
Chi_Liu
ba8eddb12f Add GPT3/OPT to Stablehlo in shark tank (#1468)
Co-authored-by: AmosLewis <Amos_Lewsi@foxmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-05-29 21:58:39 -07:00
yzhang93
47b346d428 Modify the lowering config format for SPIRVMatmulPromoteVectorize pipeline (#1471) 2023-05-29 21:53:48 -07:00
Ean Garvey
1b4f4f5f4d Fix download path for SD1.4 Unet. (#1469) 2023-05-26 11:59:51 -07:00
Elias Joseph
73cd7e8320 added full vicuna to vicuna.py 2023-05-26 22:06:40 +05:30
Ean Garvey
19c0ae3702 Cleanup SD pipeline utils (#1466) 2023-05-25 12:50:11 -05:00
Ean Garvey
54e57f7771 Revive SD downloads from shark_tank. (#1465) 2023-05-25 12:03:21 -05:00
PhaneeshB
6d64b8e273 vic and slm common generation base 2023-05-25 20:29:41 +05:30
PhaneeshB
a8ea0326f5 correct SLM saved vmfb naming 2023-05-25 20:29:41 +05:30
PhaneeshB
58e9194553 add Lists import 2023-05-25 20:29:41 +05:30
PhaneeshB
eb360e255d remove unused imports 2023-05-25 20:29:41 +05:30
PhaneeshB
a6f88d7f72 refactor mlir compile 2023-05-25 20:29:41 +05:30
Prashant Kumar
8e571d165f Enable cpu f16 dtype tracing for the vicuna model. (#1461) 2023-05-24 09:37:57 -07:00
Ean Garvey
3cddd01b10 Update OPT tokenizer and xfail a few more large tests on macos CI (#1459)
* Update opt_torch_test.py

* Update all_models.csv
2023-05-23 14:36:57 -07:00
Chi_Liu
64c2b2d96b Add gpt2 to stablehlo support in shark tank (#1447)
- Add torch decomposition support when generating shark tank
- Add gpt2 stablehlo
2023-05-22 10:45:51 -07:00
Phaneesh Barwaria
f5ce121988 SLM on Sharkstudio (#1454)
* localize import, fix file reading, device cpu

* extract out model args
2023-05-19 11:21:08 -07:00
Ean Garvey
991f144598 Add iree hidden imports to SD spec (#1456)
* Add iree hidden imports to SD spec

* Update shark_sd_cli.spec
2023-05-19 11:19:16 -07:00
PhaneeshB
09bea17e59 fix #2 SLM in SharkStudio 2023-05-18 00:56:22 +05:30
Daniel Garvey
aefcf80b48 swap to cpu an remove hardcoded paths (#1448)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-05-17 10:53:34 -07:00
PhaneeshB
512235892e fix SLM for SharkStudio 2023-05-17 22:34:30 +05:30
PhaneeshB
6602a2f5ba add continuous output for CLI 2023-05-17 18:33:46 +05:30
Boian Petkantchin
20114deea0 In MiniLM JAX example verify MLIR result against JAX 2023-05-16 09:54:07 -07:00
Boian Petkantchin
9acf519078 Add option to skip venv creation in setup script 2023-05-16 09:54:07 -07:00
Boian Petkantchin
bdf37b5311 If device/backend is unknown pass it to IREE verbatim 2023-05-16 09:54:07 -07:00
powderluv
8ee2ac89f8 Rename sharded_vicuna_fp32_web.py to vicuna_web.py 2023-05-16 09:41:35 -07:00
powderluv
60cb48be2e Rename sharded_vicuna_fp32.py to vicuna.py 2023-05-16 09:40:51 -07:00
powderluv
86a215b063 Delete sharded_vicunia.py 2023-05-16 09:37:39 -07:00
powderluv
d6e3a9a236 Delete standalone_vicuna.py 2023-05-16 09:37:26 -07:00
Chi_Liu
a0097a1ead Add mlir_type for torch_model_list.csv (#1428)
- Enable stablehlo/tosa mlir output for torch model
- Add BERT stablehlo support
2023-05-15 10:23:54 -07:00
Ean Garvey
a9bae00606 Fix vulkan device selection at compile time and adapt to IREE python changes. (#1407)
* Add support for vulkan device selection at compile time.

* Don't convert device ID to int and fix .exe imports
2023-05-12 23:31:50 -07:00
Daniel Garvey
4731c1a835 prevent loading tokenizer on import (#1432)
also adds sentencepiece dep for exe
moved vicuna imports to after an if statement
in general we should avoid importing files that load whole models as
global variables
2023-05-12 19:11:45 -07:00
Ean Garvey
4c07e47e8c Specify a few models for expected failure on CUDA CI. (#1430) 2023-05-12 17:03:37 -05:00
Gaurav Shukla
e0cc2871bb [SD] Yield 2 tokens at a time in vicuna
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 23:49:01 +05:30
Gaurav Shukla
649f39408b [SD] Fix vicuna response
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 18:06:21 +05:30
Gaurav Shukla
c142297d73 [SD] Fix gradio to 3.22.0 version
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com
2023-05-11 18:05:55 +05:30
Gaurav Shukla
9e07360b00 [SD] Standalone vicuna with web
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 17:23:44 +05:30
Gaurav Shukla
7b74c86e42 [SD] Fix SAMPLE_INPUT_LEN import issue
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 15:41:43 +05:30
Eliasj42
fa833f8366 fixed spacing issue with chat-bot (#1417)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-10 16:07:50 -07:00
Gaurav Shukla
fcb059aa38 [SD] Integrate vicuna in the web (#1410) 2023-05-10 11:30:22 -07:00
PhaneeshB
517c670f82 vicuna chat cli 2023-05-10 22:55:06 +05:30
Eliasj42
59df14f18b added vicuna demo (#1408)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-09 21:18:20 -07:00
Ean Garvey
6c95ac0f37 Revert dialect registration in model annotator (#1406)
Matches https://github.com/nod-ai/SHARK-Runtime/pull/58
2023-05-09 11:50:19 -07:00
Daniel Garvey
7a4a51ae73 vulkan vic f16 (#1404)
Co-authored-by: dan <dan@nod-labs.com>
2023-05-08 16:46:53 -07:00
powderluv
d816cc015e Revert "added standalone vicuna script (#1399)" (#1402)
This reverts commit 0e4a8ca240.
2023-05-05 16:08:05 -07:00
Eliasj42
54ce3d48ca added standalone vicuna script (#1401)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-05 18:05:52 -05:00
Eliasj42
0e4a8ca240 added standalone vicuna script (#1399)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-05 15:46:05 -07:00
Daniel Garvey
6ca1298675 maximizes window size for webview launch (#1394) 2023-05-04 20:43:06 -07:00
jinchen62
bbef7a6464 Redesign model manager webui (#1391) 2023-05-04 20:41:29 -07:00
Ean Garvey
cdf2d61d53 Remove imports from iree.compiler.transforms from model annotator. (#1392) 2023-05-04 20:40:19 -07:00
Ean Garvey
6c14847d1f xfail some large tests on macOS builder and switch to hash updates. (#1341)
* Update test-models.yml

* Disable large tests on macOS builder
2023-05-04 19:47:03 -05:00
Gaurav Shukla
68ecdd2a73 [SD] Add LoRA as experimental tab
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-04 22:30:25 +05:30
Gaurav Shukla
3f4d444d18 [SD] Fix stable LM chatbot
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-04 22:30:25 +05:30
m68k-fr
e473d0375b [Web] Models folders cleanup (#1365) 2023-05-03 16:13:20 -05:00
Ean Garvey
e38d96850f Fix input image loading in img2img rest API (#1388) 2023-05-03 15:51:00 -05:00
Gaurav Shukla
fed63dfd4b [SD] Add stableLM chatbot (#1383)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-05-03 15:37:20 -05:00
Boian Petkantchin
eba4d06405 In MiniLM JAX example do not hardcode device (#1385)
* In MiniLM JAX example do not hardcode device

* In MiniLM JAX example don't use bytecode MLIR

---------

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-05-03 10:34:42 -07:00
Boian Petkantchin
4cfba153d2 Add example JAX MiniLM inference (#1380)
* Do not hardcode the name of the VM module in get_iree_module

* Add example JAX MiniLM inference

---------

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-05-02 15:03:54 -07:00
jinchen62
307c05f38d Convert original vae to diffusers (#1382) 2023-05-02 01:27:28 -07:00
jinchen62
696df349cb Fix curl issue (#1369) 2023-04-28 09:31:14 -07:00
jinchen62
cb54cb1348 Add model manager tab for SD webui (#1368) 2023-04-28 02:43:40 -07:00
Daniel Garvey
9bdb86637d add tkinter launch for webui (#1364) 2023-04-27 19:17:55 -05:00
jinchen62
fb6f26517f Fix webui note (#1367) 2023-04-27 16:14:43 -07:00
Chi_Liu
aa8ada9da9 Add support for torch to stablehlo and tosa in shark_importer (#1360) 2023-04-27 08:09:45 -07:00
powderluv
1db906a373 Revert "Add model manager tab for webui (#1359)" (#1362)
This reverts commit 9d1d1617d8.
2023-04-26 22:25:26 -07:00
jinchen62
9d1d1617d8 Add model manager tab for webui (#1359) 2023-04-26 13:38:18 -07:00
jinchen62
7112789cb8 Add support of using civitai model download url (#1357) 2023-04-25 23:39:52 -07:00
jinchen62
d6b8be2849 Add drawing canvas for img2img stencil scribble (#1355) 2023-04-25 14:41:01 -07:00
powderluv
822171277c Revert "[SD] Add FastChat as part of SD WebUI (#1349)" (#1350)
This reverts commit a5ae9d9f02.
2023-04-24 15:22:25 -07:00
Abhishek Varma
a5ae9d9f02 [SD] Add FastChat as part of SD WebUI (#1349)
-- This commit includes FastChat as part of SD WebUI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-04-24 11:12:58 -07:00
powderluv
09e3f63d5b Fix pascal (#1346)
* Add fp32 for upscaler VAE

* Plumb Pascal vulkan support
2023-04-23 20:28:25 -07:00
powderluv
d60a5a9396 Add fp32 for upscaler VAE (#1345) 2023-04-23 15:27:55 -07:00
m68k-fr
90df0ee365 [Web] Gallery set to a 768px reference for high-end desktop users (#1344) 2023-04-23 11:48:06 -07:00
nirvedhmeshram
133c1bcadd add device to scheduler model names (#1338) 2023-04-22 20:13:56 -05:00
powderluv
caadbe14e9 Revert VAE to use im2col (#1339) 2023-04-22 15:23:41 -07:00
Ean Garvey
5f5823ccd9 Fix inference object imports for SD apps. (#1334) 2023-04-21 13:40:48 -05:00
Vivek Khandelwal
d2f7e03b7e Add StableLM model (#1331) 2023-04-21 09:51:02 -07:00
Gaurav Shukla
0b01bbe479 [SD] Add txt2img/upscaler/inpaint/outpaint Rest API (#1325)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-21 09:06:06 -07:00
yzhang93
25c5fc44ae Modify tuner.py to take vulkan target triple flag (#1328) 2023-04-20 14:31:32 -07:00
Daniel Garvey
7330729c92 enable sd pytest (#1322) 2023-04-19 22:11:30 -05:00
Ean Garvey
ce16cd5431 Create local shark_tank if needed for tuning configs. (#1321)
Now that --clear_all successfully deletes local shark_tank cache, we need to make sure it exists before trying to use it.
2023-04-19 11:44:21 -05:00
Ean Garvey
598dc5f79d Don't dump image data on img2img api call. (#1320) 2023-04-19 21:24:46 +05:30
Abhishek Varma
1f8e332cbe [SD] Fix img2img API bug for custom_vae argument (#1319)
-- https://github.com/nod-ai/SHARK/pull/1314 misses to add `custom_vae`
   parameter to img2img_if's invocation within img2img_api.
-- This commit adds a fix to the same.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-04-19 10:39:52 -05:00
Abhishek Varma
17b9632659 [SD] Adapted SHARK's v1 img2img API for SdPaint + updated Stencil model ID (#1318) 2023-04-19 06:29:36 -07:00
jinchen62
bda92a54ab Fix custom vae path (#1317) 2023-04-18 20:50:43 -07:00
jinchen62
747ed383b1 Add custom vae dropdown in webui (#1314) 2023-04-18 17:24:02 -07:00
Ean Garvey
1afe07c296 Disable winograd on VAE with rdna2 and fix unet tuning. (#1313)
* Disable winograd on VAE with rdna2 and fix unet tuning.

* Fix batch size 1 downloads and clear_all on windows.
2023-04-18 15:55:10 -05:00
jinchen62
b70919b38d Fix memory leak with ondemand (#1312)
support ondemand for outpainting and multi batch_count
2023-04-18 13:03:16 -05:00
m68k-fr
4e513d647f Update list of scheduler available for inferences (#1298) 2023-04-17 22:37:00 -05:00
jinchen62
94cd2a0fed Fix outpainting config (#1310) 2023-04-17 10:48:52 -07:00
Kyle Herndon
606029c01c Fix LoRA device format bug and allow LoRA to resume from a previous training 2023-04-17 13:19:46 +05:30
powderluv
1aa85222e9 Add AMD W7900 target triple (#1304)
This maps to RDNA3
2023-04-16 00:14:21 -07:00
m68k-fr
1b3f468c04 [Web] Style Fixes for Gradio V3.25.0 (#1300) 2023-04-13 18:40:42 -05:00
m68k-fr
35de7e27fa [Web] remove txt2img ui dependencies from png import metadata (#1275) 2023-04-12 07:32:47 -10:00
yzhang93
467f900759 Add auto-tuner to SD apps (#1291) 2023-04-12 09:21:17 -07:00
Ean Garvey
0bd9d582c7 Add documentation for using SHARK with AI-Render (#1296) 2023-04-12 03:09:34 -10:00
jinchen62
428cfe8dae Fix low vram mode issues (#1295)
- add ondemand back to img2img
- workaround memory leak for batch count
2023-04-11 17:59:09 -07:00
Ean Garvey
f17915bedc Fix batch size appending to model name. (#1294)
* Update shark_downloader.py

* Update shark_downloader.py
2023-04-11 15:34:25 -05:00
Gaurav Shukla
1b49b5149a [SD] Add Img2Img rest API
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-11 23:06:58 +05:30
jinchen62
3002793301 Unload clip on demand and workaround memory leak (#1283) 2023-04-10 16:59:03 -07:00
Phaneesh Barwaria
d25ef5529f Add fix for vae fp32 Upscalar (#1284)
- fixes size mismatch error for upscalar vae
2023-04-07 14:36:40 -05:00
Ean Garvey
308856a947 Touch unet if base cfg needed for SD pipeline init (#1281) 2023-04-05 03:02:29 -05:00
m68k-fr
151b4e142f [SD] Fix encoder error for model_max_length not beeing 77 (#1278)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-04-04 22:39:29 -07:00
Ean Garvey
e5a69a7c36 pin diffusers to e47459c (#1279) 2023-04-04 18:29:21 -07:00
m68k-fr
450b6cafc4 [SD] Add weight emphasis to prompts encoder (#1276) 2023-04-04 09:47:04 -07:00
Daniel Garvey
237d26baa2 update model db to reflect changes (#1277)
* remove 1/1 tqdm progress bar

* update model_db to reflect changes
2023-04-04 11:46:55 -05:00
Daniel Garvey
67d6ee1104 remove 1/1 tqdm progress bar (#1274) 2023-04-03 22:30:09 -05:00
Ean Garvey
98b069488e Add tank_version.json (#1272) 2023-04-03 18:36:23 -07:00
jinchen62
e0f227643a Fix webui circular import issue (#1271) 2023-04-03 16:00:10 -07:00
jinchen62
a0af3bb0cb xload and unload models (#1242) 2023-04-03 14:42:18 -07:00
powderluv
2cd61a5b96 strip source map (#1270) 2023-04-03 14:41:32 -07:00
Gaurav Shukla
f49d41a807 [SD] Add Stable diffusion text2image rest API (#1265)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-03 12:02:24 -07:00
Ean Garvey
2191fc8952 Separate pytest benchmark modes and fix model updates for SHARK downloader / pytest. (#1264)
* Only xfail windows models in CI

* downloader: make model updates more robust.

* Separate baseline and native benchmarks in pytest.

* Fix native benchmarks

* Fix torchvision model utils.
2023-04-03 08:24:21 -07:00
PhaneeshB
aea7796e60 add gradio client to spec 2023-04-03 18:57:19 +05:30
Abhishek Varma
a376619f1e [SD] Improve vmfb caching algo and retry mechanism (#1248)
-- This commit gets rid of the all-or-nothing vmfb caching mechanism
   and improves the retry mechanism by providing lower-level granularity
   for compiling each model units.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-03-31 09:38:14 -07:00
powderluv
02d52bb626 Add Intel ARC A770 target triple (#1263)
This just enables the plumbing. It generates black images.
2023-03-29 14:49:05 -07:00
Abhishek Varma
3b63645f79 [SD] Fix custom model path for WebUI (#1260)
-- This commit fixes custom model path for WebUI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-29 09:48:11 -07:00
Ean Garvey
d6f740b998 allow pytest to retry getting model artifacts + disable autotuning for pytorch benchmarks (#1257)
* Adds a few xfails to enable macOS builder

* Convert string batch sizes to ints where needed.

* allow pytest to retry getting model artifacts

* Reduce attempts and add assert msg.
2023-03-28 23:38:45 -05:00
Daniel Garvey
594c6b8ea2 fix ckpt dir (#1258) 2023-03-28 14:31:01 -07:00
Ean Garvey
96b1560da5 Make batch size configurable via pytest and fix sharktank generation. (#1227)
* Fix sharktank generation and add batch_size pytest option for torch.

* Disable torch dynamo until py3.11 supported

* Compile torchmodel without dynamo if torch.compile fails

* Use release versions of TF/Keras for importer.

* Pin torchvision and remove debug prints.

* Remove duplicates from torch model list.

* Update generate_sharktank.py

* xfail a few models that fail sharktank generation/ numerics
2023-03-28 14:33:39 -05:00
Abhishek Varma
0ef6a0e234 [SD] Fix Stencil scribble crash by updating image resize (#1255)
-- This commit updates Stencil resize feature to cap the size of
   images within [128,768] as supported by the SD pipeline.
-- This solves the issue of scribble crashing on larger image.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-28 10:13:11 -07:00
Gaurav Shukla
641d535f44 [SD] Fix device path issue for cpu (#1256)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-28 10:09:49 -07:00
Daniel Garvey
5bb7846227 single entry point exe for all cli apps (#1158)
usage:
add --app="img2img" (or "inpaint" "outpaint" "txt2img")
2023-03-28 11:15:21 -05:00
yzhang93
8f84258fb8 Fix check for use_tuned conditions (#1252) 2023-03-27 11:21:25 -07:00
Ean Garvey
7619e76bbd Disable and xfail some models that fail validation/compilation. (#1251)
* Rollback T5 models for torch as the inputs give some issues that aren't trivial to resolve
* xfail efficientnet-b0 on torch+cuda -- see CUDA requesting shared memory size larger than allowed size openxla/iree#12771
2023-03-27 12:42:53 -05:00
Daniel Garvey
9267eadbfa disable openjourney gen for nightly (#1249) 2023-03-27 11:55:34 -05:00
Phaneesh Barwaria
431132b8ee Fix img2img mode switch (#1247)
* add updated scheduler value in global config

* clear scheduler global variable with others
2023-03-27 07:01:22 -07:00
cstueckrath
fb35e13e7a fix Python version detection bug (#1246)
* fix Python version detection bug

* Update setup_venv.ps1
2023-03-27 07:00:40 -07:00
yzhang93
17a67897d1 Add SD v2.1 768x768 tuned model (#1244)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-24 10:39:15 -07:00
Gaurav Shukla
da449b73aa [SD] Disable lora training tab for now (#1241)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-24 09:16:24 -07:00
Kyle Herndon
0b0526699a Fix incorrect device argument initialization for LoRA training by extracting the device type and number and formatting it for pytorch (#1237)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
2023-03-24 01:10:50 -07:00
Boian Petkantchin
4fac46f7bb In models testing fix paths to be relative to the script dir not cwd (#1128)
authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-03-22 15:26:52 -05:00
Daniel Garvey
49925950f1 fix false positives (#1193) 2023-03-22 15:25:39 -05:00
Thomas
807947c0c8 Remove deprecated cli option iree-hal-cuda-disable-loop-nounroll-wa (#1235) 2023-03-22 12:05:15 -05:00
Abhishek Varma
593428bda4 [SD] Fix for transformers/__init__.py issue in PyInstaller (#1233)
-- This commit fixes the transformers/__init__.py issue in PyInstaller.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-22 08:43:53 -07:00
Abhishek Varma
cede9b4fec [SD] Fix custom_vae as a required parameter in inpaint (#1232) 2023-03-22 04:30:17 -07:00
Prashant Kumar
c2360303f0 Add the int8 quantized model. 2023-03-22 16:28:13 +05:30
jinchen62
420366c1b8 Move schedulers to global obj (#1225) 2023-03-21 22:40:43 -07:00
Ean Garvey
d31bae488c Set iree-input-type to tm_tensor for SD (#1228) 2023-03-21 19:07:31 -07:00
Kyle Herndon
c23fcf3748 Fix incorrect device argument initialization for LoRA training (#1231)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-21 19:07:18 -07:00
jinchen62
7dbbb1726a Fix SD obj not defined if fail to get models from pretrained (#1222) 2023-03-21 07:55:17 -07:00
Abhishek Varma
8b8cc7fd33 [SD] Update LoRA inference to handle various checkpoints (#1215) 2023-03-21 06:52:20 -07:00
Ean Garvey
e3c96a2b9d Move sentencepiece to importer requirements. (#1218) 2023-03-21 00:39:57 -05:00
Ean Garvey
5e3f50647d Set --vulkan_large_heap_block_size default to 2gb. (#1220) 2023-03-20 21:07:09 -07:00
gpetters94
7899e1803a Add fix for attention slicing fp16 (#1217) 2023-03-20 19:11:29 -07:00
mariecwhite
d105246b9c Fix t5 models 2023-03-21 10:39:59 +11:00
mariecwhite
90c958bca2 Add T5-base and T5-large Torch and TF Models (#1116) 2023-03-20 17:32:50 -05:00
mariecwhite
f99903e023 Add EfficientNet B0 and B7 Torch and TF models 2023-03-21 09:22:05 +11:00
mariecwhite
c6f44ef1b3 Add EfficientNet B0 and B7 Torch and TF models 2023-03-21 09:14:45 +11:00
mariecwhite
8dcd4d5aeb Make batch size configurable 2023-03-20 18:03:17 -04:00
Phoenix Meadowlark
d319f4684e Add peak memory reporting for IREE, TF and PyTorch (#1216) 2023-03-20 15:40:49 -05:00
Ean Garvey
54d7b6d83e Generate model artifacts in pytests if they don't exist in the cloud. (#1121)
* Add gen_shark_files fn to shark_downloader for OTF artifact generation

* add generate_sharktank as a tank/ python module.

* Fix some paths in tank generation.
2023-03-20 12:13:19 -05:00
m68k-fr
4a622532e5 [Web] Stop images (#1212) 2023-03-19 14:37:30 -07:00
cstueckrath
650b2ada58 add pytorch_lightning to requirements (#1211)
* add pytorch_lightning to requirements

this will additionally add lightning-utilities and torchmetrics

* Update shark_sd.spec

* Update shark_sd_cli.spec
2023-03-19 12:29:54 -07:00
m68k-fr
f87f8949f3 [Web] CSS fix for gradio V3.22.1 (#1210) 2023-03-19 06:13:59 -07:00
m68k-fr
7dc9bf8148 [Web] Move "stop Batch" button to "Advanced Options" toggle (#1209) 2023-03-18 20:54:42 -07:00
Kyle Herndon
ba48ff8d25 Implement LoRA training and UI for training and UI for inference in img2img, inpaint, outpaint (#1200)
txt2img inference UI is already committed.

Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-17 12:54:56 -07:00
Gaurav Shukla
638840925c [SD] Add support for larger size upscaling (#1204)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-17 10:20:48 -07:00
m68k-fr
b661656c03 [Web] Fix custom model path for upscaler (#1199) 2023-03-16 15:57:23 -07:00
Gaurav Shukla
0225434389 [SD] Add sendTo Upscaler
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-16 20:49:19 +05:30
Gaurav Shukla
7ffe20b1c2 [SD] Release memory used by upscaler when not in use
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-16 20:49:19 +05:30
Gaurav Shukla
d8f0c4655d [SD] Add Upscaler web
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-16 20:49:19 +05:30
Gaurav Shukla
7e8d3ec0df [SD] Add upscalar pipeline
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-16 20:49:19 +05:30
jinchen62
9c08eec565 Clear memory cache when switching model and mode (#1194) 2023-03-15 22:18:26 -07:00
m68k-fr
2d2c523ac5 [Web] Upgrade Gradio to v3.21.0 (#1188)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-15 10:14:49 -07:00
Abhishek Varma
f17b3128c0 [SD] Add LoRA inference to SD pipeline (#1189)
-- This commit adds LoRA inference to SD pipeline.
-- It also modifies txt2img to incorporate the new feature.
   img2img, inpaint, outpaint, etc using Unet can also be extended in a
   similar way.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-15 10:13:45 -07:00
Abhishek Varma
7c7e630099 [SD] Add fix for using latest diffusers + add scribble variant to Stencil (#1191)
* [SD] Add Scribble variant in Stencil

-- This commit adds scribble variant in Stencil.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [SD] Use latest diffusers

-- This commit points back to the latest diffusers and updates the
   processing script to tackle the Pix2Pix import issue.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-15 10:13:20 -07:00
m68k-fr
2dd1491ec1 [Web] Add clear queue button (#1192) 2023-03-15 10:12:59 -07:00
Daniel Garvey
236357fb61 add missing import for shark_sd.spec (#1190)
L
2023-03-15 09:23:29 -05:00
Phoenix Meadowlark
7bc38719de Add benchmark artifacts to .gitignore (#1186) 2023-03-14 15:19:06 -07:00
Daniel Garvey
bdbe992769 Add IREE_SAVE_TEMPS for import_debug command (#1184)
based on hf_model_id. Works on windows
2023-03-14 11:40:23 -07:00
Abhishek Varma
e6b925e012 [SD] Add Openpose to Stencil + image size issue fix (#1181)
-- This commit adds openpose model variant to stencil.
-- Fixes image size issue.
-- Also includes fix for the .exe bug introduced by https://github.com/nod-ai/SHARK/pull/1175

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-14 10:30:52 -07:00
cstueckrath
771120b76c workaround Gradio issue (#1183)
https://discord.com/channels/973663919757492264/975522729564446740/1085109774758191164
2023-03-14 01:27:24 -07:00
Boian Petkantchin
a8ce7680db Add flag to augment the device allocator (#1182)
Example:
$ python my_app.py --device_allocator caching debug
This will wrap the device allocator with first caching allocator then
debug allocator.

$ python my_app.py --device_allocator caching
Only wrap with caching allocator.

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-03-13 15:49:26 -07:00
Phaneesh Barwaria
b6dcf2401b Stencil perf improvement (#1179)
* remove conditioning strength multiplier

* mod diffusers lib to v0.14.0
2023-03-13 14:37:38 -07:00
Daniel Garvey
62b5a9fd49 generate sharktank for apps dir (#966)
* merge confix resolution

* add support to other scripts

---------

Co-authored-by: dan <dan@nod-labs.com>
2023-03-13 10:54:15 -07:00
m68k-fr
2f133e9d5c Fix png metadata (#1178) 2023-03-12 22:43:39 -07:00
powderluv
f898a1d332 Update README.md 2023-03-12 16:54:42 -07:00
m68k-fr
b94266d2b9 [Web] Randomize seed to -1 (#1176) 2023-03-12 12:42:31 -07:00
m68k-fr
1b08242aaa [Web] Improve dropdowns ux (#1175) 2023-03-12 12:41:51 -07:00
Abhishek Varma
691030fbab [SD] Improve Stencil feature to handle general image sizes
-- Currently stencil feature works with 512x512 images only.
-- This commit relaxes this constraint and adds support for various
   image sizes.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-11 21:48:31 +05:30
m68k-fr
16ad7d57a3 [WebUi] txt2img_ui: Import png metadata (#1147) 2023-03-10 16:26:34 -08:00
Anush Elangovan
c561ebf43c Drop the torch-mlir pin
Seems to work now with top of master
2023-03-10 15:39:04 -08:00
Prashant Kumar
97fdff7f19 Add instructions how to run the LLaMA model. (#1168)
* Add instructions how to run the LLaMA model.

* Update README.md
2023-03-10 12:36:37 -08:00
Anush Elangovan
ce6d82eab2 Fix bloom lint 2023-03-10 11:53:08 -08:00
Abhishek Varma
b8f4b18951 [SD] Use dynamic stencil HF repo id
-- This commit removes the hardcoded HF ID for Stencil and instead
   utilizes a dynamic instantiation of HF model.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-10 23:31:45 +05:30
Eliasj42
b23d3aa584 added more memory efficient method to run large bloom models with sharded blooms (#1165)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-03-10 09:32:56 -08:00
Vivek Khandelwal
495670d9b6 Fix SD fine tuning script device arg usage 2023-03-10 18:37:53 +05:30
Boian Petkantchin
815e23a0b8 Update iree-compile flags --iree-llvm-xxx -> --iree-llvmcpu-xxx (#1164) 2023-03-09 11:31:50 -08:00
Boian Petkantchin
783538fe11 Move linting opts from github workflow to config files
This helps development where you can be sure that running locally

black .
flake8 .

will do the same as in the github job.
2023-03-09 10:46:30 -08:00
Boian Petkantchin
996c645f6a In SD don't include device path in vmfb filename
Include only the driver name instead.
2023-03-09 10:45:32 -08:00
m68k-fr
1f7d249a62 Use utf-8 format for imgs_details.csv 2023-03-09 16:15:58 +05:30
jinchen62
7f6c9a2dc2 Add an inpainting option for only masked area (#1154) 2023-03-07 09:46:05 -08:00
Eliasj42
93891984f3 made sharded bloom example more user friendly (#1153)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-03-06 10:23:48 -08:00
Vivek Khandelwal
cc0ef54e0e Fix Stable diffusion fine tuning script 2023-03-06 17:52:16 +05:30
Daniel Garvey
812152485d temporarily xfail tiny convnext macos (#1142) 2023-03-03 13:30:56 -06:00
Vivek Khandelwal
0816fb403a Add Stable diffusion fine tuning script
This commit adds the sd fine tuning script which runs through the
torchdynamo path.
2023-03-03 21:59:00 +05:30
Gaurav Shukla
4f171772be [SD] Fix SD web flags
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-03 21:55:40 +05:30
mariecwhite
a52331d4aa Install IREE pre-releases (#1139) 2023-03-02 23:17:56 -06:00
yzhang93
ad821a1fc8 Use old torch-mlir package to avoid crash on rdna2 (#1137) 2023-03-02 18:16:58 -08:00
Ean Garvey
116b128802 Use nightly shark_tank for test-models (#1133)
* Use nightly shark_tank for test-models

* Update all_models.csv
2023-03-02 12:33:36 -06:00
Gaurav Shukla
b118f183d1 [SD] Fix few things in sendTo feature (#1132) 2023-03-02 09:11:55 -08:00
Gaurav Shukla
911dff16f1 [SD] Add sendTo feature in stable diffusion (#1131)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-02 08:42:38 -08:00
Abhishek Varma
de59a66ae4 [SD] Update diffusers to point to the fix for Stencil + add opencv-python (#1130) 2023-03-02 08:19:29 -08:00
Daniel Garvey
23f1468cc6 disable most models on windows pytest (#1125) 2023-03-02 01:37:50 -06:00
jinchen62
080350d311 Make loading custom inpainting models general (#1126) 2023-03-01 22:14:04 -08:00
Phaneesh Barwaria
7f3f92b9d5 remove extra return arg (#1123)
* remove extra return arg

txt2img expects only 3 mlirs

* add venv reqs for stencils
2023-03-01 11:45:24 -08:00
Abhishek Varma
be3cdec290 [SD] Add Stencil feature to SD pipeline (#1111)
* [WIP] Add ControlNet to SD pipeline

-- This commit adds ControlNet to SD pipeline.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [SD] Add ControlNet to img2img + fix bug for img2img scheduler

-- This commit adds ControlNet execution to img2img.
-- It restructures the addition of ControlNet variants.
-- It also fixes scheduler selecting bug for img2img pipeline.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* add shark models for stencilSD

* Add Stencil controlled SD in img2img pipeline (#1106)

* use shark stencil modules

* adjust diffusers change

* modify to use pipeline

* remove control from unet

* pump stencils through unet

* complete integration in img2img

* fix lint and comments

* [SD] Add ControlNet pipeline + integrate with WebUI + add compiled flow execution

-- This commit creates a dedicated SD pipeline for ControlNet.
-- Integrates it with img2img WebUI.
-- Integrates the compiled execution flow for ControlNet.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [SD] Stencil execution

* Remove integration setup

* [SD] Fix args.use_stencil overriding bug + vmfb caching issue

-- This commit fixes args.use_stencil overriding issue which caused
   img2img pipeline to pick wrong set of modules.
-- It also fixes vmfb caching issue to speed up the loading time
   and pick right set of modules based on a mask.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
2023-03-01 10:44:40 -08:00
m68k-fr
f09574538c [WebUi] Remove unsupported full_width parameter, Reactivate gallery nav while multiple images are generated 2023-03-01 23:17:12 +05:30
Daniel Garvey
b1113ab551 disable benchmark on windows for pytest (#1100) 2023-02-28 18:10:29 -06:00
powderluv
ef756389e3 Revert "add cv2 and nod diffusers (#1112)" (#1114)
This reverts commit cb17d017df.
2023-02-28 14:31:40 -08:00
Phaneesh Barwaria
cb17d017df add cv2 and nod diffusers (#1112) 2023-03-01 01:33:43 +05:30
Gaurav Shukla
798f231792 [SD] Update metadata info and canvas size (#1109)
* [SD] Save missing metadata in case of img2img and outpaint

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

* [SD] Update the canvas size for inpaint/outpaint

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

* [SD] Update output gallery on each inference

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

---------

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-28 11:25:30 -08:00
m68k-fr
7136890da3 [Fix] Unsupported width and height argument error 2023-02-28 23:32:58 +05:30
mariecwhite
d567192fd3 Fix call to Torch Inductor 2023-02-28 00:35:57 -08:00
jinchen62
dcc4025c78 Fix loading custom inpainting models (#1103) 2023-02-27 17:06:09 -08:00
yzhang93
c6c8ec36a1 Enable tuned models for inpainting (#1102) 2023-02-27 16:46:57 -08:00
Quinn Dawkins
1344c0659a Add doc on profiling with Shark (#1101)
* Add doc on profiling with Shark

* Rename doc
2023-02-27 11:31:27 -08:00
powderluv
973f6d20f4 Try pre-pix2pix 2023-02-25 00:09:05 -08:00
powderluv
8b5c9c51e7 Revert "Update diffusers (#1094)" (#1096)
This reverts commit 0064cc2a6e.
2023-02-24 19:27:56 -08:00
jinchen62
bae208bcc4 Fix outpainting params (#1089) 2023-02-24 14:41:32 -08:00
Daniel Garvey
b6c14ad468 Make sd tests output performance metrics into csv (#1085)
* make some paths windows friendly (#1066)

* add csv output to builder script

and reduce number of models tested
2023-02-24 16:27:52 -06:00
powderluv
0064cc2a6e Update diffusers (#1094) 2023-02-24 14:09:19 -08:00
Gaurav Shukla
0a0567e944 [SD] Avoid unnecessary temp file creations (#1092)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-24 10:53:34 -08:00
gpetters94
694b1d43a8 Add attention slicing support (#1087) 2023-02-24 02:43:02 -08:00
Ean Garvey
e7eb116bd2 use tf-nightly for importer (#1077) 2023-02-23 23:14:48 -06:00
yzhang93
596499a08c Disable tuned configs on all inpainting models (#1086) 2023-02-23 13:15:22 -08:00
naveen raj
2a2e460df2 Add DEISMultistep scheduler #1076 (#1084)
* Add DEISMultistep scheduler #1076

* line lenght lint fix
2023-02-23 10:15:05 -08:00
jinchen62
a9039b35ed Add outpainting web UI (#1083) 2023-02-23 01:02:25 -08:00
jinchen62
a01154a507 Add SD outpainting (#1072)
python apps/stable_diffusion/scripts/outpaint.py --prompt="Face of a yellow cat, high resolution, sitting on a park bench" --img_path=test_imgs/overture-creations-5sI6fQgYIuo.png --import_mlir --hf_model_id="stabilityai/stable-diffusion-2-inpainting" --pixels=128 --mask_blur=8 --left --right --top --bottom --steps=20
2023-02-22 23:16:05 -08:00
powderluv
1d9204282d Update README.md 2023-02-22 23:12:41 -08:00
Eliasj42
5ff40a0d2d added an example to run sharded bloom (#1079)
added ability to compile sharded mlir files from hugingface models

Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-02-22 22:48:58 -08:00
jinchen62
fab6d2e4e0 Resize input image and mask for SD inpainting (#1082) 2023-02-22 22:46:59 -08:00
powderluv
abab59c25f Update nightly.yml 2023-02-22 18:44:43 -08:00
powderluv
c25840b585 Update nightly.yml 2023-02-22 18:34:37 -08:00
powderluv
1b3f9125bb Update nightly.yml 2023-02-22 18:23:44 -08:00
powderluv
b5d9f5ba49 Update nightly.yml 2023-02-22 18:20:31 -08:00
powderluv
1c22aa9c8f Resolve __init__.py issues (#1080)
Also drop torchvision. The test passed and didn't fail but
we can't be sure it fixes the __init__.py issue yet.
2023-02-22 18:17:00 -08:00
Daniel Garvey
e1d7fb879c make some paths windows friendly (#1066) 2023-02-22 14:44:55 -06:00
powderluv
e912c42bf0 update the openxla links 2023-02-22 12:10:23 -08:00
powderluv
e6841acf36 Publish nightlies as pre-releases
So stable versions can be marked on the Releases page
2023-02-22 12:05:28 -08:00
Gaurav Shukla
bc4459b6f4 [SD] Add inpainting web UI (#1069)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-22 11:01:18 -08:00
cstueckrath
9b544491e0 Update setup_venv.ps1 (#1073)
* Update setup_venv.ps1

fix a bug that occurs, when Python is installed but no py.exe is available

* Update setup_venv.ps1
2023-02-22 07:52:59 -08:00
m68k-fr
9c5415b598 [WebUi] css fix for Gradio v3.19.0 (#1059)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-02-21 23:50:54 -08:00
powderluv
040dbc317f unpin diffuser to latest (#1071)
Currently 0.13.x
2023-02-21 23:47:19 -08:00
powderluv
65775046d8 update IREE pip links 2023-02-21 19:31:23 -08:00
Daniel Garvey
b18bc36127 force creation of workdir (#1070) 2023-02-21 18:10:36 -08:00
cstueckrath
f01c526efd Update setup_venv.ps1 (#1064) 2023-02-21 14:13:04 -05:00
Gaurav Shukla
16168ab6b3 [SD] Update need_vae_encode correctly
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-21 20:26:06 +05:30
Gaurav Shukla
4233218629 [SD] Reset args.img_path to None in txt2img to avoid vae_encode
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-21 18:46:15 +05:30
RaINi_
b63fb36dc0 Use path.join for the winograd config directory (#1065) 2023-02-20 22:04:25 -06:00
Daniel Garvey
4e92304b89 remove annoying accelerate warning (#1056)
disables usage of low_cpu_mem_usage=True in from_pretrained() calls.
Can be re-enabled by using flag --low_cpu_mem_usage
defaults to False to avoid spam as we don't include accelerate in our
requirements.txt
2023-02-20 14:46:26 -06:00
Ean Garvey
2ae047f1a8 Update importer/benchmark setup for python3.11 (#1043) 2023-02-20 11:29:00 -06:00
Ean Garvey
6d2a485264 Add --benchmark_dispatches option to pytest. (#800)
* Add --benchmark_dispatches option to pytest.

* Update README.md and fix filepath for dispatch benchmarks
2023-02-19 12:16:18 -06:00
Daniel Garvey
4f045db024 disable anythingv3 until issue is resolved (#1053) 2023-02-18 23:47:21 -05:00
yzhang93
5b33597b6d Enable v1.5 to use tuned configs (#1049) 2023-02-18 16:54:26 -05:00
m68k-fr
962470f610 [WebUi] Minor interface cleanup and Ui cosmetics 2023-02-17 22:00:47 +05:30
cstueckrath
ba8c116380 add KDPM2Discrete and a force flag for setup_venv (#1044)
* add KDPM2Discrete and a force flag for setup_venv

* add KDPM2Discrete and a force flag for setup_venv
also made sure that Python 3.11 is used for the venv as 3.10
doesn't work anymore

* add KDPM2Discrete and a force flag for setup_venv
also made sure that Python 3.11 is used for the venv as 3.10
doesn't work anymore
2023-02-17 07:19:56 -05:00
jinchen62
ad7330eae4 Add inpainting test (#1011) 2023-02-16 22:17:10 -06:00
yzhang93
cf126e4839 Use tuned configs on custom models with ckpt_loc (#1038) 2023-02-16 17:06:21 -08:00
powderluv
c96d25c3e2 Delete stable_diffusion_amd.md
All instructions are common now and on the main page.
2023-02-16 14:57:32 -08:00
powderluv
006aa0dae2 Update README.md 2023-02-16 14:54:00 -08:00
Daniel Garvey
5b204bee86 temporarily xfail microsoft resnet50 (#1037)
Co-authored-by: dan <dan@nod-labs.com>
2023-02-16 16:14:51 -06:00
Phaneesh Barwaria
d98b2afbe9 img2img denoise strength (#1040) 2023-02-16 13:40:20 -08:00
Daniel Garvey
681332ef32 fix tests after default flag changes (#1009)
* fix tests after default flag changes

also adds support for import-mlir

* Update setup_venv.ps1

---------
2023-02-16 12:57:50 -06:00
mariecwhite
c3a4fdcbfc Add bert-large-uncased TF model 2023-02-15 21:42:44 -08:00
mariecwhite
aac5de5b02 Add bert-large-uncased Torch model 2023-02-15 21:25:32 -08:00
powderluv
13a255afad Update nightly.yml 2023-02-15 17:11:38 -08:00
powderluv
3bffda52f9 Pin to latest diffusers (#1031) 2023-02-15 14:23:10 -08:00
Daniel Garvey
d4e62ce557 add an import-mlir fallback in case of failure (#1030)
may not cover all cases. will observet

Co-authored-by: dan <dan@nod-labs.com>
2023-02-15 16:15:23 -06:00
yzhang93
9738483b18 [SD] Map v2_1 to v2_1_base until fix (#1029) 2023-02-15 13:44:41 -08:00
Abhishek Varma
143492fe94 [SD] Add support for standalone Vae checkpoints (#1020)
-- This commit adds support for standalone Vae checkpoints.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-15 12:17:32 -08:00
Gaurav Shukla
ecc5c662c4 [SD] Save output images to different loc every day (#1027)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-15 12:16:36 -08:00
yzhang93
d973ba191d Add conditions to force use --import_mlir (#1028) 2023-02-15 10:37:09 -08:00
Gaurav Shukla
0198b183a2 [SD] Img2Img works for limited schedulers.
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-15 23:06:28 +05:30
Gaurav Shukla
0d44a3527b [SD][web] Add strength UI for img2img
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-15 22:47:41 +05:30
Gaurav Shukla
2147b6a397 [SD] Move some common code to utility
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-15 22:47:41 +05:30
Gaurav Shukla
6b5b4ba27b [SD] Add batch count in Image2Image
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-15 22:47:41 +05:30
Gaurav Shukla
67005bf57c [SD] Update iree-vulkan-target-triple after device switch
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-15 22:47:41 +05:30
PhaneeshB
0430c741c6 add strength param 2023-02-15 20:59:03 +05:30
powderluv
1ce02e365d Update README.md 2023-02-15 01:22:28 -08:00
m68k-fr
eae862adc2 Fix lint and path for gradio_tmp_imgs_folder 2023-02-15 14:27:29 +05:30
drumicube
dffa89524a Save gradio tmp images to shark_tmp folder and clean it at launch 2023-02-15 14:27:29 +05:30
yzhang93
2af1102441 [SD] Merge configs of different max lengthes from the same variant to one config file (#1019) 2023-02-15 00:25:29 -08:00
powderluv
c4b472842a Update stable_diffusion_amd.md 2023-02-14 19:02:20 -08:00
powderluv
750a7d806f update docs to 3.11 2023-02-14 17:12:09 -08:00
powderluv
bc7333f1e5 Remove forcing LLPC setting (#1018)
also fix logo paths
2023-02-14 17:09:03 -08:00
powderluv
55ae50f991 Update inpaint.py 2023-02-14 14:12:05 -08:00
powderluv
a590c331ef Update img2img.py 2023-02-14 14:11:50 -08:00
powderluv
8c241b06cb Update txt2img.py 2023-02-14 14:11:36 -08:00
powderluv
9c072c8068 Update index.py 2023-02-14 14:11:20 -08:00
powderluv
ebd8b5122a Update stable_diffusion_amd.md 2023-02-14 14:09:34 -08:00
powderluv
055e484a40 Update README.md 2023-02-14 14:06:46 -08:00
powderluv
912c4a1d12 Update shark_sd.spec 2023-02-14 13:21:29 -08:00
Abhishek Varma
c203b65bf1 Fix __file__ AttributeError + Remove --enable_stack_trace (#1015) 2023-02-14 07:55:02 -08:00
powderluv
307f0334ee Drop im2col for VAE since it crashes the driver (#1010)
This is for untuned models.
2023-02-13 19:02:51 -05:00
yzhang93
5167df08b9 [SD] Fix cuda OTF annotation (#1008) 2023-02-13 12:32:50 -08:00
Gaurav Shukla
dd2e482214 [SD] Fix multiple call to device check (#1007)
- Also makes the dark theme default.
- Fix custom_vae parameter in img2img.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-13 11:57:52 -08:00
Eliasj42
87fd13d8eb added an example to run sharded bloom (#1003)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-02-13 10:37:47 -08:00
yzhang93
dd423bc6de [SD] Using --compile-to to dump mlir for OTF annotation (#1004)
* [SD] Using --compile-to to dumpmlir for preprocessing

* Use python api for dumping process
2023-02-13 09:17:59 -08:00
powderluv
899cb9cc1f Temporarily disable signing of exe 2023-02-12 20:37:42 -08:00
drumicube
0464c7e558 Add support for command arguments to the WebUi (#1000)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-02-11 19:20:21 -08:00
powderluv
f64e1fb926 Fix dark theme again for exe builds (#1001) 2023-02-11 19:08:17 -08:00
powderluv
ef7d31293d Update tests to 3.11 2023-02-11 15:38:27 -08:00
powderluv
6d54eb68dc update to support 3.11 2023-02-11 15:23:18 -08:00
powderluv
30eb10c990 Update to 3.11 2023-02-11 03:47:14 -08:00
Abhishek Varma
591bbcd058 [SD] Fix vmfb locating bug
-- This commit fixes a bug in vmfb caching due to vae_encoder and also
   involves a minor NFC change in the code.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-10 23:33:47 +05:30
Abhishek Varma
99aa77d036 [SD] Add a common way to name vmfbs including custom_vae
-- This commit adds a common way to name vmfbs and adds to it `custom_vae`
   support as well.
-- This was required to make a common place to change vmfbs name
   without breaking any feature support AND also tackle the caching
   of vmfbs gracefully.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-10 23:33:47 +05:30
Abhishek Varma
9c13f1e635 Add custom vae support using --custom_vae flag
-- This commit adds custom vae support to SD wherein the user can
   point to a model's checkpoint file whose Vae needs to be plugged
   into the main model.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-10 23:33:47 +05:30
Gaurav Shukla
24af983cfb [SD] Fix input image type
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-10 23:27:52 +05:30
Gaurav Shukla
67842a7525 [SD] Fix parameters in img2img
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-10 22:03:33 +05:30
PhaneeshB
3159a6f3e1 add support for img1img 2023-02-10 21:29:02 +05:30
Gaurav Shukla
b2f3c96835 [SD][web] Add Img2Img UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-10 21:27:31 +05:30
jinchen62
6582475955 Add SD inpainting
python apps/stable_diffusion/scripts/inpaint.py --prompt="prompt" --img_path=path/to/img --mask_path=path/to/mask --import_mlir --max_length=77 --hf_model_id="stabilityai/stable-diffusion-2-inpainting"
2023-02-10 15:33:20 +05:30
Anush Elangovan
41ee65b377 Revert "Enable --device_allocator=caching"
This reverts commit 83fe477066.
2023-02-09 23:00:06 -08:00
Anush Elangovan
83fe477066 Enable --device_allocator=caching 2023-02-09 22:58:46 -08:00
yzhang93
4ca84ee4ee Revert "Delete unnecessary arg setting (#978)" (#985)
This reverts commit 83c69ecd49.
2023-02-09 16:44:26 -08:00
Ean Garvey
c28cc4c919 Fix local_tank_cache handling in shark_downloader. (#981) 2023-02-09 14:52:03 -06:00
yzhang93
e9864cb3f7 Modify the annotation OTF to return bytecode module (#980) 2023-02-08 14:29:43 -08:00
yzhang93
83c69ecd49 Delete unnecessary arg setting (#978) 2023-02-08 10:30:18 -08:00
Prashant Kumar
3595b4aaff Incorporate latest changes in the shark_dynamo backend. 2023-02-08 20:37:30 +05:30
Abhishek Varma
3a9cfe113a Fix SD restart error in exe file (#975)
-- This commit fixes SD restart error in exe file by creating
   variants.json in CWD instead of a relative path.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-08 06:14:08 -08:00
yzhang93
c9966127da Fix iree flags to be able to run on rdna2 (#972) 2023-02-07 16:39:32 -08:00
Ean Garvey
51300d33a7 Remove non-SD args from generate_sharktank.py (#970) 2023-02-07 13:29:55 -06:00
Gaurav Shukla
5af124c5a5 [SD] Add batch count in stable diffusion
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-07 23:26:46 +05:30
Abhishek Varma
eeb20b531a Fix restart SD session error + override args.use_tuned temporarily
-- This commit fixes the session restart error for SD.
-- It also overrides `args.use_tuned` for `import_mlir`, and sets
   `use_tuned` as `False`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-07 19:50:48 +05:30
cstueckrath
9dca842c22 Update .gitignore to exclude models (#967)
the models folder will be stashed along with other changes and most likely kill git doing so.
2023-02-07 01:48:36 -08:00
Ean Garvey
1eb9436836 Fix generate_sharktank args. 2023-02-07 14:06:07 +05:30
Ean Garvey
9604d9ce81 make --update_tank update only if hash mismatch 2023-02-07 14:06:07 +05:30
Ean Garvey
481d0553d8 Remove unnecessary repro_dir / shark_tmp usage 2023-02-07 14:06:07 +05:30
powderluv
60035cd63a Add css in exe (#963)
exe should now default to dark theme too
2023-02-06 15:26:08 -08:00
drumicube
d35f992ace Bring back the --runs options for the cmd command and fix wrong seed/model reported in json, csv and png (#962) 2023-02-06 15:16:50 -06:00
Daniel Garvey
157ae64f9d print to stdout for test visibility (#937)
Co-authored-by: dan <dan@nod-labs.com>
2023-02-06 01:03:27 -08:00
powderluv
ffa17f6057 Update sd_dark_theme.css 2023-02-06 01:01:50 -08:00
drumicube
d695a43e32 Make the dark theme default while launching web server (#954) 2023-02-05 07:25:45 -08:00
powderluv
01f6b4e6f0 Update README.md 2023-02-04 23:40:13 -08:00
yzhang93
7cf31a6ae4 Fix iree-benchmark flag names (#952) 2023-02-04 22:24:18 -08:00
Quinn Dawkins
fbd6224b04 Revert "Revert pipelines (#948)" (#951)
This reverts commit 8115b26079.
Additionally fixes img2col by adding detach elementwise from named op
passes.
2023-02-04 22:44:08 -05:00
powderluv
8115b26079 Revert pipelines (#948)
* Revert "[SD] Modify the flags to use --iree-preprocessing-pass-pipeline (#914)"

This reverts commit a783c089a9.

* Revert "Fix iree flags due to the change in shark-runtime (#944)"

This reverts commit 1d38d49162.
2023-02-04 07:09:51 -08:00
powderluv
820586ac68 Update README.md 2023-02-04 01:01:11 -08:00
powderluv
4a7441ed07 Update profiling_with_iree.md 2023-02-04 00:47:57 -08:00
powderluv
383741f284 Update stable_diffusion_amd.md 2023-02-04 00:40:47 -08:00
powderluv
2bbc4e0e9f Update README.md 2023-02-04 00:35:40 -08:00
powderluv
a7237244b0 Send users to the .exe file first 2023-02-04 00:30:32 -08:00
yzhang93
1d38d49162 Fix iree flags due to the change in shark-runtime (#944) 2023-02-03 21:34:02 -08:00
yzhang93
a783c089a9 [SD] Modify the flags to use --iree-preprocessing-pass-pipeline (#914)
* [SD] Modify the flags to use --iree-preprocessing-pass-pipeline

* Fix flags in sd_annotation
2023-02-03 15:08:02 -08:00
powderluv
e7907dc532 Disable tuned models for sm_89 (#943)
Looks like tuning on A100 doesn't necessarily translate to 40xx.
2023-02-03 14:30:46 -08:00
powderluv
394413679d Fix ckpt_dir (#939) 2023-02-03 12:54:19 -08:00
powderluv
37189f14cb roll to 492 2023-02-03 11:59:18 -08:00
powderluv
0b1ee81901 Minor webui changes (#938) 2023-02-03 11:26:45 -08:00
Gaurav Shukla
00cf73f9b8 [SD] Merge model id dropdown and .ckpt dropdown (#936)
- use_tuned is set to False for custom checkpoints.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-03 10:43:33 -08:00
Abhishek Varma
5a5f285493 [apps-SD] Prepone loading of vmfbs + restructure the SD pipeline
-- This commit prepones loading of vmfbs, if present, for all sub-models.
-- It also involves restructuring the SD pipeline to achieve the loading
   of vmfbs smoothly and postpones processing of checkpoint files only when
   required.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-03 20:21:24 +05:30
powderluv
7f2ea454b6 revert /base variants as they are different (#929)
sd2_1base is different from VAE base (for older cards)
2023-02-03 01:27:32 -08:00
Daniel Garvey
7c14002118 Map 2_1 to 2_1_base (#927)
* fix broken paths for older models

* adds a mapping from sd_2_1 to sd_2_1_base

we only have models in models_db for 2_1_base.
now that diffusers is fixed we can actually generate
2_1 itself, but until we add support for that in the tank
we should fetch 2_1_base for no-import_mlir

---------

Co-authored-by: dan <dan@nod-labs.com>
2023-02-02 19:03:19 -08:00
powderluv
3e9554f0a1 roll to 487 2023-02-02 19:02:39 -08:00
Daniel Garvey
e11ffec544 fix broken paths for older models (#926)
Co-authored-by: dan <dan@nod-labs.com>
2023-02-02 15:48:19 -08:00
powderluv
8a47ddbe99 Update models/ location in UI (#925)
default to png metadata on
2023-02-02 15:28:39 -08:00
powderluv
821108c7bd Fix models path (#924) 2023-02-02 15:16:00 -08:00
Gaurav Shukla
339738f8a3 [SD][web] Populate checkpoints as dropdown UI (#918)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-02 13:59:50 -08:00
powderluv
9b90672f63 Fix LLPC env var (#920) 2023-02-02 11:45:08 -08:00
Ean Garvey
ba07e94a5e disable Torch Inductor autotuner in benchmarks (#919) 2023-02-02 13:25:43 -06:00
aldesilv
b3fc0f29cc enable additional flags for tank test models (#866)
Co-authored-by: Alex <alexander@nod-labs.com>
2023-02-02 11:19:33 -08:00
Gaurav Shukla
5c7deb3611 [SD] Fix output image location (#917)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-02 09:50:37 -08:00
Daniel Garvey
15604e374f change bytecode model paths (#913)
Co-authored-by: dan <dan@nod-labs.com>
2023-02-02 11:12:13 -06:00
Abhishek Varma
7cfc0fa55b [APPS-SD] Fix a few bugs and bring it up to speed with SD CLI (#908) 2023-02-02 07:12:01 -08:00
Ean Garvey
a90812133b Enable pytests on Windows (#901) 2023-02-01 18:36:41 -06:00
powderluv
e26a70aa4f Drop old cli and webui (#911) 2023-02-01 13:13:46 -08:00
Daniel Garvey
6a32a4e26c move ci sd stuff to apps (#912)
Co-authored-by: dan <dan@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-02-01 12:15:07 -08:00
powderluv
e853abf98b Update stable_diffusion_amd.md 2023-02-01 11:11:58 -08:00
powderluv
51e81e6ef8 update main readme 2023-02-01 11:09:00 -08:00
powderluv
e355000ceb Drop torchvision 2023-02-01 10:26:37 -08:00
Daniel Garvey
e374074013 Windows test (#896)
* add generate_sharktank for stable_diffusion model defaults

* add windows test for sd

---------

Co-authored-by: dan <dan@nod-labs.com>
2023-02-01 12:03:54 -06:00
powderluv
81e3d1c2c6 switch to apps/ 2023-02-01 06:54:20 -08:00
powderluv
ab0cbb4475 Add PyInstaller for apps/ webui and cli (#909)
tested webui, cli and webui exe and cli exe
2023-02-01 06:51:27 -08:00
powderluv
1c64e40722 Add PyInstaller for apps/ (#907)
Build with pyinstaller.exe .\apps\stable_diffusion\web\shark_sd.spec

normal flow works. exe is missing a few json files
2023-02-01 06:04:49 -08:00
Evan Guan
8cafe56eb4 Added flags for metadata information. (#894) 2023-02-01 05:16:11 -08:00
Eliasj42
3eceeb7b23 fixed a bug that would sometimes cause intel-gpu to appear unsupported (#899)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-01-31 22:32:05 -08:00
powderluv
1a37675435 Revert "move beta to release (#898)" (#905)
This reverts commit 7edcaf5a06.
2023-01-31 20:31:41 -08:00
powderluv
198ebede8d Revert "replace new model_db.json (#902)" (#904)
This reverts commit 842adef29c.
2023-01-31 20:29:40 -08:00
Ean Garvey
a504903dd5 Fix formatting issues. (#903) 2023-02-01 09:12:45 +05:30
Daniel Garvey
842adef29c replace new model_db.json (#902) 2023-01-31 18:55:22 -08:00
Daniel Garvey
7edcaf5a06 move beta to release (#898)
Co-authored-by: dan <dan@nod-labs.com>
2023-01-31 17:14:08 -06:00
Gaurav Shukla
c124b76328 [SD] Reorganize the stable diffusion model. (#806)
The stable diffusion codebase has been reorganized to make it more
modular so that the same script can be used for web as well as cli,
instead of duplicating the whole codebase.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-31 14:42:41 -08:00
aldesilv
e9c744ee5d find rocm arch used in rocminfo (#893)
Co-authored-by: Alex <alexander@nod-labs.com>
2023-01-31 10:22:31 -08:00
Ean Garvey
83302930d8 Update generate_sharktank.py (#897) 2023-01-31 10:21:22 -08:00
Daniel Garvey
a4634632ba add generate_sharktank for stable_diffusion model defaults (#742)
Co-authored-by: dan <dan@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-01-31 09:44:54 -08:00
Abhishek Varma
d17e8dc5ad [NFC] Rename SD negative_prompts flag
-- This commit renames SD `negative-prompts` -> `negative_prompts` flag.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-31 21:38:59 +05:30
powderluv
9fe63de4d4 Pin macOS SDK to 216 2023-01-31 01:09:44 -08:00
Eliasj42
8111f8bf35 added ability to select gpu (#891)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-01-30 13:39:12 -08:00
Abhishek Varma
fcd62513cf [SD-CLI] Add support for .safetensors + Use diffusers pipeline to load SD
-- This commit uses `load_pipeline_from_original_stable_diffusion_ckpt`
   as exposed due to [Diffusers PR](https://github.com/huggingface/diffusers/pull/2019).
-- It also adds a support for the end users to use `.safetensors` along
   with `.ckpt` file.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-31 00:00:37 +05:30
Abhishek Varma
c3c701e654 Update requirements.txt + README.md of SD
-- This commit includes two python modules as part of requirements.txt.
-- It also updates README.md to also inclue `--no-use_tuned` for users to
   be able to try `hf_model_id` or `ckpt_loc` without any issue.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-30 14:12:54 +05:30
Daniel Garvey
6bf991edf6 adding more robust main.py testing (#889)
Co-authored-by: dan <dan@nod-labs.com>
2023-01-30 00:14:26 -08:00
yzhang93
9644e78545 Fix CUDA tuned model annotation (#880) 2023-01-27 11:35:18 -08:00
dymil
c911189ef0 Add note about latest RDNA3 driver support (#881)
Also tweak other wording
2023-01-27 09:39:19 -08:00
Abhishek Varma
1118b4b651 [SD-CLI] Clean up vmfbs if a retry method fails
-- This commit cleans up vmfb files generated as a result of retry method.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-27 21:55:36 +05:30
PhaneeshB
4be75d4418 fix seed values in SD json and filename 2023-01-27 18:40:26 +05:30
Ean Garvey
fb6beae27c Adds pytest-forked dependency to fix pytest memory accumulation issues. (#876)
* Minor improvements to test-models workflow

- cleaned up pytest command line args in Validate Models job scripts.
- Removed -s flag to provide more readable logs
- Changed shark_cache location to within github workspace and removed --update_tank flag from Linux workflows.

* Use pytest-forked for managing pytest memory usage.
2023-01-26 18:20:15 -06:00
yzhang93
fee73b0b63 Add SD model annotation on fly (#869)
* Add SD model annotation on fly

* Move tuned_compile_through_fx to utils

* Fix SD compilation flags
2023-01-26 11:46:36 -08:00
powderluv
9bbffa519e Add an option to respect LLPC env var (#875)
Also add OSX paths
2023-01-25 13:56:55 -08:00
jinchen62
c3a641f0ab Address TODOs for dataset annotator (#872)
- add args usage, pass gs_url by CL flag
- add support for no existing prompts
2023-01-25 09:28:23 -08:00
yzhang93
aafe7c4701 Add more cuda devices to use tuned model (#868) 2023-01-25 06:36:17 -08:00
Abhishek Varma
9a0b082cf8 [SD-CLI] Add batch_size command-line arg + prompt processing
-- This commit adds `batch_size` command-line arg.
-- It also involves replicating the prompt `batch_size` no. of times.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-25 19:21:25 +05:30
powderluv
8265e34a29 Add SHARK SD CLI tool (#870) 2023-01-24 23:14:32 -08:00
powderluv
8ef8ae097f Update to build 469 2023-01-24 22:16:13 -08:00
powderluv
c3d14293c0 Update sample results 2023-01-24 22:14:06 -08:00
powderluv
d55d8be504 Add signing of release builds 2023-01-24 21:32:21 -08:00
powderluv
03543030d3 use pefile 2023-01-24 18:35:51 -08:00
powderluv
fc6b474b92 Add ordlookup to requirements.txt 2023-01-24 18:30:16 -08:00
powderluv
a5db785dd7 checkoutv2 on windows 2023-01-24 18:23:22 -08:00
powderluv
1c1c5cd611 Build Windows nightly on 7950x 2023-01-24 16:21:56 -08:00
Abhishek Varma
6ed02f70ec [SD-CLI] Make using ckpt_loc and hf_model_id easier
-- Currently we require users to specify the base model on which the custom
   model (.ckpt) is tuned on. Even for running a HuggingFace repo-id, we
   require the users to go a tedious way of adding things to variants.json.

-- This commit aims to address the above issues and will be treated as a
   starting point for a series of design changes which makes using SHARK's SD
   easier.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-24 23:03:46 +05:30
Prashant Kumar
cb78cd8ac0 Add the support for the batch size parameter. 2023-01-24 22:33:13 +05:30
Ean Garvey
0c4590b45a Update generate_sharktank.py 2023-01-24 10:18:03 +05:30
jinchen62
d2e2ee6efa Add multiple prompts support for dataset annotator (#862) 2023-01-23 18:40:36 -08:00
powderluv
6a380a0b48 Add more nvidia cards 2023-01-23 17:07:45 -08:00
powderluv
e5d5acbf1f Remove torchvision requirements from web (#860) 2023-01-23 13:48:53 -08:00
powderluv
00e38abbf0 Add 4080 support 2023-01-23 09:56:34 -08:00
Abhishek Varma
e3e4ea5443 Update README.md
-- Make usage of `hf_model_id` clearer.
2023-01-23 23:25:23 +05:30
Prashant Kumar
a3e4ea3228 Remove the dependency of the torchvision. (#858)
Remove the dependency of torchvision library for the conversion
of tensor layout format to what PIL library expects.
2023-01-23 08:49:57 -08:00
powderluv
56f16d6baf Update SD readme 2023-01-23 06:51:54 -08:00
Abhishek Varma
7a55ab900e [SD-CLI] Fix CKPT script + add more variants + update README.md
-- This commit fixes CKPT script to rely on the previous CKPT to Diffusers
   script.
   TODO: Let go of the script once the CKPT is included in next release
         of diffusers.
-- It also adds many variants as part of `variants.json` and updates
   `README.md` to reflect change in default `hf_model_id`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-23 18:34:24 +05:30
Abhishek Varma
137643fe72 [SD-CLI] Update README.md of custom models to include hf_model_id 2023-01-23 11:37:13 +05:30
Anush Elangovan
d6e59c6241 black format comments 2023-01-22 16:34:40 -08:00
powderluv
458eb5d34c detect RX 7900 better 2023-01-22 16:32:27 -08:00
Erkin Alp Güney
8259f08864 Collapsibles for Win10 and Linux users (#851)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-01-22 09:50:33 -08:00
Prashant Kumar
b3ab0a1843 Add width and height support for the scheduler. 2023-01-22 23:16:50 +05:30
dependabot[bot]
f09f217478 Bump tensorflow from 2.10 to 2.10.1 (#853)
Bumps [tensorflow](https://github.com/tensorflow/tensorflow) from 2.10 to 2.10.1.
- [Release notes](https://github.com/tensorflow/tensorflow/releases)
- [Changelog](https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md)
- [Commits](https://github.com/tensorflow/tensorflow/compare/v2.10.0...v2.10.1)

---
updated-dependencies:
- dependency-name: tensorflow
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-01-22 06:40:17 -08:00
Daniel Garvey
e842c8c19b add main.py testing for sdiff (#836)
Co-authored-by: dan <dan@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-01-22 01:16:17 -08:00
powderluv
f6c3112d44 Revert "potential fix to pre-load DLL dir for torch-mlir (#848)" (#852)
This reverts commit 6c470d8131.
2023-01-22 00:09:35 -08:00
yzhang93
7059610632 Modify the default for --hf_model_id flag 2023-01-21 11:21:47 +05:30
powderluv
2d272930d9 Update to signed build 455 2023-01-20 16:50:42 -08:00
powderluv
6c470d8131 potential fix to pre-load DLL dir for torch-mlir (#848)
Doesn't regress the main.py script but system already pre-loaded
the DLL so needs more testing.
2023-01-20 14:48:45 -08:00
jinchen62
30b29ce8cd Add readme for dataset annotator (#847) 2023-01-20 01:03:33 -08:00
jinchen62
1a9933002f Add dataset annotation tool (#835) 2023-01-19 16:56:08 -08:00
stanley
c4a9365aa1 [Shark][Training] Refresh SharkTrainer to latest APIs. 2023-01-19 20:30:15 +00:00
Prashant Kumar
9d3af37104 bugfix related to the height width params. 2023-01-20 00:21:44 +05:30
Prashant Kumar
7b3d57cff7 Add height and width as args. 2023-01-19 23:43:29 +05:30
Abhishek Varma
a802270da9 [SD-CLI] Update README.md about variants.json 2023-01-19 22:46:54 +05:30
Abhishek Varma
dd194a8758 [SD-CLI] Reorder loading of opt_params when needed
-- This commit reorders loading of opt_params when `import_mlir` is not used.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 22:02:51 +05:30
Abhishek Varma
6de02de221 [SD-CLI] Make using custom models easier
-- This commit makes using custom models easier using a combination of
   `import_mlir`, `ckpt_loc` and `hf_model_id`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 22:02:36 +05:30
Abhishek Varma
85259750bf [SD-CLI] Fix variants.json mapping
-- This commit fixes variants.json's mapping.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 22:02:36 +05:30
Prashant Kumar
1249f0007d Remove args.variant and args.version with args.custom_model. 2023-01-19 19:55:12 +05:30
Abhishek Varma
db0514d3fa [SD-CLI] Fix get_model_configuration to use max_length
-- This commit fixes `get_model_configuration` to use `max_length`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 19:10:04 +05:30
Abhishek Varma
dce42a7fad [SD-CLI] Fix args.max_length range check
This commit fixes args.max_length range check.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 18:26:23 +05:30
Prashant Kumar
ec0b380194 Refactor shark_tank models and custom models.
The custom models shouldn't depend on shark_tank in anyway.
2023-01-19 13:56:11 +05:30
Ean Garvey
7f27b61c98 Update setup_venv.sh to install triton if BENCHMARK=1 2023-01-19 00:26:46 -06:00
Guy Nachshon
f0b3557b02 fix: replace malicious and deleted package (#833) 2023-01-18 13:41:05 -08:00
xzuyn
2a1d1c1001 make jpeg optimized and progressive (#820)
* GUI make jpeg optimized and progressive

* CLI make jpeg optimized and progressive
2023-01-17 16:35:36 -08:00
Abhishek Varma
df7eb80e5b [SD-CLI] Make custom_model take highest priority for generating models if present
-- This commit makes `custom_model` take highest priority for generating models if present.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-17 22:50:58 +05:30
Fraser Humphries
b9d947ce6f style: 🎨 Restore whitespace 2023-01-17 17:45:32 +05:30
Fraser Humphries
e6589d2454 fix: 🏗️ Add demo.css to spec file datas 2023-01-17 17:45:32 +05:30
Fraser Humphries
0f5ac6afcf fix: 🐛 resolve css file path relative to __file__
issues-816
2023-01-17 17:45:32 +05:30
Abhishek Varma
bc1bb1d188 [SD-CLI] Fix vmfb naming + update README.md for custom_model
-- This commit introduces a fix for .vmfb naming to strip away any
   non-alphanumeric characters from `custom_model` path.
-- It also updates the README.md to include the `custom_model` arg.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-17 16:27:54 +05:30
Abhishek Varma
3af2dd10ce [SD-CLI] Add CKPT support to update models irrespective of import_mlir flag
-- This commit adds CKPT support to update models irrespective of `import_mlir` flag.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-17 13:24:27 +05:30
yzhang93
dd22c65855 Add CUDA tuned models for SD variants (#814) 2023-01-16 09:38:27 -08:00
PhaneeshB
48137ced19 add png as default format 2023-01-16 18:37:36 +05:30
Phaneesh Barwaria
6eb47c12d1 add multi-run in single execution (#812) 2023-01-13 11:12:43 -08:00
Prashant Kumar
5a1fc6675a This PR adds --import-mlir for f16 tensors without cuda. 2023-01-13 22:19:53 +05:30
Prashant Kumar
6f80825814 Modify import_with_fx to import with dtype=f16. 2023-01-13 22:19:53 +05:30
PhaneeshB
f0dd48ed2a remaining disk space warning 2023-01-13 19:34:05 +05:30
Gaurav Shukla
15e2df0db0 [SD][web] Add a UI textbox to show the output location
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-13 19:33:04 +05:30
Fraser Humphries
4ad0109769 fix: 🐛 Extract demo css string to css file
fix: 🐛 Extract demo css string to css file

issues/807

fix: 🐛 Revert background colors
2023-01-13 16:42:05 +05:30
PhaneeshB
ee0009d4b8 pythonize uname for cpu target triple in windows 2023-01-12 22:39:49 +05:30
PhaneeshB
9d851c3346 small fixes 2023-01-12 22:32:24 +05:30
xzuyn
5d117af8ae Increase JPEG output quality & disable subsampling (#801)
* Increase JPEG output quality & disable subsampling

Increased to JPEG95 from the default JPEG75 which is way too compressed. Output image size is now ~100kb. Previously was ~20kb.

* Increase JPEG output quality & disable subsampling

Add jpeg quality increase on cli

* line length changes

* line length changes
2023-01-11 23:06:11 -08:00
yzhang93
bb41c2d15e Add VAE cuda tuned model (#796) 2023-01-11 14:15:03 -08:00
powderluv
eba138ee4a Revert "Change address for connection test (#785)" (#797)
This reverts commit 187f0fa70c.
2023-01-11 12:01:37 -08:00
Gaurav Shukla
3b2bbb74f8 [SD][web] Add support for saving generated images
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-11 22:47:32 +05:30
fokin33
dbc0f81211 Add simple telegram bot (#787) 2023-01-11 09:20:23 -06:00
mariecwhite
d0b613d22e Enable Torch-Inductor Part 2 2023-01-10 20:15:29 -08:00
Ean Garvey
72f29b67d5 Add Resnet50 fp16 variant to pytests. (#760) 2023-01-10 16:31:11 -08:00
Quinn Dawkins
9570045cc3 Fix tuned model selection for non-vulkan devices (#792) 2023-01-10 19:04:21 -05:00
Phaneesh Barwaria
e4efdb5cbb add json data for each image (#790)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-01-10 13:13:07 -08:00
calcifer11
187f0fa70c Change address for connection test (#785)
Some ISP's (like mine) reserves 1.1.1.1 for internal testing, meaning _internet_connected(); needlessly retries for a minute until it fails even though my connection is fine.
Propose 8.8.8.8 instead as this is also publically available and not normally blocked by ISPs.
2023-01-10 10:51:30 -08:00
Gaurav Shukla
472185c3e4 [SD][web] Fix device key error
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-10 20:51:01 +05:30
Gaurav Shukla
f94a571773 [SD] Update spec file to include model_config.json
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-10 20:38:10 +05:30
mariecwhite
183e447d35 Enable Torch Inductor (#784) 2023-01-10 20:57:58 +11:00
xzuyn
12f844d93a Git pull through argument in setup_venv (#623) 2023-01-09 15:42:13 -08:00
yzhang93
47a119a37f [SD] Add CUDA A100 tuned model (#773) 2023-01-09 15:22:27 -08:00
Gaurav Shukla
ee56559b9a [SD][web] Add a json file for model configuration
This cleans model_wrappers.py file.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-10 00:05:46 +05:30
Gaurav Shukla
00e594deea [SD][web] Add version number in performance details
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-09 21:32:34 +05:30
George Petterson
6ad9b213b9 Add GCN4
(cherry picked from commit 3be072b3c09c9b38bc2d79ad6e6900eefee49a1c)
2023-01-09 21:09:50 +05:30
PhaneeshB
e4375e8195 Add support for vulkan target env 2023-01-09 21:09:50 +05:30
mariecwhite
487bf8e29b Enable TF32 in Torch if specified (#768) 2023-01-09 06:48:57 -08:00
Prashant Kumar
fea1694e74 Delete the cached objects explicitly. 2023-01-06 23:04:52 +05:30
Prashant Kumar
4102c124a9 Add the shark upscaler model. (#759) 2023-01-05 14:07:20 -08:00
yzhang93
135bad3280 [SD] Update v1.4 tuned model (#758) 2023-01-05 11:04:30 -08:00
Gaurav Shukla
b604f36881 [SD][web] Add flags for global URL and server port
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-05 15:30:30 +05:30
yzhang93
782b449c71 Add script to auto annotate SD models and variants (#751)
* Add script to auto annotate SD models and variants

* Add model config files

* Add script to auto annotate SD models and variants

* Add model config files

* Move config files to shark_tank
2023-01-04 15:53:10 -08:00
jinchen62
017dcab685 Add target triple support for TITAN RTX (#756) 2023-01-04 15:39:00 -08:00
Abhishek Varma
e60b4568c6 [SharkInference] Make SharkInference compile the entire module (#708)
* [SharkInference] Make SharkInference compile the entire module

-- Previously SharkInference was compiling and providing run APIs
   for a harcoded function with function name "forward".
-- This commit makes the compiling functionality generic and now
   any function being defined within the module can be run.
-- It also creates an API to fetch all the function names defined
   within the compiled module.
-- This commit updates both web and command-line execution of Stable
   Diffusion to use new API of  SharkInference.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-03 23:25:23 +05:30
powderluv
4ee3d95a5a Update to build 423
Post pytorch security breach
2023-01-01 12:10:23 -08:00
Graham
f18725bacc replaced <username> with %username% for easy copy/paste (#744) 2022-12-31 21:29:37 -08:00
jinchen62
f6064a2b84 Add a prototype of the model compilation configs for SD (#734) 2022-12-28 15:14:36 -08:00
Quinn Dawkins
2e90cb7b95 Set default warmup count to 0 (#736) 2022-12-28 12:27:43 -06:00
powderluv
2c09d63cd9 Update to build 417 2022-12-27 14:25:20 -08:00
powderluv
cc6fbdb0c3 Add sm_89 and point to nvcuda.dll (#731) 2022-12-26 10:54:38 -08:00
powderluv
ecfdec12f3 Update requirements.txt 2022-12-25 15:39:20 -08:00
Gaurav Shukla
45af40fd14 [SD][web] Add openjourney and dreamlike in SD web UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-26 01:59:36 +05:30
Phaneesh Barwaria
d11cf42501 Add support for dreamlike diffusion (#725)
* Add support for dreamlike diffusion

* model wrapper to support 77 dreamlike

* lint fix
2022-12-26 01:35:17 +05:30
Gaurav Shukla
c3c1e3b055 [SD] Add bucket info in the model_db.json
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-25 20:38:33 +05:30
Gaurav Shukla
7c5e3b1d99 [SD] Fix flags for cuda devices
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-25 19:03:02 +05:30
Gaurav Shukla
ed6cec71e7 [SD] Fix clip inference time
Fix clip inference time by adding default warmup_count to 5.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-25 18:16:53 +05:30
Tobby "GTD-Carthage" Ong
d6bcdd069c - Added missing double linebreak from linting 2022-12-25 12:07:43 +05:30
Tobby "GTD-Carthage" Ong
a26347826d - Revised code to also use get_schedulers function instead 2022-12-25 12:07:43 +05:30
Tobby "GTD-Carthage" Ong
5d1c099b31 [SD] Add Euler Ancestral scheduler as option to WebUI 2022-12-25 12:07:43 +05:30
Gaurav Shukla
220bee1365 [SD][web] Add device support in the SD web UI
1. Now device selection is available through UI.
2. Models reloading will only happen when there will be a change in the
   settings(variant + device).

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-25 01:45:07 +05:30
PhaneeshB
1261074d95 Add tuned models for av3 and ad 2022-12-24 22:56:15 +05:30
Stanley Winata
136021424c [SD] Change default VMA large heap block size for windows perf. (#715)
Windows perform can boost from 2.67s/image to 2.4523s/image.
While Linux stays the same.
2022-12-24 01:40:58 +07:00
PhaneeshB
fee4ba3746 Add openjourney 2022-12-23 23:34:22 +05:30
Gaurav Shukla
a5b70335d4 [SD][web] Add variant support in the web UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-23 23:18:27 +05:30
Stanley Winata
5cf4976054 [Vulkan][utils] Add GTX Pascal support. (#709) 2022-12-22 15:24:15 -08:00
PhaneeshB
1aa3255061 Add vaebase for av3 and ad 2022-12-23 04:17:17 +05:30
Daniel Garvey
b01f29f10d add support for clear_all (#691) 2022-12-22 11:25:03 -06:00
Boian Petkantchin
2673abca88 Fix concurrency issue in stress_test for CUDA devices 2022-12-22 08:54:19 -08:00
Gaurav Shukla
7eeb7f0715 [SD] Update all the utilities to make web and CLI codebase closer (#707)
At this point, all the utilities of SD web and CLI are exactly same.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-22 02:49:48 -08:00
powderluv
37262a2479 Remove spurious characters 2022-12-21 19:23:54 -08:00
Gaurav Shukla
de6e304959 [SD] Fix the resource location in shark_sd.spec (#706) 2022-12-21 14:41:56 -08:00
Quinn Dawkins
234475bbc7 Add base_vae entries for variant models (#705) 2022-12-21 14:35:08 -08:00
Quinn Dawkins
abbd9f7cfc [SD] Set unet flags for cuda (#704) 2022-12-21 13:22:04 -08:00
Gaurav Shukla
dfd6ba67b3 [SD] Update SD CLI to use model_db.json
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-22 02:13:04 +05:30
yzhang93
1595254eab Modify model annotation tool to walk through ops by shape (#692) 2022-12-21 10:46:30 -08:00
PhaneeshB
6964c5eeba encapsulate relevant methods in one method 2022-12-21 23:56:17 +05:30
PhaneeshB
2befe771b3 Add support for automatic target triple selection for SD 2022-12-21 22:38:06 +05:30
Prashant Kumar
b133a035a4 Add the download progress bar. 2022-12-21 15:47:33 +05:30
Gaurav Shukla
726c062327 [SD] Update spec files
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-21 14:16:04 +05:30
Gaurav Shukla
9083672de3 [SD][web] Tuned models only for stablediffusion/fp16 and rdna3 cards
Currently tuned models are only available for stablediffusion/fp16 and
rdna3 cards.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-21 14:15:39 +05:30
Quinn Dawkins
cdbaf880af [SD] [web] Add model variants to web 2022-12-21 13:42:22 +05:30
Quinn Dawkins
9434981cdc Add random seed generation for seed = -1 in cli (#689) 2022-12-20 17:15:22 -05:00
Phaneesh Barwaria
8b3706f557 Add Anything v3 and AnalogDiffusion variants of SD (#685)
* base support for anythingv3

* add analogdiffusiont

* Update readme

* keep max len 77 till support for 64 added for variants

* lint fix
2022-12-20 13:08:13 -08:00
Gaurav Shukla
0d5173833d [SD] Add a json file for model names information. (#687)
This commit simplifies the code to identify the model name for a
particular set of flags. This is achieved by introducing a json file
that stores the model names information. The models are uploaded in
gcloud with these names.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-20 11:47:31 -08:00
powderluv
bf1178eb79 roll to build 400 2022-12-20 10:34:31 -08:00
yzhang93
abcd3fa94a [SD] Set model max length 64 as default (#681) 2022-12-19 21:13:04 -08:00
Quinn Dawkins
62aa1614b6 [SD] Add --use_base_vae flag to do conversion to pixel space on cpu (#682) 2022-12-19 21:09:39 -08:00
Quinn Dawkins
7027356126 [SD] Fix warmup for max length 64 (#680) 2022-12-19 21:04:44 -05:00
yzhang93
5ebe13a13d Add Unet len 64 tuned model (#679) 2022-12-19 16:24:08 -08:00
Gaurav Shukla
c3bed9a2b7 [SD][web] Add flag to disable the progress bar animation
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-20 02:50:04 +05:30
yzhang93
f865222882 Update VAE 19dec tuned model (#676) 2022-12-19 12:42:28 -08:00
powderluv
e2fe2e4095 Point to 398 2022-12-19 12:08:30 -08:00
powderluv
0532a95f08 Update stable_diffusion_amd.md 2022-12-19 12:04:42 -08:00
Quinn Dawkins
ff536f6015 [SD] Deduplicate initial noise generation (#677) 2022-12-19 14:38:41 -05:00
Gaurav Shukla
097d0f27bb [SD][web] Add 64 max_length support in SD web
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-20 00:00:58 +05:30
Prashant Kumar
2257f87edf Update opt_params.py 2022-12-19 23:43:30 +05:30
PhaneeshB
a17800da00 Add 64 len f16 untuned mlir 2022-12-19 22:53:17 +05:30
Prashant Kumar
059c1b3a19 Disable vae --use_tuned version. 2022-12-19 22:45:45 +05:30
Stanley Winata
9a36816d27 [SD][CLI] Add a warmup phase (#670) 2022-12-20 00:14:23 +07:00
Gaurav Shukla
7986b9b20b [SD][WEB] Update VAE model and wrapper
This commit updates VAE model which significantly improves performance
by an order of ~300ms.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-19 22:32:05 +05:30
Gaurav Shukla
b2b3a0a62b [SD] Move initial latent generation out of inference time
The initial random latent generation is not taken into account
for total SD inference time.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-19 22:32:05 +05:30
Prashant Kumar
3173b7d1d9 Update VAE model and wrapper. 2022-12-19 19:54:50 +05:30
Gaurav Shukla
9d716d70d6 [SD][web] Fix performance issues on shark scheduler
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-19 17:44:37 +05:30
Stanley Winata
e1901a8608 [SD][CL] Disable print at every iteration. (#664)
Printing might incur extra time to runtime. Hence, we add a flag to hide it. To disable printing please set this flag `--hide_steps`.

Co-authored-by: Stanley <stanley@MacStudio.lan>
2022-12-19 15:39:57 +07:00
Quinn Dawkins
7d0cbd8d90 [SD][web] Set default tuned unet to v2 (#663) 2022-12-19 11:50:08 +07:00
Quinn Dawkins
59358361f9 [SD] Make clip batch 2 for positive and negative prompts (#662)
Combines the forward passes for each input prompt type into a single batched clip pass.
2022-12-18 23:46:21 -05:00
Quinn Dawkins
7fea2d3b68 [SD] update default large heap size for web as well (#661) 2022-12-18 21:50:26 -05:00
Quinn Dawkins
b6d3ff26bd [SD] Change default VMA large heap block size (#660) 2022-12-18 21:41:46 -05:00
Stella Laurenzo
523e63f5c1 Fix NoneType exception if vulkan tuning flags not detected. (#659)
(This goes on to produce compilation errors, but one step at a time)
2022-12-18 16:40:56 -08:00
Stella Laurenzo
10630ab597 Add config stanza for NVIDIA RTX 2080. (#658)
Just happened to have this card on my Windows machine and verified that the SD demo works on it.

```
Average step time: 144.26142692565918ms/it
Clip Inference Avg time (ms) = (205.001 + 44.000) / 2 = 124.501
VAE Inference time (ms): 281.001

Total image generation time: 7.856997728347778sec
```

I'd love to add an API upstream to derive compiler tuning flags from a host device.
2022-12-18 16:40:47 -08:00
Quinn Dawkins
2bc6de650d [SD] Add support for a compiled version of the discrete Euler scheduler (#657)
* Add Shark version of euler scheduler

* Add Shark version of euler scheduler to web ui
2022-12-17 19:25:43 -08:00
powderluv
ffef1681e3 Update stable_diffusion_amd.md 2022-12-17 03:40:08 -08:00
yzhang93
d935006a4a Update Unet tuned model to v2 (#656) 2022-12-16 22:10:15 -08:00
powderluv
660cb5946e Update to 392 release 2022-12-16 16:00:49 -08:00
Gaurav Shukla
10160a066a [SD][WEB] Add vae tuned model in the SD web (#653)
1. Add tuned vae model in the SD web.
2. Use tuned models in case of rdna3 cards.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-16 15:29:48 -08:00
Anush Elangovan
72976a2ece Import env vars first 2022-12-16 15:12:28 -08:00
Phaneesh Barwaria
831f206cd0 Revert "Add target triple selection for multiple cards" (#655)
This reverts commit acb905f0cc.
2022-12-16 15:01:45 -08:00
Gaurav Shukla
72648aa9f2 Revert "[SD][WEB] Deduce vulkan-target-triple in the presence of multiple cards"
This reverts commit 35e623deaf.
2022-12-17 04:28:18 +05:30
Gaurav Shukla
35e623deaf [SD][WEB] Deduce vulkan-target-triple in the presence of multiple cards
1. Get the correct vulkan-target-triple for a specified device in the
   presence of multiple cards.
2. Use tuned unet model for rdna3 cards.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-17 03:04:47 +05:30
Anush Elangovan
6263636738 Fix more lints 2022-12-16 13:26:15 -08:00
Anush Elangovan
535d012ded Fix lint 2022-12-16 13:24:51 -08:00
yzhang93
c73eed2e51 Add VAE winograd tuned model (#647) 2022-12-16 13:01:45 -08:00
Anush Elangovan
30fdc99f37 Set to enable llpc
Use an env var to enable llpc
2022-12-16 12:57:30 -08:00
PhaneeshB
acb905f0cc Add target triple selection for multiple cards 2022-12-17 02:24:37 +05:30
Gaurav Shukla
bba06d0142 [SD][WEB] Avoid passing args to utils APIs
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-17 01:41:33 +05:30
Ean Garvey
a14a47af12 Move most xfails to entries in tank/all_models.csv and temporarily remove multiprocessing and TF gpu support. (#646)
-Adds date variable back to nightly.yml so shark_tank uploads are dated again
-added specification for nightly pytests to not run tests on metal (vulkan is sufficient)
-added some paths/filetypes to be ignored when triggering workflow runs. (no test-models on changes to .md files or anything in the shark/examples/ directory or its subdirectories.
-pytest only picks up tank/test_models.py, so no need to specify which file to run when running pytest from SHARK base directory.
-Cleaned up xfails so that they can be added to models as csv entries. Columns 7-9 in all_models.csv trigger xfails with cpu, cuda, vulkan, respectively, and row 10 can be populated with a reason for the xfails.
-Fixed a few defaults for shark_args and pytest args (defined in conftest.py)
-Fixes --update_tank option in shark_downloader
removes some multiprocessing in pytest / TF+CUDA support because it breaks pytest and false passes, leaving regressions at large.
-Adds xfails for and removes albert torch from gen_sharktank list (tank/torch_model_list.csv).
-Cleans up xfails for cpu, cuda, vulkan (removing old ones)
2022-12-16 12:56:32 +05:30
Phaneesh Barwaria
73457336bc add flag for toggling vulkan validation layers (#624)
* add vulkan_validation_layers flag

* categorize SD flags

* stringify true and false for flag
2022-12-15 20:40:59 -06:00
Ean Garvey
a14c53ad31 Remove albert-base-v2 since it fails torch_mlir.compile() (#644) 2022-12-15 16:05:19 -06:00
Gaurav Shukla
e7e763551a [WEB][SD] Make unet tuned model default for rdna3 devices (#642) 2022-12-15 12:02:03 -08:00
nirvedhmeshram
2928179331 Add more NVIDIA targets (#640) 2022-12-15 11:24:38 -06:00
Stanley Winata
24a16a4cfe [Stable Diffusion] Disable binding fusion to work with moltenVK on mac. (#639)
Co-authored-by: Stanley <stanley@MacStudio.lan>
2022-12-16 00:22:49 +07:00
Phaneesh Barwaria
6aed4423b2 add vulkan lib path (#638) 2022-12-15 19:48:29 +07:00
yzhang93
6508e3fcc9 Update tuned model SD v2.1base (#634) 2022-12-14 16:02:35 -05:00
Gaurav Shukla
a15cb140ae [WEB] Display the 512x512 image size
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-14 22:43:03 +05:30
Prashant Kumar
898bc9e009 Add the stable diffusion v2.1 version. 2022-12-14 20:19:41 +05:30
Gaurav Shukla
e67ea31ee2 [SHARK][SD] Add --local_tank_cache flag in the stable diffusion
This flag can be used to set local shark_tank cache directory.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-14 20:00:25 +05:30
Gaurav Shukla
986c126a5c [SHARK][SD] Add support for negative prompts
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-14 18:20:09 +05:30
Gaurav Shukla
0eee7616b9 [WEB] Launch only one SD version at a time
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-14 17:30:24 +05:30
powderluv
5ddce749b8 lint fix 2022-12-13 22:02:32 -08:00
powderluv
d946cffabc Revert "Move most xfails to entries in tank/all_models.csv and temporarily remove multiprocessing and TF gpu support. (#602)" (#622)
This reverts commit fe618811ee.
2022-12-13 21:49:46 -08:00
Ean Garvey
fe618811ee Move most xfails to entries in tank/all_models.csv and temporarily remove multiprocessing and TF gpu support. (#602)
* Move most xfails to entries in tank/all_models.csv

* enable usage of pytest without specifying tank/test_models.py

* add dict_configs.py to gitignore.

* Pin versions for runtimes and torch-mlir for setup.
2022-12-13 18:11:17 -08:00
powderluv
09c45bfb80 clean up cache printf 2022-12-13 14:11:14 -08:00
Boian Petkantchin
e9e9ccd379 Add stress test 2022-12-13 13:21:51 -08:00
Boian Petkantchin
a9b27c78a3 Return dynamic model if specified when downloading from the tank 2022-12-13 13:21:51 -08:00
Boian Petkantchin
bc17c29b2e In get_iree_runtime_config get the specific device instead of the default 2022-12-13 13:21:51 -08:00
Boian Petkantchin
aaf60bdee6 Simplify iree_device_map 2022-12-13 13:21:51 -08:00
Gaurav Shukla
d913453e57 [WEB] Update models to 8dec and also default values (#620)
1. Update the models to 8 dec.
2. precision is default to `fp16` in CLI.
3. version is default to `v2.1base` in CLI as well as web.
4. The default scheduler is set to `EulerDiscrete` now.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-13 13:08:33 -08:00
powderluv
08e373aef4 Update stable_diffusion_amd.md 2022-12-13 11:47:29 -08:00
Prashant Kumar
4cb50a3d06 Update the models to 8th Dec version. 2022-12-14 00:01:46 +05:30
Gaurav Shukla
b03038222d [SHARK] Update dependencies
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-13 22:12:00 +05:30
Gaurav Shukla
5f5e0766dd [WEB] Add SD2.1 web support
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-13 21:36:01 +05:30
powderluv
48ec11c514 Build wheels (#613)
* Build wheels

* Update nightly.yml

* Update nightly.yml

* Update nightly.yml
2022-12-12 20:53:08 -08:00
Prashant Kumar
8ae76d18b5 Add euler scheduler. Also, make it default for sd2.1. 2022-12-13 00:03:45 +05:30
Prashant Kumar
e5be1790e5 Enable the v2.1 base version with --version="v2.1base". (#611) 2022-12-12 07:02:01 -08:00
powderluv
e64aa40b17 Add Windows nightly builder 2022-12-11 19:31:02 -08:00
mariecwhite
eb8114ece8 Initialize TF models locally (#610) 2022-12-12 11:35:34 +11:00
Ean Garvey
616ee9b824 Don't include baseline benchmarks if setup without IMPORTER=1. (#607) 2022-12-10 14:58:29 -06:00
Stanley Winata
57c94f8f80 [vulkan] Add "radeon" check to the default AMD triple (#604) 2022-12-10 09:05:48 -08:00
powderluv
2a59c4f670 Update stable_diffusion_amd.md 2022-12-09 16:54:47 -08:00
Boian Petkantchin
192ff487c4 Fix wrong path to script in tank readme (#598) 2022-12-09 11:51:17 -06:00
Gaurav Shukla
b62ee3fcb9 [WEB] Add schedulers in the web UI (#594)
1. Add schedulers option in web UI.
2. Remove random seed checkbox as the same functionality can be achieved
   by passing -1(or any negative number) to the seed.

Signed-Off-by: Gaurav Shukla

Signed-off-by: Gaurav Shukla
2022-12-08 13:53:20 -08:00
Ean Garvey
0225292a44 Remove print statements from compile utils (#593) 2022-12-08 13:40:47 -08:00
Ean Garvey
589a7ed02f Print a message when a model is downloaded via shark_downloader. (#595) 2022-12-08 15:27:58 -06:00
Quinn Dawkins
b3a42cd0b1 Don't do nchw-to-nhwc transpose for stable diffusion models (#592) 2022-12-08 12:19:23 -05:00
Gaurav Shukla
e3e1ca7cc6 [WEB] Fix seed when out of uint32 range
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-08 22:46:33 +05:30
Gaurav Shukla
57e417d174 [WEB] Fix web performance
Set the iree flags before compilation.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-08 19:57:20 +05:30
Ean Garvey
1699db79b5 Disable SHARK-Runtime flags if USE_IREE=1 specified during setup. (#588)
* Disable SHARK-Runtime flags if USE_IREE=1 specified during setup.

* Update setup_venv.sh

* Autodetect cpu count for runtime flags.
2022-12-08 02:31:31 -06:00
Quinn Dawkins
dab9403b8f Fix slow conversion to image in SD web gui (#586) 2022-12-08 00:35:51 -05:00
Ean Garvey
9a14298146 Revert changes to multiprocessing (#585) 2022-12-07 19:59:17 -06:00
Ean Garvey
40eea21863 Enable conv nchw-to-nhwc flag by default for most models + minor fixes (#584) 2022-12-07 16:24:02 -08:00
Ean Garvey
d2475ec169 Add mnasnet to torch models and minor fixes. (#577)
* Minor fixes to benchmark runner

* Add Mnasnet to tank.
2022-12-07 22:30:58 +05:30
Ean Garvey
b3bcf4bf44 Update expected failures in pytest suite. (#574) 2022-12-06 23:05:12 -08:00
Stanley Winata
6049f86bc4 [Vulkan][Utils] Automatic platform/OS detection (#569)
To enable AMD gpus on macOS, we need this detection to let the compiler know that we would be needing moltenVK to use this GPU.
2022-12-07 12:05:00 +07:00
mariecwhite
ff649b52ef Add TF EfficientNet Model (#502) 2022-12-06 13:51:59 -06:00
Gaurav Shukla
e9e138c757 [WEB] Add random seed checkbox
When True, it will not use user specified seed, instead will generate a
random seed.

Signed-Off-by: Gaurav Shukla
2022-12-06 21:44:22 +05:30
Phaneesh Barwaria
1096936a15 Enable f32 path for SD (#567) 2022-12-06 19:29:12 +05:30
Gaurav Shukla
29cc478525 [WEB] Add command line args to shark web
1. Now the server can be launched with command line args.
2. The `precision` and `scheduler` parameters are now part of command
   line args instead of UI.
3. Add vae encode model wrapper.

Signed-Off-by: Gaurav Shukla
2022-12-06 17:21:05 +05:30
Stanley Winata
05e9eb40b5 [Misc] Ignore vmfbs from getting tracked by git. (#566) 2022-12-06 00:01:52 -08:00
Stanley Winata
c4444ff695 [vulkan][utils] Add rdna3 detection (#565) 2022-12-05 23:56:06 -08:00
Anush Elangovan
27b34f3929 Add gcs instead of gsutil
Test .exe on AMD hardware.
2022-12-05 22:17:58 -08:00
powderluv
2b8d784660 update latest sd build 2022-12-05 22:16:13 -08:00
Daniel Garvey
18f447d8d8 fix hash comparison (#563)
Co-authored-by: dan <dan@nod-labs.com>
2022-12-05 21:43:05 -08:00
Daniel Garvey
d7e1078d68 remove nodcloud from client (#562)
Co-authored-by: dan <dan@nod-labs.com>
2022-12-05 23:13:19 -06:00
Daniel Garvey
6be592653f remove gsutil_flags and fix download (#559) 2022-12-05 20:29:00 -08:00
Daniel Garvey
8859853b41 Revert "Revert "find gsutil on linux (#557)" (#560)" (#561)
This reverts commit 3c46021102.
2022-12-05 20:27:43 -08:00
Daniel Garvey
3c46021102 Revert "find gsutil on linux (#557)" (#560)
This reverts commit bba8646669.
2022-12-05 21:53:47 -06:00
Daniel Garvey
bba8646669 find gsutil on linux (#557)
* find gsutil on linux

* cleaned up downloader and ditched gsutil

Co-authored-by: dan <dan@nod-labs.com>
2022-12-05 19:03:48 -08:00
Daniel Garvey
b0dc19a910 revert parallel downloads to 1 (#555)
Co-authored-by: dan <dan@nod-labs.com>
2022-12-05 15:42:42 -08:00
Daniel Garvey
df79ebd0f2 replace gsutil with variable path for pyinstaller (#541)
Co-authored-by: dan <dan@nod-labs.com>
2022-12-05 15:08:57 -08:00
Quinn Dawkins
e19a97f316 Don't do a numpy copy on the results from compiled vm (#543) 2022-12-05 14:21:47 -05:00
Harish Anand
482ffd6275 Move discord link from advanced instructions (#542) 2022-12-04 06:15:34 -08:00
Quinn Dawkins
5117e50602 Revert "Enable the clip f16 model." until correctness is fixed 2022-12-04 19:17:34 +05:30
powderluv
83b138208d Add gradio to requirements.txt 2022-12-03 16:06:52 -08:00
Quinn Dawkins
1870cb4557 Add a note to the Stable Diffusion README about clearing vulkan cache (#545) 2022-12-03 15:12:45 -08:00
Prashant Kumar
42ad5b9c5c Enable the clip f16 model.
-- Enabled the clip f16 model.
-- Updated the location of sdv2 model.
2022-12-03 18:50:40 +05:30
yzhang93
333975eb8f Update Unet fp16 tuned model and Vae flag (#539) 2022-12-02 23:21:18 -05:00
Gaurav Shukla
aa0195e4ef [SHARK] Add vae encoder wrapper
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-03 08:42:25 +05:30
Anush Elangovan
56109fe09b Add one click installer
Build with pyinstaller web\shark_sd.spec
2022-12-02 14:07:10 -08:00
powderluv
e74046478b Update stable_diffusion_amd.md 2022-12-02 13:57:03 -08:00
Gaurav Shukla
aa5a60812f [SHARK] Fix space issues in download path
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-03 00:52:10 +05:30
Ean Garvey
ebb60019aa Minor formatting fix. (#538) 2022-12-03 00:17:31 +05:30
mariecwhite
6393dc5d14 Use correct TF device depending on configuration (#492) 2022-12-02 11:33:56 -06:00
Anush Elangovan
8c158f2452 Fix onedir pyinstall
Use relative paths for install

pyinstaller web/shark_sd.spec creates an exe
2022-12-02 07:28:22 -08:00
powderluv
8c3eabdcee Update stable_diffusion_amd.md 2022-12-02 07:13:10 -08:00
powderluv
8aa0ce6a24 Update stable_diffusion_amd.md 2022-12-02 07:10:31 -08:00
Gaurav Shukla
a27ee141b3 [WEB] Fix few warnings and generate seed faster
1. Fix gsutil warnings while copying multiple files.
2. Enhance random seed generation speed.
3. Add support for multiple schedulers.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-02 17:16:19 +05:30
Anush Elangovan
1106456651 Update cuda 11.7 nightly URL and add index.spec 2022-12-01 22:49:23 -08:00
Quinn Dawkins
8856878cbd Add flag for enabling rgp from the main.py SD script (#533) 2022-12-01 19:01:29 -05:00
Gaurav Shukla
a9bac0287d [WEB] Update to latest models.
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-01 22:55:31 +05:30
Gaurav Shukla
efbd3dc778 [WEB] Fix debug option and add random seed generation
Signed-Off-by: Gaurav Shukla
2022-12-01 21:08:34 +05:30
Phaneesh Barwaria
a0d0eaa408 add clip and vae timing (#527) 2022-12-01 16:17:40 +05:30
Prashant Kumar
e2bf734b67 Update f32 models. 2022-12-01 14:16:03 +05:30
Prashant Kumar
a333a90441 Update to the latest bytecode. 2022-12-01 12:44:54 +05:30
powderluv
6dc0057d3d Update README.md 2022-11-30 17:02:28 -08:00
powderluv
0f9e69d48c Update README.md 2022-11-30 17:01:23 -08:00
powderluv
e6a7c019ab Update README.md 2022-11-30 16:59:55 -08:00
powderluv
1d32eabd14 Update stable_diffusion_amd.md 2022-11-30 16:52:07 -08:00
powderluv
53d03f06a6 Update stable_diffusion_amd.md 2022-11-30 16:04:53 -08:00
powderluv
a2d8c40455 Update stable_diffusion_amd.md 2022-11-30 15:56:38 -08:00
powderluv
4f7d950c8d Update README.md 2022-11-30 15:54:50 -08:00
Harish Anand
cac54b8c26 Update stable_diffusion_amd.md (#525)
- Mention `git clone` after installing git in Windows
- Remove the extra . in powershell set-executionpolicy
2022-11-30 14:48:10 -08:00
powderluv
cd0e881d7d Update stable_diffusion_amd.md 2022-11-30 13:43:24 -08:00
powderluv
fee406e220 Update README.md 2022-11-30 13:43:02 -08:00
powderluv
128342f47f Update stable_diffusion_amd.md 2022-11-30 13:42:25 -08:00
powderluv
024487c5fe Update stable_diffusion_amd.md 2022-11-30 13:40:00 -08:00
powderluv
879ba27ccb Update stable_diffusion_amd.md 2022-11-30 13:33:04 -08:00
powderluv
6d6d9627e7 Update stable_diffusion_amd.md 2022-11-30 13:31:53 -08:00
powderluv
af4bc82543 Update stable_diffusion_amd.md 2022-11-30 13:30:15 -08:00
powderluv
439a18bcc3 Update README.md 2022-11-30 13:27:13 -08:00
powderluv
e12a1e0444 Update README.md 2022-11-30 13:01:19 -08:00
powderluv
4400b0d3c3 Update README.md 2022-11-30 12:38:02 -08:00
powderluv
5dff28ff99 streamline README.md 2022-11-30 12:23:36 -08:00
powderluv
d5ac841a1a Update requirements.txt
add transformers to base venv
2022-11-30 12:12:28 -08:00
powderluv
232ce12e9b Create stable_diffusion_amd.md 2022-11-30 12:10:34 -08:00
aldesilv
9a8638a6d0 dump all isas with amdllpc (#517)
SHARK/shark/examples/shark_inference/stable_diffusion$ python main.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa

Co-authored-by: alexander <alexander@nod-labs.com>
2022-11-30 11:33:30 -08:00
Gaurav Shukla
a5445866b8 [WEB] Update the iree flag
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-11-30 18:56:48 +05:30
powderluv
e8ded71a7b Default to 50 steps for SD 2022-11-29 16:45:23 -08:00
Prashant Kumar
a14c615def Update with the new flag. (#522) 2022-11-29 09:39:32 -08:00
Gaurav Shukla
3903b6ff0c [WEB] Enable Debug and disable live preview
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-11-29 22:39:53 +05:30
Ean Garvey
41bf262482 Update SD README.md (#516)
* Update README.md

* Create profiling_with_iree.md
2022-11-29 10:21:28 -06:00
Gaurav Shukla
645b658da0 [WEB] Update model wrappers and scheduler
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-11-29 21:22:33 +05:30
Prashant Kumar
6ee8f61fbe Add the stable diffusion v2 model.
The f16 version of stable diffusion v2 model is added.
--version="v2" will run the v2 model.
2022-11-29 18:18:04 +05:30
Prashant Kumar
3c4c4231ce Add new args. 2022-11-29 18:18:04 +05:30
Prashant Kumar
d0eef19eba Remove the lms versions as they were redundant.
Tested with the DPM scheduler.
2022-11-29 15:05:05 +05:30
Ean Garvey
6ca2eb3ad7 Update README.md (#515) 2022-11-28 14:09:30 -06:00
Prashant Kumar
74aeb55733 Add support for different schedulers.
Initial support for adding schedulers. This verifies the model running
with the PNDM scheduler too.
2022-11-28 22:12:09 +05:30
Gaurav Shukla
3eb7965ca0 [WEB] Pressing Enter at prompt triggers Image generation
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-11-28 20:56:20 +05:30
Phaneesh Barwaria
04f20070d1 xfail for cpu models with tensor shape inf error (#512) 2022-11-24 16:12:04 -06:00
Gaurav Shukla
88937fcb2f [WEB] Add vulkan-heap-block-size flag
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-11-24 16:58:27 +05:30
aldesilv
f80b85f10c dump spv for dispatches (#509) 2022-11-23 22:34:27 -06:00
Quinn Dawkins
32a2ec432d [Stable Diffusion] Revive the tuned model (#506) 2022-11-23 15:42:24 -05:00
Gaurav Shukla
f4821d0d39 [WEB] Update seed calculation and model versions.
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-11-23 19:21:48 +05:30
Prashant Kumar
fdf2aa54ef Update the sd models. 2022-11-22 23:09:04 +05:30
Gaurav Shukla
275c032264 [WEB] Fix set_param prototype
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-11-22 20:15:45 +05:30
Gaurav Shukla
d88979fe19 [WEB] Enable guidance scale and update seed calculation
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-11-22 19:32:21 +05:30
Phaneesh Barwaria
e67bcffea7 add vulkan-heap-block-size flag (#498) 2022-11-22 13:30:25 +05:30
Ean Garvey
005ded3c6f Update xfails. (#500)
* Update test_models.py

* Fix formatting.
2022-11-22 01:30:34 +05:30
Gaurav Shukla
d624940e12 Remove unnecessary torch_mlir import
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-11-21 21:47:21 +05:30
Gaurav Shukla
7763403b0e [WEB] Cache text-encoder and reorganize the codebase
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-11-21 17:21:12 +05:30
Prashant Kumar
88c58244b9 Update stable diffusion models to point to new location. 2022-11-18 19:39:21 +05:30
yzhang93
0754c6ea20 Update model annotation to take vulkan configs (#495)
Co-authored-by: vivian <vivian@nod-labs.com>
2022-11-17 14:34:17 -08:00
Prashant Kumar
7b1f04d121 Changes incorporating the recent torch_mlir compile api changes. 2022-11-15 15:25:37 +05:30
Phaneesh Barwaria
d8a9bee244 Add internet connection check for re-downloading models (#488) 2022-11-14 13:56:42 -06:00
Phaneesh Barwaria
ac0ea6bd3c xfail albert tf static cpu (#490) 2022-11-14 13:56:26 -06:00
Ean Garvey
45677c1e23 Install torch version required by torch-mlir when setting up importer venv. (#486) 2022-11-14 14:01:01 +05:30
Phaneesh Barwaria
d9f4a9954a modify to get correct target triple (#485) 2022-11-13 20:13:44 -08:00
mariecwhite
ec461a4456 Enable XLA compiler for TF models (#484) 2022-11-13 20:10:47 -08:00
Mehdi Amini
559928e93b Actually print the error message when SharkRunner can't initialize the driver (#482)
Right now it would just terminate the process silently
2022-11-13 19:08:46 -08:00
Mehdi Amini
a526f7d5b8 Fix dispatch saving code after 749a2c2d (#483)
In 749a2c2d iree_device_map and iree_target_map have been made functions
but not all of the uses were updated.
2022-11-14 05:39:01 +05:30
Phaneesh Barwaria
749a2c2dec add support for choosing vulkan device (#439) 2022-11-12 14:00:41 -08:00
Gaurav Shukla
29a317dbb6 [WEB] Update SD styling and prompt loading. (#479)
* [WEB] CSS changes to the web-ui (#465)

This commit updates UI with styling.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>

* [WEB] Update the title (#466)

* [WEB] Add support for long prompts (#467)

* [WEB] fix background color

Signed-Off-by: Gaurav Shukla

* [WEB] Remove long prompts support

It removes support to long prompts due to higher lag in loading long prompts.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs>

* [WEB] Update nod logo and enable debug feature.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
Signed-off-by: Gaurav Shukla
Signed-off-by: Gaurav Shukla <gaurav@nod-labs>
2022-11-10 10:55:22 -08:00
Abhishek Varma
2f36de319a [SHARK_INFERENCE] Add ESRGAN model test file
-- This commit adds ESRGAN model test file to SHARK_INFERENCE.

Signed-off-by: Abhishek Varma <abhishek@nod-ai.com>
2022-11-10 17:12:42 +05:30
Quinn Dawkins
2005bce419 Fix flags for untuned Stable Diffusion FP16 model (#478) 2022-11-09 21:31:10 -05:00
Ean Garvey
8a02d7729d Add a few xfails. (#477) 2022-11-09 09:33:09 -08:00
Prashant Kumar
1cdf301c14 Update the guidance parameter argument and add the int8 version of the
stable diffusion model.
2022-11-08 23:14:44 +05:30
yzhang93
9a86e5c476 Fix dispatch benchmarking tool (#460) 2022-11-08 09:37:12 -08:00
Eliasj42
32d3f4bd5f added ordered benchmarks to dispatch benchmarking tool (#450)
* added ordered benchmarks to dispatch benchmarking tool

* saved changes

* updated readme

Co-authored-by: Elias Joseph <elias@nod-labs.com>
2022-11-07 09:36:21 -08:00
Prashant Kumar
18689afc1a Make separate function for each model. 2022-11-07 20:20:38 +05:30
PhaneeshB
64d6da75c7 Resolve Mac torch-mlir torch setup dependency. Enable MacOS CI 2022-11-07 15:38:37 +05:30
Ean Garvey
1e95e4b502 Change dependency installation order in venv setup script. (#470) 2022-11-04 20:53:54 -05:00
Ean Garvey
c63009a6db Update test_models.py (#464) 2022-11-04 16:59:01 -07:00
Gaurav Shukla
88f8718635 [WEB] Load prompts from json
The prompt examples will now be loaded from a json file `prompts.json`.

Signed-Off-by: Gaurav Shukla
2022-11-02 20:52:34 +05:30
Prashant Kumar
a081733a42 Add the clip text shark_model. (#458) 2022-11-02 00:08:33 -07:00
Gaurav Shukla
06ccfb0533 [WEB] Load vae and unet during server start up
The vae and unet models(both fp16 and fp32 variant) can be loaded at
server startup in order to reduce web response time.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-11-01 23:11:52 +05:30
Gaurav Shukla
b18d75e3f7 [WEB] Use tuned version of UNET fp16
This commit updates SD script in order to use the tuned version of Unet
fp16 model.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-11-01 19:00:21 +05:30
Quinn Dawkins
3e7efaa048 Switch stable diffusion to the new tuned model (#455) 2022-10-31 15:15:31 -07:00
Gaurav Shukla
a3fdfc81db [WEB] Minor changes in the shark web (#454)
1. Default steps = 50.
2. Live preview will yield intermediate image at every 5 steps.
3. Add logs to .gitignore

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-10-31 14:29:00 -07:00
Gaurav Shukla
f4c91df1df [WEB] Add pillow dependency (#453)
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-10-31 12:57:21 -07:00
Prashant Kumar
32e1ba8c0d Adding batch_size support for stable diffusion. 2022-11-01 00:57:52 +05:30
Gaurav Shukla
1939376d72 [WEB] Cache model parameters (#452)
This commit cache some of the model parameters to reduce the response
time of shark web.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-10-31 11:55:10 -07:00
Gaurav Shukla
25931d48a3 [WEB] Update stable diffusion UI and enable live preview (#447)
This commit enables live preview feature and also updates stable
diffusion web UI.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-10-31 04:10:15 -07:00
powderluv
024c5e153a Update Windows in README 2022-10-30 22:27:03 -07:00
powderluv
83f34b645d Add Windows instructions 2022-10-30 22:25:42 -07:00
powderluv
3f9f450e0d Add setup_venv.ps1 for windows (#448)
Powershell users can run ./setup_venv.ps1 to setup the env
2022-10-30 22:17:35 -07:00
powderluv
fd89b06641 Drop RDNA1 for now 2022-10-29 14:29:09 -07:00
Gaurav Shukla
f8dc996004 Update vulkan-target-triple for Radeon devices. (#446)
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-10-29 14:27:20 -07:00
Phaneesh Barwaria
e6a964088b Add os agnostic vulkan device name check (#445) 2022-10-29 13:19:14 -07:00
Gaurav Shukla
e3e767c7eb [WEB] Remove live preview and disable resnet|albert_maskfill
This commit removes live preview feature for now as it's not functional.
This feature will be added in the next patch.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-10-30 00:37:59 +05:30
Quinn Dawkins
239c19eb12 Update Stable diffusion script to enable use of tuned models (#443) 2022-10-29 01:42:49 -04:00
Eliasj42
7f37599a60 Added a dispatch benchmarking tool (#441)
To produce benchmarks of individual dispatches, you can add --dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir> to your command line argument.

Co-authored-by: Elias Joseph <elias@nod-labs.com>
2022-10-28 14:31:03 -07:00
Prashant Kumar
77c9a2c5ea Add profiling vulkan_device info and minor changes to reflect upstream
changes.
2022-10-28 18:02:07 +05:30
Ean Garvey
fd7baae548 Serialize torch-mlir CAPI module as bytecode instead of string. (#435)
* Serialize torch-mlir CAPI as bytecode instead of string.

* Minor fixes to MLIR data handling in SHARK python.
2022-10-27 14:37:15 -05:00
Stanley Winata
01fdf5ee16 [example][SD] compile fp16 with iree-spirv-unify-aliased-resources (#436) 2022-10-27 05:12:28 -07:00
Gaurav Shukla
e52f533c16 [WEB] Save vmfb and add live preview
This commit updates SD script to save the compiled module and also adds
live preview of generated images.

Signed-off-by: Gaurav Shukla<gaurav@nod-labs.com>
2022-10-26 23:20:53 +05:30
Quinn Dawkins
fbd77dc936 Enable iterator space fusion for SD (#432) 2022-10-26 01:08:26 -04:00
Quinn Dawkins
cdc6dd19e3 Force stable diffusion fp16 and fp32 to generate images with similar noise (#431) 2022-10-25 17:28:18 -04:00
PhaneeshB
fd578a48a9 add cli args for vulkan target triple 2022-10-25 21:47:26 +05:30
Ean Garvey
9956099516 Add pytest option for updating tank and fix save_mlir function. (#413)
* Use IREE tf tools to save .mlir modules when generating shark_tank.

* Add option to pytest for enabling auto-updates to local shark tank.

* xfail mobilenet torch on cpu, cuda and fix CI macos setup

* Update test-models.yml to disable macos vulkan CI.
2022-10-25 21:29:18 +05:30
powderluv
f97b8fffed Update README.md 2022-10-24 12:51:49 -07:00
Gaurav Shukla
7b9e309724 [WEB] Expose SD parameters in the web ui (#427) 2022-10-24 04:34:35 -07:00
Quinn Dawkins
1d33913d48 Add option to save and load precompiled flatbuffer (#425) 2022-10-23 16:24:09 -07:00
Prashant Kumar
a48eaaed20 Pass the flags to vae. 2022-10-23 23:57:48 +05:30
Prashant Kumar
2741b8be53 Pass the flags to vae. (#422) 2022-10-23 11:23:13 -07:00
Anush Elangovan
4f906a265c Fix lint 2022-10-22 12:43:52 -07:00
Anush Elangovan
0dff8d7af0 Simple download script to prime the hf model cache 2022-10-21 17:42:05 -07:00
Quinn Dawkins
4f0d0d8167 Update vulkan gui README for iree-vulkan-gui + Stable Diffusion (#399) 2022-10-21 14:02:40 -04:00
Vivek Khandelwal
d513060b21 Add params for Stable Diffusion (#420) 2022-10-21 23:11:09 +05:30
Prashant Kumar
d1a25ce4f3 Update stable_args.py 2022-10-21 17:26:31 +05:30
Gaurav Shukla
51c98695b2 [WEB] Update stable diffusion inference
This commit updates the stable diffusion web incorporating the latest
improvements.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-10-21 01:26:38 +05:30
Quinn Dawkins
b448770ec2 Add ms/iter timing for stable diffusion script (#414) 2022-10-20 13:32:37 -04:00
Prashant Kumar
5fe22a7980 Minor fix. 2022-10-20 22:57:22 +05:30
Prashant Kumar
38ae6b5af4 Add stable_diffusion fp16 and fp32 with args. 2022-10-20 21:47:11 +05:30
Ean Garvey
0bfe30d75d Fix issues with extra_args in benchmarks, pin tf==2.10 (#411) 2022-10-20 06:55:26 -07:00
Quinn Dawkins
7be1d7d0be Add option for extra arguments through SharkInference.compile (#408) 2022-10-19 15:32:48 -05:00
Prashant Kumar
0d74c873f0 Add stable_diff_f16 version. (#407) 2022-10-19 10:04:24 -07:00
powderluv
139aff2938 Update nightly.yml
fix links
2022-10-18 23:42:22 -07:00
anush elangovan
a3f733490c Force update of packages
Pickup tools from upstream IREE
2022-10-19 05:20:53 +00:00
anush elangovan
8a11f138d1 Update SHARK-Runtime releases page 2022-10-19 05:06:36 +00:00
Ean Garvey
3405607917 (TESTING) Fix .whl assets path (#404) 2022-10-14 12:13:14 -05:00
Ean Garvey
7c99a6bd33 Update README.md (#406) 2022-10-13 20:29:49 -05:00
Ean Garvey
3fba8ce0e6 Update README.md (#405) 2022-10-13 12:43:03 -07:00
Ean Garvey
f3bde3c7fc Cleanup tank directory and move instructions to tank/README.md (#401) 2022-10-13 12:20:02 -05:00
Phaneesh Barwaria
21fee8ef33 enable only one workflow job per branch (#402) 2022-10-13 12:15:30 -05:00
Vivek Khandelwal
0e217d6180 Add Stable Diffusion Img2Img model script 2022-10-13 21:56:46 +05:30
Phaneesh Barwaria
00a8ce75d1 Xfail vulkan tests and Enable MacOs test on CI (#383) 2022-10-13 11:14:41 -05:00
Quinn Dawkins
8f3f00cd99 Add iree-run-module like tool for running in a vulkan session (#398) 2022-10-12 20:46:26 -04:00
Ean Garvey
13bae2538a Update URL for IREE compiler/runtime install (#397)
* Update URL for IREE compiler/runtime install

* Update gh-pages-releases.yml

* Update test_models.py

* Update assets path
2022-10-12 15:47:11 -05:00
Ean Garvey
f508c80c23 Add workflow for GH pages releases and release scraping script. (#394)
* Add workflow for GH pages releases and release scraping script.

* Update test_models.py and change tokens for gh pages.
2022-10-11 22:03:33 -05:00
gpetters94
53df0620e3 Add OPT to tank (#214) 2022-10-11 11:03:56 -05:00
powderluv
a63755bc24 Correct spelling 2022-10-11 01:53:55 -07:00
Quinn Dawkins
d93d0783a8 Add script for tensorflow stable diffusion (#391) 2022-10-10 12:01:49 -04:00
Daniel Garvey
d38e37bd99 seperate importer and benchmark deps (#393) 2022-10-08 23:31:20 -05:00
Ean Garvey
3618fb3ada Move old test scripts out of base tank directory and add xfails. (#389) 2022-10-07 16:02:46 -07:00
Vivek Khandelwal
70a29b03e0 Add FP16 Resnet50 script 2022-10-06 21:56:43 +05:30
Ean Garvey
006adf8746 Fix issue with FASTAPI pip install. (#382) 2022-10-01 14:55:24 -05:00
Quinn Dawkins
33b53e7caf Add flag for specifying the vae mlir file location in stable diffusion (#381) 2022-09-30 00:37:58 -04:00
Daniel Garvey
c54815de17 edit assets path (#376) 2022-09-28 16:42:36 -05:00
Gaurav Shukla
0013fb0753 [WEB] Add shark-web logging
1. This commit adds support to display logs in the shark-web.
2. It also adds nod logo in the home page.
3. Stable-diffusion outputs are being saved now.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-09-29 01:20:42 +05:30
Ean Garvey
56f8a0d85a Update torch-mlir releases page in setup_venv.sh (#374)
* Update README.md

* Update setup_venv.sh
2022-09-28 11:07:44 -07:00
Ean Garvey
9035a2eed3 Add --local_tank_cache flag and update requirements. (#368)
* Add --local_tank_cache flag and update requirements.

* Update requirements-importer.txt
2022-09-28 03:02:59 -05:00
Vivek Khandelwal
28daf410b6 Add instructions to use locally build Torch-MLIR with SHARK 2022-09-28 10:16:38 +05:30
Ean Garvey
cbf3f784aa Add pytest option to specify a URL for shark tank artifacts. (#363)
* Xfail updates.

* Generalize tank SHA option to bucket address and add pytest option.
2022-09-27 02:40:40 -05:00
Anush Elangovan
ef4b306c7b Add diffusers and scipy 2022-09-26 13:35:23 -07:00
powderluv
5316c1e0bf Use latest transformers (#346) 2022-09-26 13:11:41 -07:00
Gaurav Shukla
0228973eef [WEB] Fix the mlir location of stable-diffusion model (#367)
Update the location of stable-diffusion mlir file since there is some
problem with iree-compile.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-09-26 10:56:36 -07:00
Gaurav Shukla
d4eeff0a5d [WEB] Add Stable-Diffusion in the SHARK web (#366)
1. This commit adds stable-diffusion as a part of shark web.
2. The V-diffusion model has been disabled for now as it's not
   working(will raise a different patch with fix).
3. Add standard output in the web ui.
4. Add instructions to launch the shark-web.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-09-26 10:42:02 -07:00
Prashant Kumar
c7b2d39ab2 Update stable_diff to contain vae. 2022-09-26 20:11:43 +05:30
Gaurav Shukla
21958cc02a [WEB] Remove unused parameters in the v-diffuison model (#314)
This commit removes unused parameters in the v-diffusion model. It also
updated the server parameters in order to make multiple requests to be
handled sequentially.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-09-25 10:57:06 -07:00
Ean Garvey
de23e5d9d7 update xfails for PyTorch DistilBERT (#355) 2022-09-24 14:53:20 -05:00
Quinn Dawkins
6438bce023 Add a script to convert a jpg to the correct input for resnet50 with the vulkan gui (#362) 2022-09-23 16:32:52 -07:00
yzhang93
587d74b449 Update model annotation tool (#361)
Usage:
with create_context() as ctx:
  module = model_annotation(ctx, input_contents=..., config_path=..., search_op=...)

Example:
The example is to annotate the minilm model with GPU config files.
python model_annotation.py /nodclouddata/vivian/minilm_model/model.mlir /nodclouddata/vivian/minilm_model/model_config.json
2022-09-23 15:44:51 -07:00
Prashant Kumar
b9c8985047 Add sharkdynamo which combines shark with torchdynamo.
-- Adds graph breaks when necessary.
-- Even for loops are supported.
2022-09-23 22:40:02 +05:30
Vivek Khandelwal
93ebe07d2b Add bert_tosa script 2022-09-23 10:52:06 +05:30
Ean Garvey
d82b305781 Fix issues with loading .vmfb into SharkInference 2022-09-23 09:53:13 +05:30
Quinn Dawkins
1df20fac95 [Lockstep] Hack to avoid aten._reshape_alias (#332)
This enforces the decomposition for aten._reshape_alias used in AOTAutograd to essentially avoid having to deal with problems with strides when running in eager mode.
2022-09-22 18:02:09 -04:00
Prashant Kumar
991e7043d1 Add stable diffusion model. 2022-09-22 13:40:51 +05:30
powderluv
1c4d6c23fa Update CMakeLists.txt 2022-09-21 22:48:56 -07:00
Anush Elangovan
87895446a5 Roll SHARK-Runtime 2022-09-22 00:09:04 -07:00
Ean Garvey
c0f3a09a40 Include SHA in path to failure reproducers. Add --save_fails option. (#352) 2022-09-21 17:55:06 -05:00
Anush Elangovan
e9ad4b9fc4 Update SHARK Runtime 2022-09-21 06:31:48 -07:00
Ean Garvey
c061a8897d Add pytest options to save reproducers. (#350)
* Add pytest options to save and/or upload reproducers.

* pass shark_module to benchmark method.
2022-09-20 20:29:46 -05:00
Ean Garvey
4253551b67 Update README with new testing instructions and filter test cases. (#349) 2022-09-20 15:55:46 -05:00
Vivek Khandelwal
e4991c049e Add Readme file for the bloom model 2022-09-20 20:27:52 +05:30
Daniel Garvey
5df582e7e8 creates abstract test case class (#333) 2022-09-20 07:06:38 -07:00
Ean Garvey
814a6f8295 Modify vulkan target triple substring searches. (#318) 2022-09-20 01:20:20 -05:00
Vivek Khandelwal
7013c3cd4a Add bloom e2e script 2022-09-20 10:56:04 +05:30
powderluv
0ddd65b6f1 Create LICENSE 2022-09-19 15:07:59 -07:00
powderluv
44d8f08bfc Fix Torch-MLIR release page 2022-09-17 00:50:39 -07:00
erman-gurses
fc8aa6ae63 Add ROCM parameters (#335) 2022-09-16 09:12:19 -07:00
Quinn Dawkins
9bd951b083 Clean up the v-diffusion install pipeline (#327) 2022-09-16 11:47:07 -04:00
Vivek Khandelwal
c43448a826 Update compile_utils.py 2022-09-15 18:28:10 +05:30
Vivek Khandelwal
864723a473 add bloom model example 2022-09-15 18:23:09 +05:30
Anush Elangovan
3b0ec8ce4e Update resnet paths 2022-09-14 16:56:20 -07:00
Anush Elangovan
174b171913 Clean up SDL linking 2022-09-14 13:18:55 -07:00
powderluv
cfd9733c2b Delete shark_web directory 2022-09-14 06:38:30 -07:00
Anush Elangovan
8d4d543a49 Update shark runtime 2022-09-14 06:14:02 -07:00
Anush Elangovan
1b9c88a052 Update vulkan gui readme 2022-09-13 19:35:47 -07:00
Anush Elangovan
e212ff2071 Fix resnet50 vulkan_gui to work with tank models 2022-09-13 19:22:41 -07:00
Quinn Dawkins
8d21292d34 Fix input tensors with non-floating point dtype in the lockstep tracer (#328) 2022-09-13 21:14:38 -04:00
Anush Elangovan
e304041574 Remove redundant {} 2022-09-13 16:12:35 -07:00
Anush Elangovan
1776c55e73 Fix torch-mlir download URL 2022-09-13 16:07:25 -07:00
Anush Elangovan
4e4c34c717 fix release downloads 2022-09-13 15:00:47 -07:00
Anush Elangovan
23378b6be8 Add resnet to vulkan-gui 2022-09-13 07:06:47 -07:00
Ean Garvey
6cf5564c84 Remove "gpu" device alias and migrate to using "cuda" for NVIDIA GPU. (#325)
* Replace instances of "gpu" alias for devices with "cuda"
2022-09-13 01:16:56 -05:00
Ean Garvey
7143902a90 Update test-models.yml (#323) 2022-09-12 22:47:40 -05:00
Anush Elangovan
15186db73f Hardcode SDL2 for now (works on linux) 2022-09-12 10:17:41 -07:00
powderluv
ccd7a01ce2 Update README.md 2022-09-12 07:12:57 -07:00
Anush Elangovan
1d7035117d Add cpp inference examples and vulkan_gui 2022-09-12 07:07:33 -07:00
Ean Garvey
1710abd366 Update mobilenet_v3_small_torch_test.py (#322) 2022-09-10 15:22:57 -05:00
Ean Garvey
6aeda3670f Split nightly workflow by backend (IREE / SHARK) (#313)
* Fix validation for nightly builds.

* Add option to generate shark_tank inside SHARK project
Add shark_arg for updating tank on mismatched hash (downloader)

* Fixup CI tank dir option.

* Fixup work directory variable
2022-09-09 22:51:30 +05:30
Prashant Kumar
bb52b224d0 Add sparse architecture and test with torchrec SparseArch.
Features that don't work with current implementation:
    -- embeddingbag config with multiple features.
2022-09-09 21:49:30 +05:30
Stanley Winata
95ec3d7216 [tank][v-diffusion] Polish up v-diffusion UX (#315) 2022-09-08 12:55:51 -07:00
powderluv
18872222d3 Update README.md 2022-09-07 01:14:30 -07:00
Ean Garvey
d453f2e49d Enable CPU benchmarks on test-models workflows. (#299)
* Update test-models.yml

* Update README.md
2022-09-07 01:22:58 -05:00
Ean Garvey
3824d37d27 Add metadata to benchmark results. (#297) 2022-09-06 13:03:48 -05:00
Ean Garvey
d946287723 Update xfails for torchvision models. (#310) 2022-09-01 13:06:12 -05:00
Prashant Kumar
885b0969f5 [WEB] Cache the compiled module.
-- Don't compile the module again and again.
2022-09-01 23:08:08 +05:30
Gaurav Shukla
a886cba655 [WEB] Add v_diffusion model in the shark web (#306)
This commit adds adds `v_diffusion` model web visualization as a part of
shark web.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-09-01 06:34:51 -07:00
Vivek Khandelwal
4afe2e3adb Add func to save intermediate images in v_diffusion_pytorch 2022-09-01 18:36:58 +05:30
Gaurav Shukla
fe080eaee6 [WEB] Introduce web interface for the SHARK models (#298)
This commit introduces web application for SHARK using gradio platform.
This adds web visualization of `Resnet50` and `Albert_Maskfill` models
as a start.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-08-31 23:17:52 -07:00
Quinn Dawkins
3703f014d9 Add scripts for generating images on ats-m (#305) 2022-08-31 23:07:02 -07:00
Daniel Garvey
d45a496030 adds a flag to enable directory choice (#303)
individual tests will require implementation of the flag
alternatively, simply passing shark_default_sha in your
individual app's download function will allow for this behavior
2022-08-31 22:17:40 -07:00
powderluv
4ee164c66f remove a100 cpu 2022-08-31 12:59:47 -07:00
powderluv
bf84c033bb add icelake 2022-08-31 12:58:40 -07:00
Prashant Kumar
5105f62551 Add the dlrm_model in shark example. (#301)
-- DLRM model is added in the shark example.
-- The model is verified on cpu, gpu and vulkan.

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2022-08-31 12:54:21 -07:00
Quinn Dawkins
99be837d84 Add lockstep tracer based on TorchMLIR eager mode + examples (#243) 2022-08-31 15:50:24 -04:00
Quinn Dawkins
b7766898ee Add cfg sampling from tank model for v-diffusion and move compilation outside of the sampling loop (#302) 2022-08-31 11:35:04 -07:00
powderluv
57f73dfbc9 Update nightly.yml 2022-08-28 23:59:03 -07:00
powderluv
50b2b9638d Update nightly.yml 2022-08-28 23:43:32 -07:00
Daniel Garvey
1bfd00e2f8 fixes an install issue (#295) 2022-08-25 18:52:00 -05:00
Daniel Garvey
64424877ac No iree instal (#294)
* adds support to default to tuned model

currently setup for tf bert/resnet50
going to refactor test class to avoid having to
add an argument to 50+ files

* adds an option to avoid installing iree

useful when building iree from source
specify env variable NO_BACKEND=1
2022-08-25 15:02:28 -05:00
Phaneesh Barwaria
02d857260c Update ReadMe
-Add gsutil installation for resnet50 example
2022-08-25 20:28:50 +05:30
Phaneesh Barwaria
1322ec5935 Simplified Testing Interface (#289) 2022-08-24 23:54:56 -05:00
Daniel Garvey
48e9818f7e adds support to default to tuned model (#287)
currently setup for tf bert/resnet50
going to refactor test class to avoid having to
add an argument to 50+ files
2022-08-24 16:30:02 -05:00
Ean Garvey
14857770dc Fix local artifact recognition and usage by SHARK downloader. (#286)
* Fix local artifact recognition and usage by SHARK downloader.

* Update generate_sharktank.py

* Update generate_sharktank.py
2022-08-24 14:37:16 -05:00
Vivek Khandelwal
f79a6bf5aa Update setup_v_diffusion_pytorch.sh (#291)
Fix minor issue with v-diffusion PyTorch version
2022-08-24 22:00:02 +05:30
Prashant Kumar
7dc27a7477 Don't remove the latest .whl package from CI. (#290)
Previously, the CI was removing the latest package and pointing to the
stale package.
2022-08-24 09:03:48 -07:00
Chi_Liu
17dba601c8 Add huggingface top5 image classification automodel (#268) 2022-08-22 15:05:38 -07:00
Chi_Liu
064aa3b1f4 Fix tmp dir bug (#285) 2022-08-22 15:00:35 -07:00
Ean Garvey
4960efc686 Update requirements-importer.txt (#284) 2022-08-19 23:21:41 -05:00
Ean Garvey
a3654f33da Fix sourcing for canonical MiniLM shark_tank model artifacts. (#278)
* Fix generation of MiniLM artifacts.

* Fix miniLM output for validation. Xfail numerics failure on mpnet.

* Update distilbert-base-uncased_tf_test.py

* try-except for transition of minilm model
2022-08-17 23:03:47 -05:00
Daniel Garvey
82c541dfb8 fix missing model download path (#281) 2022-08-17 23:02:50 -05:00
Stanley Winata
55bcb2eb3c Level Zero Backend (#280) 2022-08-17 19:19:27 -07:00
Daniel Garvey
1a85550879 fix nightly upload check (#277) 2022-08-17 14:31:15 -05:00
Ean Garvey
334f2f76c4 Update README.md (#273)
* Update README.md

* Update README.md

* Update README.md

* Update README.md

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2022-08-17 10:38:27 -07:00
Vivek Khandelwal
03601ccdd6 Add v_diffusion_pytorch model in shark/tank (#271) 2022-08-17 22:53:31 +05:30
Prashant Kumar
88b0dec0ee Update unet_model to run on shark.
-- Verified unet_model runs on the cpu/gpu/vulkan backend.
2022-08-17 13:16:02 +05:30
Ean Garvey
3514822cac Improvements to pytest benchmarks. (#267)
* Add ONNX env var flags for venv setup.

* Setup arguments for ONNX benchmarking via pytest.

* Enable ONNX benchmarking on MiniLM via pytest (experimental)

* Fix sequence lengths to 128 for TF model creation and fix issue with benchmarks.

* Disable CI CPU benchmarks on A100, change some default args.

* add xfails for roberta TF model tests on GPU.
2022-08-17 02:29:48 -05:00
Ean Garvey
a8b021dc8d Add benchmarks to MHLO miniLM and resnet50 and add dialect, num_iterations (#264) 2022-08-16 13:55:40 -05:00
Daniel Garvey
5e931debd5 Sharktank-ci (#262) 2022-08-15 13:32:24 -05:00
Ean Garvey
22ff92c48b Add config.VmModule argument to from_flatbuffer call. (#266) 2022-08-14 15:11:19 -07:00
powderluv
7f5aaa3477 Update nightly.yml 2022-08-14 12:22:50 -07:00
powderluv
904e0e1444 Update nightly.yml 2022-08-14 09:57:10 -07:00
powderluv
db6e2207ed Update _common.py 2022-08-13 13:49:01 -07:00
Daniel Garvey
7975087ee2 change backend name (#265) 2022-08-13 12:01:12 -07:00
Daniel Garvey
e8482d47f5 split nightly pytest commands (#259)
prevents oom
2022-08-12 16:11:46 -07:00
Ean Garvey
3e900d2b25 Change Resnet50 directory names. (#263) 2022-08-12 16:10:59 -07:00
Ean Garvey
4b5d09fc6c Add TF ResNet50 to tank tests. (#261)
* Add TensorFlow Resnet50 test to shark tank.
2022-08-12 09:20:43 -07:00
Prashant Kumar
02b1e7ac36 Update torch_mlir.compile API.
torch_mlir.compile API is updated and verified by compiling all the
torch models via generate_sharktank script.
2022-08-10 22:50:15 +05:30
Ean Garvey
23619068eb Disable passing of sm_arch to iree-compile CL args by default. (#253)
* Disable passing of sm_arch to iree-compile CL args by default.

* Fix formatting.
2022-08-10 01:19:24 -07:00
powderluv
f7f24dc4d9 Revert "Add Debug log of torch_model_blacklist.txt (#242)" (#249)
This reverts commit 7023d556b5.
2022-08-09 10:23:14 -07:00
powderluv
c2aa451767 Update test-models.yml 2022-08-09 10:12:59 -07:00
Chi_Liu
7023d556b5 Add Debug log of torch_model_blacklist.txt (#242)
* Add debug log of torch_model_blacklist.txt

* Add make_fx for torch model

* Update torch_model_blacklists.txt

* Add some Xfails
2022-08-09 17:54:02 +05:30
powderluv
274650fd43 Update nightly.yml
Add tests for USE_IREE=1
2022-08-07 00:06:11 -07:00
Prashant Kumar
d934765b1d Add mobilenet_v3_small torch model to the test_suite.
-- The model doesn't validate with the correct results on the GPU.
-- The model passes on CPU and levelzero.
-- The static version of the model gets stuck for vulkan.
2022-08-05 14:10:43 +05:30
Ean Garvey
6f5ceb4e61 Update test-models.yml (#244) 2022-08-04 21:56:08 -07:00
Ean Garvey
6c22139ac9 Upload benchmark results for every test-models workflow (excl. Vulkan) (#241)
* Upload benchmark results for every test-models workflow (excl. Vulkan)
2022-08-04 14:43:07 -07:00
powderluv
1c4f5e0c34 Add M1 Max and Pro variants 2022-08-04 13:45:34 -07:00
Daniel Garvey
7dc0a4f74d fine tune with shark (#211) 2022-08-04 13:14:57 -05:00
Chi_Liu
90fddc6cb0 Add more torch hg model tests (#238) 2022-08-03 18:00:04 -07:00
Quinn Dawkins
934f15ebb7 Fix IREE eager backend device string (#237) 2022-08-03 12:09:52 -07:00
369 changed files with 71013 additions and 5824 deletions

5
.flake8 Normal file
View File

@@ -0,0 +1,5 @@
[flake8]
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py

37
.github/workflows/gh-pages-releases.yml vendored Normal file
View File

@@ -0,0 +1,37 @@
# See: https://github.com/llvm/torch-mlir/issues/1374
name: Publish releases page
on:
workflow_dispatch:
jobs:
scrape_and_publish_releases:
name: "Scrape and publish releases"
runs-on: ubuntu-latest
# Don't run this in everyone's forks.
if: github.repository == 'nod-ai/SHARK'
steps:
- name: Checking out repository
uses: actions/checkout@v2
with:
token: ${{ secrets.NODAI_INVOCATION_TOKEN }}
- name: Run scrape releases script
run: python ./build_tools/scrape_releases.py nod-ai SHARK > /tmp/index.html
shell: bash
- run: git fetch --all
- run: git switch github-pages
- run: git config --global user.email "none@none.com"
- run: git config --global user.name "nod-ai"
- run: mv /tmp/index.html package-index/index.html
- run: git add package-index/index.html
# Only try to make a commit if the file has changed.
- run: git diff --cached --exit-code || git commit -m "Update releases."
- name: GitHub Push
uses: ad-m/github-push-action@v0.6.0
with:
github_token: ${{ secrets.NODAI_INVOCATION_TOKEN }}
branch: github-pages

View File

@@ -9,13 +9,80 @@ on:
workflow_dispatch:
jobs:
build:
windows-build:
runs-on: 7950X
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Compute version
shell: powershell
run: |
$package_version = $(Get-Date -UFormat "%Y%m%d")+"."+${{ github.run_number }}
$package_version_ = $(Get-Date -UFormat "%Y%m%d")+"_"+${{ github.run_number }}
$tag_name=$package_version
echo "package_version=$package_version" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
echo "package_version_=$package_version_" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
echo "tag_name=$tag_name" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
tag_name: ${{ env.tag_name }}
release_name: nod.ai SHARK ${{ env.tag_name }}
body: |
Automatic snapshot release of nod.ai SHARK.
draft: true
prerelease: true
- name: Build Package
shell: powershell
run: |
./setup_venv.ps1
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets
uses: dwenegar/upload-release-assets@v1
env:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
assets_path: ./dist/nodai*
#asset_content_type: application/vnd.microsoft.portable-executable
- name: Publish Release
id: publish_release
uses: eregon/publish-release@v1
env:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
linux-build:
runs-on: a100
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
python-version: ["3.11"]
backend: [IREE, SHARK]
steps:
- uses: actions/checkout@v3
@@ -31,63 +98,56 @@ jobs:
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Compute version
run: |
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
tag_name="${package_version}"
echo "package_version=${package_version}" >> $GITHUB_ENV
echo "tag_name=${tag_name}" >> $GITHUB_ENV
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
tag_name: ${{ env.tag_name }}
release_name: nod.ai SHARK ${{ env.tag_name }}
body: |
Automatic snapshot release of nod.ai SHARK.
draft: true
prerelease: false
- name: Install dependencies
run: |
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/SHARK-Runtime/releases; fi
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
- name: Build and validate the IREE package
if: ${{ matrix.backend == 'IREE' }}
continue-on-error: true
run: |
cd $GITHUB_WORKSPACE
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
source iree.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt
if !(grep -Fxq " failed" pytest_results.txt)
then
export SHA=$(git log -1 --format='%h')
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
fi
rm -rf ./wheelhouse/nodai*
- name: Build and validate the package
- name: Build and validate the SHARK Runtime package
if: ${{ matrix.backend == 'SHARK' }}
run: |
cd $GITHUB_WORKSPACE
./setup_venv.sh
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/SHARK-Runtime/releases
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
pytest -k 'not benchmark' --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py --ignore=shark/tests/test_shark_importer.py --ignore=tank/tf/
- name: Upload Release Assets
id: upload-release-assets
uses: dwenegar/upload-release-assets@v1
env:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
assets_path: ./wheelhouse/nodai_*.whl
- name: Publish Release
id: publish_release
uses: eregon/publish-release@v1
env:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt

View File

@@ -6,53 +6,80 @@ name: Validate Models on Shark Runtime
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [a100, MI100, MacStudio, ubuntu-latest]
suite: [cpu,gpu,vulkan]
python-version: ["3.10"]
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
suite: [cpu,cuda,vulkan]
python-version: ["3.11"]
include:
- os: ubuntu-latest
suite: lint
- os: MI100
suite: rocm
- os: MacStudio
suite: metal
exclude:
- os: ubuntu-latest
suite: vulkan
- os: ubuntu-latest
suite: gpu
suite: cuda
- os: ubuntu-latest
suite: cpu
- os: MacStudio
suite: gpu
suite: cuda
- os: MacStudio
suite: cpu
- os: MacStudio
suite: vulkan
- os: MI100
suite: gpu
- os: MI100
- os: icelake
suite: vulkan
- os: icelake
suite: cuda
- os: a100
suite: cpu
- os: 7950x
suite: cpu
- os: 7950x
suite: cuda
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set Environment Variables
if: matrix.os != '7950x'
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest'
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
run: |
# See https://github.com/actions/setup-python/issues/433
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest'
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'
@@ -72,40 +99,65 @@ jobs:
run: |
# black format check
black --version
black --line-length 79 --check .
black --check .
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude lit.cfg.py
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude lit.cfg.py
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate CPU Models
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest -k 'cpu' --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
python build_tools/vicuna_testing.py
- name: Validate GPU/CUDA Models
if: matrix.suite == 'gpu'
- name: Validate Models on NVIDIA GPU
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest -k "gpu" --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models
if: matrix.suite == 'vulkan'
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'metal' && matrix.os == 'MacStudio'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest -k 'vulkan' --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
- name: Validate GPU/ROCM Models
if: matrix.suite == 'rocm'
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest -k 'rocm' --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
pytest -k vulkan -s --ci
- name: Validate Stable Diffusion Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
python build_tools/stable_diffusion_testing.py --device=vulkan

36
.gitignore vendored
View File

@@ -2,6 +2,8 @@
__pycache__/
*.py[cod]
*$py.class
*.mlir
*.vmfb
# C extensions
*.so
@@ -31,7 +33,6 @@ MANIFEST
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
@@ -158,12 +159,43 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
# vscode related
.vscode
# Shark related artefacts
*venv/
shark_tmp/
*.vmfb
.use-iree
tank/dict_configs.py
*.csv
reproducers/
# ORT related artefacts
cache_models/
onnx_models/
# Generated images
generated_imgs/
# Custom model related artefacts
variants.json
models/
# models folder
apps/stable_diffusion/web/models/
# Stencil annotators.
stencil_annotator/
# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData
# Embeded browser cache and other
apps/stable_diffusion/web/EBWebView/
# Llama2 tokenizer configs
llama2_tokenizer_configs/

2
.gitmodules vendored
View File

@@ -1,4 +1,4 @@
[submodule "inference/thirdparty/shark-runtime"]
path = inference/thirdparty/shark-runtime
url =https://github.com/nod-ai/SHARK-Runtime.git
url =https://github.com/nod-ai/SRT.git
branch = shark-06032022

View File

@@ -1,3 +0,0 @@
[style]
based_on_style = google
column_limit = 80

218
LICENSE Normal file
View File

@@ -0,0 +1,218 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
---- LLVM Exceptions to the Apache 2.0 License ----
As an exception, if, as a result of your compiling your source code, portions
of this Software are embedded into an Object form of such source code, you
may redistribute such embedded portions in such Object form without complying
with the conditions of Sections 4(a), 4(b) and 4(d) of the License.
In addition, if you combine or link compiled forms of this Software with
software that is licensed under the GPLv2 ("Combined Software") and if a
court of competent jurisdiction determines that the patent provision (Section
3), the indemnity provision (Section 9) or other Section of the License
conflicts with the conditions of the GPLv2, you may retroactively and
prospectively choose to deem waived or otherwise exclude such Section(s) of
the License, but only in their entirety and only with respect to the Combined
Software.

454
README.md
View File

@@ -1,29 +1,161 @@
# SHARK
High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerators and Heterogeneous Clusters
High Performance Machine Learning Distribution
[![Nightly Release](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
[![Validate torch-models on Shark Runtime](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
## Communication Channels
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
## Installation
<details>
<summary>Installation (Linux and macOS)</summary>
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
#### Linux Drivers
* MESA / RADV drivers wont work with FP16. Please use the latest AMGPU-PRO drivers (non-pro OSS drivers also wont work) or the latest NVidia Linux Drivers.
Other users please ensure you have your latest vendor drivers and Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home) and if you are using vulkan check `vulkaninfo` works in a terminal window
</details>
### Quick Start for SHARK Stable Diffusion for Windows 10/11 Users
Install the Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
Download the [stable release](https://github.com/nod-ai/shark/releases/latest)
Double click the .exe and you should have the [UI](http://localhost:8080/) in the browser.
If you have custom models put them in a `models/` directory where the .exe is.
Enjoy.
<details>
<summary>More installation notes</summary>
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
## Running
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE)
* The first run may take few minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
* You will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/.
## Stopping
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
</details>
<details>
<summary>Advanced Installation (Only for developers)</summary>
## Advanced Installation (Windows, Linux and macOS) for developers
## Check out the code
```shell
git clone https://github.com/nod-ai/SHARK.git
cd SHARK
```
## Setup your Python VirtualEnvironment and Dependencies
### Windows 10/11 Users
* Install the latest Python 3.11.x version from [here](https://www.python.org/downloads/windows/)
* Install Git for Windows from [here](https://git-scm.com/download/win)
#### Allow the install script to run in Powershell
```powershell
set-executionpolicy remotesigned
```
#### Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...)
```powershell
./setup_venv.ps1 #You can re-run this script to get the latest version
```
### Linux / macOS Users
```shell
./setup_venv.sh
source shark.venv/bin/activate
```
### Run Stable Diffusion on your device - WebUI
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
```
#### Linux / macOS Users
```shell
(shark.venv) > cd apps/stable_diffusion/web
(shark.venv) > python index.py
```
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
### Run Stable Diffusion on your device - Commandline
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
```shell
python3.11 apps/stable_diffusion/scripts/main.py --app=txt2img --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
</details>
The output on a AMD 7900XTX would look something like:
```shell
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
Total image generation time: 2.5788655281066895sec
```
Here are some samples generated:
![tajmahal, snow, sunflowers, oil on canvas_0](https://user-images.githubusercontent.com/74956/204934186-141f7e43-6eb2-4e89-a99c-4704d20444b3.jpg)
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
<details>
<summary>Binary Installation</summary>
### Setup a new pip Virtual Environment
This step sets up a new VirtualEnv for Python
```shell
python --version #Check you have 3.7->3.10 on Linux or 3.10 on macOS
python --version #Check you have 3.11 on Linux, macOS or Windows Powershell
python -m venv shark_venv
source shark_venv/bin/activate
source shark_venv/bin/activate # Use shark_venv/Scripts/activate on Windows
# If you are using conda create and activate a new conda env
@@ -35,19 +167,24 @@ python -m pip install --upgrade pip
### Install SHARK
This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
```shell
pip install nodai-shark -f https://github.com/nod-ai/SHARK/releases -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/shark-runtime/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```
If you are on an Intel macOS machine you need this [workaround](https://github.com/nod-ai/SHARK/issues/102) for an upstream issue.
### Run shark tank model tests.
```shell
pytest tank/test_models.py
```
See tank/README.md for a more detailed walkthrough of our pytest suite and CLI.
### Download and run Resnet50 sample
```shell
curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/resnet50_script.py
#Install deps for test script
pip install --pre torch torchvision torchaudio tqdm pillow --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install --pre torch torchvision torchaudio tqdm pillow gsutil --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python ./resnet50_script.py --device="cpu" #use cuda or vulkan or metal
```
@@ -61,77 +198,84 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
</details>
<details>
<summary>Source Installation</summary>
<summary>Development, Testing and Benchmarks</summary>
## Check out the code
```shell
git clone https://github.com/nod-ai/SHARK.git
If you want to use Python3.11 and with TF Import tools you can use the environment variables like:
Set `USE_IREE=1` to use upstream IREE
```
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
```
## Setup your Python VirtualEnvironment and Dependencies
```shell
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
./setup_venv.sh
source shark.venv/bin/activate
```
For example if you want to use Python3.10 and upstream IREE with TF Import tools you can use the environment variables like:
```
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 USE_IREE=1 ./setup_venv.sh
```
If you are a Torch-mlir developer or an IREE developer and want to test local changes you can uninstall
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
with Python bindings and set your PYTHONPATH as mentioned [here](https://google.github.io/iree/bindings/python/)
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
for Torch-MLIR.
### Run a demo script
### Run any of the hundreds of SHARK tank models via the test framework
```shell
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
# Or a pytest
pytest tank/tf/hf_masked_lm/albert-base-v2_test.py::AlbertBaseModuleTest::test_module_static_cpu
pytest tank/test_models.py -k "MiniLM"
```
### How to use your locally built IREE / Torch-MLIR with SHARK
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
for Torch-MLIR.
How to use your locally built Torch-MLIR with SHARK:
```shell
1.) Run `./setup_venv.sh in SHARK` and activate `shark.venv` virtual env.
2.) Run `pip uninstall torch-mlir`.
3.) Go to your local Torch-MLIR directory.
4.) Activate mlir_venv virtual envirnoment.
5.) Run `pip uninstall -r requirements.txt`.
6.) Run `pip install -r requirements.txt`.
7.) Build Torch-MLIR.
8.) Activate shark.venv virtual environment from the Torch-MLIR directory.
8.) Run `export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples` in the Torch-MLIR directory.
9.) Go to the SHARK directory.
```
Now the SHARK will use your locally build Torch-MLIR repo.
## Benchmarking Dispatches
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
```
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
```
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
```
shark_module = SharkInference(
mlir_model,
func_name,
device=args.device,
mlir_dialect="tm_tensor",
dispatch_benchmarks="all",
dispatch_benchmarks_dir="results"
)
```
Output will include:
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
- An .mlir file containing the dispatch benchmark
- A compiled .vmfb file containing the dispatch benchmark
- An .mlir file containing just the hal executable
- A compiled .vmfb file of the hal executable
- A .txt file containing benchmark output
See tank/README.md for further instructions on how to run model tests and benchmarks from the SHARK tank.
</details>
<details>
<summary>Testing</summary>
### Run all model tests on CPU/GPU/VULKAN/Metal
```shell
pytest tank
# If on Linux for multithreading on CPU (faster results):
pytest tank -n auto
```
### Running specific tests
```shell
# Run tests for a specific model:
pytest tank/<MODEL_NAME> #i.e., pytest tank/bert-base-uncased
# Run tests for a specific case:
pytest tank/<MODEL_NAME> -k "keyword"
# i.e., pytest tank/bert-base-uncased/bert-base-uncased_test.py -k "static_gpu"
```
### Run benchmarks on SHARK tank pytests and generate bench_results.csv with results.
(requires source installation with `IMPORTER=1 ./setup_venv.sh`)
```shell
pytest --benchmark tank
# Just do static GPU benchmarks for PyTorch tests:
pytest --benchmark tank --ignore-glob="_tf*" -k "static_gpu"
```
</details>
<details>
<summary>API Reference</summary>
@@ -182,160 +326,26 @@ result = shark_module.forward((arg0, arg1))
```
</details>
## Supported and Validated Models
<details>
<summary>PyTorch Models</summary>
SHARK is maintained to support the latest innovations in ML Models:
### Huggingface PyTorch Models
| TF HuggingFace Models | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------|----------|-------------|
| BERT | :green_heart: | :green_heart: | :green_heart: |
| DistilBERT | :green_heart: | :green_heart: | :green_heart: |
| GPT2 | :green_heart: | :green_heart: | :green_heart: |
| BLOOM | :green_heart: | :green_heart: | :green_heart: |
| Stable Diffusion | :green_heart: | :green_heart: | :green_heart: |
| Vision Transformer | :green_heart: | :green_heart: | :green_heart: |
| ResNet50 | :green_heart: | :green_heart: | :green_heart: |
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------------------|----------|----------|-------------|
| BERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
| Albert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
| BigBird | :green_heart: (AOT) | | | |
| DistilBERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
| GPT2 | :broken_heart: (AOT) | | | |
| MobileBert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
For a complete list of the models supported in SHARK, please refer to [tank/README.md](https://github.com/nod-ai/SHARK/blob/main/tank/README.md).
### Torchvision Models
## Communication Channels
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|--------------------|----------------------|----------|----------|-------------|
| AlexNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| DenseNet121 | :green_heart: (Script) | | | |
| MNasNet1_0 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| MobileNetV2 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| MobileNetV3 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| Unet | :broken_heart: (Script) | | | |
| Resnet18 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| Resnet50 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| Resnet101 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| Resnext50_32x4d | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| ShuffleNet_v2 | :broken_heart: (Script) | | | |
| SqueezeNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| EfficientNet | :green_heart: (Script) | | | |
| Regnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| Resnest | :broken_heart: (Script) | | | |
| Vision Transformer | :green_heart: (Script) | | | |
| VGG 16 | :green_heart: (Script) | :green_heart: | :green_heart: | |
| Wide Resnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| RAFT | :broken_heart: (JIT) | | | |
For more information refer to [MODEL TRACKING SHEET](https://docs.google.com/spreadsheets/d/15PcjKeHZIrB5LfDyuw7DGEEE8XnQEX2aX8lm8qbxV8A/edit#gid=0)
### PyTorch Training Models
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------------------|----------|----------|-------------|
| BERT | :broken_heart: | :broken_heart: | | |
| FullyConnected | :green_heart: | :green_heart: | | |
</details>
<details>
<summary>JAX Models</summary>
### JAX Models
| Models | JAX-MHLO lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------------------|----------|----------|-------------|
| DALL-E | :broken_heart: | :broken_heart: | | |
| FullyConnected | :green_heart: | :green_heart: | | |
</details>
<details>
<summary>TFLite Models</summary>
### TFLite Models
| Models | TOSA/LinAlg | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------------------|----------|----------|-------------|
| BERT | :broken_heart: | :broken_heart: | | |
| FullyConnected | :green_heart: | :green_heart: | | |
| albert | :green_heart: | :green_heart: | | |
| asr_conformer | :green_heart: | :green_heart: | | |
| bird_classifier | :green_heart: | :green_heart: | | |
| cartoon_gan | :green_heart: | :green_heart: | | |
| craft_text | :green_heart: | :green_heart: | | |
| deeplab_v3 | :green_heart: | :green_heart: | | |
| densenet | :green_heart: | :green_heart: | | |
| east_text_detector | :green_heart: | :green_heart: | | |
| efficientnet_lite0_int8 | :green_heart: | :green_heart: | | |
| efficientnet | :green_heart: | :green_heart: | | |
| gpt2 | :green_heart: | :green_heart: | | |
| image_stylization | :green_heart: | :green_heart: | | |
| inception_v4 | :green_heart: | :green_heart: | | |
| inception_v4_uint8 | :green_heart: | :green_heart: | | |
| lightning_fp16 | :green_heart: | :green_heart: | | |
| lightning_i8 | :green_heart: | :green_heart: | | |
| lightning | :green_heart: | :green_heart: | | |
| magenta | :green_heart: | :green_heart: | | |
| midas | :green_heart: | :green_heart: | | |
| mirnet | :green_heart: | :green_heart: | | |
| mnasnet | :green_heart: | :green_heart: | | |
| mobilebert_edgetpu_s_float | :green_heart: | :green_heart: | | |
| mobilebert_edgetpu_s_quant | :green_heart: | :green_heart: | | |
| mobilebert | :green_heart: | :green_heart: | | |
| mobilebert_tf2_float | :green_heart: | :green_heart: | | |
| mobilebert_tf2_quant | :green_heart: | :green_heart: | | |
| mobilenet_ssd_quant | :green_heart: | :green_heart: | | |
| mobilenet_v1 | :green_heart: | :green_heart: | | |
| mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
| mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
| mobilenet_v2 | :green_heart: | :green_heart: | | |
| mobilenet_v2_uint8 | :green_heart: | :green_heart: | | |
| mobilenet_v3-large | :green_heart: | :green_heart: | | |
| mobilenet_v3-large_uint8 | :green_heart: | :green_heart: | | |
| mobilenet_v35-int8 | :green_heart: | :green_heart: | | |
| nasnet | :green_heart: | :green_heart: | | |
| person_detect | :green_heart: | :green_heart: | | |
| posenet | :green_heart: | :green_heart: | | |
| resnet_50_int8 | :green_heart: | :green_heart: | | |
| rosetta | :green_heart: | :green_heart: | | |
| spice | :green_heart: | :green_heart: | | |
| squeezenet | :green_heart: | :green_heart: | | |
| ssd_mobilenet_v1 | :green_heart: | :green_heart: | | |
| ssd_mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
| ssd_mobilenet_v2_fpnlite | :green_heart: | :green_heart: | | |
| ssd_mobilenet_v2_fpnlite_uint8 | :green_heart: | :green_heart: | | |
| ssd_mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
| ssd_mobilenet_v2 | :green_heart: | :green_heart: | | |
| ssd_spaghettinet_large | :green_heart: | :green_heart: | | |
| ssd_spaghettinet_large_uint8 | :green_heart: | :green_heart: | | |
| visual_wake_words_i8 | :green_heart: | :green_heart: | | |
</details>
<details>
<summary>TF Models</summary>
### Tensorflow Models (Inference)
| Hugging Face Models | tf-mhlo lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------------------|----------|----------|-------------|
| BERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| albert-base-v2 | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| DistilBERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| CamemBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| ConvBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| Deberta | | | | |
| electra | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| funnel | | | | |
| layoutlm | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| longformer | | | | |
| mobile-bert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| remembert | | | | |
| tapas | | | | |
| flaubert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| xlm-roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| mpnet | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
</details>
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
## Related Projects

0
apps/__init__.py Normal file
View File

View File

@@ -0,0 +1,16 @@
## CodeGen Setup using SHARK-server
### Setup Server
- clone SHARK and setup the venv
- host the server using `python apps/stable_diffusion/web/index.py --api --server_port=<PORT>`
- default server address is `http://0.0.0.0:8080`
### Setup Client
1. fauxpilot-vscode (VSCode Extension):
- Code for the extension can be found [here](https://github.com/Venthe/vscode-fauxpilot)
- PreReq: VSCode extension (will need [`nodejs` and `npm`](https://nodejs.org/en/download) to compile and run the extension)
- Compile and Run the extension on VSCode (press F5 on VSCode), this opens a new VSCode window with the extension running
- Open VSCode settings, search for fauxpilot in settings and modify `server : http://<IP>:<PORT>`, `Model : codegen` , `Max Lines : 30`
2. Others (REST API curl, OpenAI Python bindings) as shown [here](https://github.com/fauxpilot/fauxpilot/blob/main/documentation/client.md)
- using Github Copilot VSCode extension with SHARK-server needs more work to be functional.

View File

@@ -0,0 +1,18 @@
# Langchain
## How to run the model
1.) Install all the dependencies by running:
```shell
pip install -r apps/language_models/langchain/langchain_requirements.txt
sudo apt-get install -y libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
```
2.) Create a folder named `user_path` in `apps/language_models/langchain/` directory.
Now, you are ready to use the model.
3.) To run the model, run the following command:
```shell
python apps/language_models/langchain/gen.py --cli=True
```

View File

@@ -0,0 +1,186 @@
import copy
import torch
from evaluate_params import eval_func_param_names
from gen import Langchain
from prompter import non_hf_types
from utils import clear_torch_cache, NullContext, get_kwargs
def run_cli( # for local function:
base_model=None,
lora_weights=None,
inference_server=None,
debug=None,
chat_context=None,
examples=None,
memory_restriction_level=None,
# for get_model:
score_model=None,
load_8bit=None,
load_4bit=None,
load_half=None,
load_gptq=None,
use_safetensors=None,
infer_devices=None,
tokenizer_base_model=None,
gpu_id=None,
local_files_only=None,
resume_download=None,
use_auth_token=None,
trust_remote_code=None,
offload_folder=None,
compile_model=None,
# for some evaluate args
stream_output=None,
prompt_type=None,
prompt_dict=None,
temperature=None,
top_p=None,
top_k=None,
num_beams=None,
max_new_tokens=None,
min_new_tokens=None,
early_stopping=None,
max_time=None,
repetition_penalty=None,
num_return_sequences=None,
do_sample=None,
chat=None,
langchain_mode=None,
langchain_action=None,
document_choice=None,
top_k_docs=None,
chunk=None,
chunk_size=None,
# for evaluate kwargs
src_lang=None,
tgt_lang=None,
concurrency_count=None,
save_dir=None,
sanitize_bot_response=None,
model_state0=None,
max_max_new_tokens=None,
is_public=None,
max_max_time=None,
raise_generate_gpu_exceptions=None,
load_db_if_exists=None,
dbs=None,
user_path=None,
detect_user_path_changes_every_query=None,
use_openai_embedding=None,
use_openai_model=None,
hf_embedding_model=None,
db_type=None,
n_jobs=None,
first_para=None,
text_limit=None,
verbose=None,
cli=None,
reverse_docs=None,
use_cache=None,
auto_reduce_chunks=None,
max_chunks=None,
model_lock=None,
force_langchain_evaluate=None,
model_state_none=None,
# unique to this function:
cli_loop=None,
):
Langchain.check_locals(**locals())
score_model = "" # FIXME: For now, so user doesn't have to pass
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
device = "cpu" if n_gpus == 0 else "cuda"
context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
with context_class(device):
from functools import partial
# get score model
smodel, stokenizer, sdevice = Langchain.get_score_model(
reward_type=True,
**get_kwargs(
Langchain.get_score_model,
exclude_names=["reward_type"],
**locals()
)
)
model, tokenizer, device = Langchain.get_model(
reward_type=False,
**get_kwargs(
Langchain.get_model, exclude_names=["reward_type"], **locals()
)
)
model_dict = dict(
base_model=base_model,
tokenizer_base_model=tokenizer_base_model,
lora_weights=lora_weights,
inference_server=inference_server,
prompt_type=prompt_type,
prompt_dict=prompt_dict,
)
model_state = dict(model=model, tokenizer=tokenizer, device=device)
model_state.update(model_dict)
my_db_state = [None]
fun = partial(
Langchain.evaluate,
model_state,
my_db_state,
**get_kwargs(
Langchain.evaluate,
exclude_names=["model_state", "my_db_state"]
+ eval_func_param_names,
**locals()
)
)
example1 = examples[-1] # pick reference example
all_generations = []
while True:
clear_torch_cache()
instruction = input("\nEnter an instruction: ")
if instruction == "exit":
break
eval_vars = copy.deepcopy(example1)
eval_vars[eval_func_param_names.index("instruction")] = eval_vars[
eval_func_param_names.index("instruction_nochat")
] = instruction
eval_vars[eval_func_param_names.index("iinput")] = eval_vars[
eval_func_param_names.index("iinput_nochat")
] = "" # no input yet
eval_vars[
eval_func_param_names.index("context")
] = "" # no context yet
# grab other parameters, like langchain_mode
for k in eval_func_param_names:
if k in locals():
eval_vars[eval_func_param_names.index(k)] = locals()[k]
gener = fun(*tuple(eval_vars))
outr = ""
res_old = ""
for gen_output in gener:
res = gen_output["response"]
extra = gen_output["sources"]
if base_model not in non_hf_types or base_model in ["llama"]:
if not stream_output:
print(res)
else:
# then stream output for gradio that has full output each generation, so need here to show only new chars
diff = res[len(res_old) :]
print(diff, end="", flush=True)
res_old = res
outr = res # don't accumulate
else:
outr += res # just is one thing
if extra:
# show sources at end after model itself had streamed to std rest of response
print(extra, flush=True)
all_generations.append(outr + "\n")
if not cli_loop:
break
return all_generations

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,103 @@
from enum import Enum
class PromptType(Enum):
custom = -1
plain = 0
instruct = 1
quality = 2
human_bot = 3
dai_faq = 4
summarize = 5
simple_instruct = 6
instruct_vicuna = 7
instruct_with_end = 8
human_bot_orig = 9
prompt_answer = 10
open_assistant = 11
wizard_lm = 12
wizard_mega = 13
instruct_vicuna2 = 14
instruct_vicuna3 = 15
wizard2 = 16
wizard3 = 17
instruct_simple = 18
wizard_vicuna = 19
openai = 20
openai_chat = 21
gptj = 22
prompt_answer_openllama = 23
vicuna11 = 24
mptinstruct = 25
mptchat = 26
falcon = 27
class DocumentChoices(Enum):
All_Relevant = 0
All_Relevant_Only_Sources = 1
Only_All_Sources = 2
Just_LLM = 3
non_query_commands = [
DocumentChoices.All_Relevant_Only_Sources.name,
DocumentChoices.Only_All_Sources.name,
]
class LangChainMode(Enum):
"""LangChain mode"""
DISABLED = "Disabled"
CHAT_LLM = "ChatLLM"
LLM = "LLM"
ALL = "All"
WIKI = "wiki"
WIKI_FULL = "wiki_full"
USER_DATA = "UserData"
MY_DATA = "MyData"
GITHUB_H2OGPT = "github h2oGPT"
H2O_DAI_DOCS = "DriverlessAI docs"
class LangChainAction(Enum):
"""LangChain action"""
QUERY = "Query"
# WIP:
# SUMMARIZE_MAP = "Summarize_map_reduce"
SUMMARIZE_MAP = "Summarize"
SUMMARIZE_ALL = "Summarize_all"
SUMMARIZE_REFINE = "Summarize_refine"
no_server_str = no_lora_str = no_model_str = "[None/Remove]"
# from site-packages/langchain/llms/openai.py
# but needed since ChatOpenAI doesn't have this information
model_token_mapping = {
"gpt-4": 8192,
"gpt-4-0314": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16 * 1024,
"gpt-3.5-turbo-0301": 4096,
"text-ada-001": 2049,
"ada": 2049,
"text-babbage-001": 2040,
"babbage": 2049,
"text-curie-001": 2049,
"curie": 2049,
"davinci": 2049,
"text-davinci-003": 4097,
"text-davinci-002": 4097,
"code-davinci-002": 8001,
"code-davinci-001": 8001,
"code-cushman-002": 2048,
"code-cushman-001": 2048,
}
source_prefix = "Sources [Score | Link]:"
source_postfix = "End Sources<p>"

View File

@@ -0,0 +1,53 @@
no_default_param_names = [
"instruction",
"iinput",
"context",
"instruction_nochat",
"iinput_nochat",
]
gen_hyper = [
"temperature",
"top_p",
"top_k",
"num_beams",
"max_new_tokens",
"min_new_tokens",
"early_stopping",
"max_time",
"repetition_penalty",
"num_return_sequences",
"do_sample",
]
eval_func_param_names = (
[
"instruction",
"iinput",
"context",
"stream_output",
"prompt_type",
"prompt_dict",
]
+ gen_hyper
+ [
"chat",
"instruction_nochat",
"iinput_nochat",
"langchain_mode",
"langchain_action",
"top_k_docs",
"chunk",
"chunk_size",
"document_choice",
]
)
# form evaluate defaults for submit_nochat_api
eval_func_param_names_defaults = eval_func_param_names.copy()
for k in no_default_param_names:
if k in eval_func_param_names_defaults:
eval_func_param_names_defaults.remove(k)
eval_extra_columns = ["prompt", "response", "score"]

View File

@@ -0,0 +1,846 @@
from __future__ import annotations
from typing import (
Any,
Mapping,
Optional,
Dict,
List,
Sequence,
Tuple,
Union,
Protocol,
)
import inspect
import json
import warnings
from pathlib import Path
import yaml
from abc import ABC, abstractmethod
import langchain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.question_answering import stuff_prompt
from langchain.prompts.base import BasePromptTemplate
from langchain.docstore.document import Document
from langchain.callbacks.manager import (
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult, PromptValue
from pydantic import Extra, Field, root_validator, validator
def _get_verbosity() -> bool:
return langchain.verbose
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template."""
base_info = {"page_content": doc.page_content}
base_info.update(doc.metadata)
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)
class Chain(Serializable, ABC):
"""Base interface that all chains should implement."""
memory: Optional[BaseMemory] = None
callbacks: Callbacks = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(
default=None, exclude=True
)
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
tags: Optional[List[str]] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def _chain_type(self) -> str:
raise NotImplementedError("Saving not supported for this chain type.")
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Input keys this chain expects."""
@property
@abstractmethod
def output_keys(self) -> List[str]:
"""Output keys this chain expects."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present."""
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
missing_keys = set(self.output_keys).difference(outputs)
if missing_keys:
raise ValueError(f"Missing some output keys: {missing_keys}")
@abstractmethod
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
input_docs = inputs["input_documents"]
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
)
if "is_first" in inputs.keys() and not inputs["is_first"]:
run_manager_ = run_manager
input_list = [inputs]
stop = None
prompts = []
for inputs in input_list:
selected_inputs = {
k: inputs[k] for k in self.prompt.input_variables
}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager_:
run_manager_.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
prompt_strings = [p.to_string() for p in prompts]
prompts = prompt_strings
callbacks = run_manager_.get_child() if run_manager_ else None
tags = None
"""Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will
# not make sense.
if not isinstance(prompts, list):
raise ValueError(
"Argument 'prompts' is expected to be of type List[str], received"
f" argument of type {type(prompts)}."
)
params = self.llm.dict()
params["stop"] = stop
options = {"stop": stop}
disregard_cache = self.llm.cache is not None and not self.llm.cache
callback_manager = CallbackManager.configure(
callbacks,
self.llm.callbacks,
self.llm.verbose,
tags,
self.llm.tags,
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.llm.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager_ = callback_manager.on_llm_start(
dumpd(self),
prompts,
invocation_params=params,
options=options,
)
generations = []
for prompt in prompts:
inputs_ = prompt
num_workers = None
batch_size = None
if num_workers is None:
if self.llm.pipeline._num_workers is None:
num_workers = 0
else:
num_workers = self.llm.pipeline._num_workers
if batch_size is None:
if self.llm.pipeline._batch_size is None:
batch_size = 1
else:
batch_size = self.llm.pipeline._batch_size
preprocess_params = {}
generate_kwargs = {}
preprocess_params.update(generate_kwargs)
forward_params = generate_kwargs
postprocess_params = {}
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
preprocess_params = {
**self.llm.pipeline._preprocess_params,
**preprocess_params,
}
forward_params = {
**self.llm.pipeline._forward_params,
**forward_params,
}
postprocess_params = {
**self.llm.pipeline._postprocess_params,
**postprocess_params,
}
self.llm.pipeline.call_count += 1
if (
self.llm.pipeline.call_count > 10
and self.llm.pipeline.framework == "pt"
and self.llm.pipeline.device.type == "cuda"
):
warnings.warn(
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
" dataset",
UserWarning,
)
model_inputs = self.llm.pipeline.preprocess(
inputs_, **preprocess_params
)
model_outputs = self.llm.pipeline.forward(
model_inputs, **forward_params
)
model_outputs["process"] = False
return model_outputs
output = LLMResult(generations=generations)
run_manager_.on_llm_end(output)
if run_manager_:
output.run = RunInfo(run_id=run_manager_.run_id)
response = output
outputs = [
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
][0]
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
else:
_run_manager = (
run_manager or CallbackManagerForChainRun.get_noop_manager()
)
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {
k: v for k, v in inputs.items() if k != self.input_key
}
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in other_keys.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
inputs["is_first"] = False
inputs["input_documents"] = input_docs
# Call predict on the LLM.
output = self.llm_chain(inputs, callbacks=_run_manager.get_child())
if "process" in output.keys() and not output["process"]:
return output
output = output[self.llm_chain.output_key]
extra_return_dict = {}
extra_return_dict[self.output_key] = output
outputs = extra_return_dict
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(
self, inputs: Union[Dict[str, Any], Any]
) -> Dict[str, str]:
"""Validate and prep inputs."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(
self.memory.memory_variables
)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]
def run(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
if args and not kwargs:
if len(args) != 1:
raise ValueError(
"`run` supports only one positional argument."
)
return self(args[0], callbacks=callbacks, tags=tags)[
self.output_keys[0]
]
if kwargs and not args:
return self(kwargs, callbacks=callbacks, tags=tags)[
self.output_keys[0]
]
if not kwargs and not args:
raise ValueError(
"`run` supported with either positional arguments or keyword arguments,"
" but none were provided."
)
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of chain."""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict()
_dict["_type"] = self._chain_type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the chain.
Args:
file_path: Path to file to save the chain to.
Example:
.. code-block:: python
chain.save(file_path="path/chain.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
chain_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(chain_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(chain_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents."""
input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def prompt_length(
self, docs: List[Document], **kwargs: Any
) -> Optional[int]:
"""Return the prompt length given the documents passed in.
Returns None if the method does not depend on the prompt length.
"""
return None
def _call(
self,
inputs: Dict[str, List[Document]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = (
run_manager or CallbackManagerForChainRun.get_noop_manager()
)
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in other_keys.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
# Call predict on the LLM.
output, extra_return_dict = (
self.llm_chain(inputs, callbacks=_run_manager.get_child())[
self.llm_chain.output_key
],
{},
)
extra_return_dict[self.output_key] = output
return extra_return_dict
from pydantic import BaseModel
class Generation(Serializable):
"""Output of a single generation."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw generation info response from the provider"""
"""May include things like reason for finishing (e.g. in OpenAI)"""
# TODO: add log probs
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
class LLMChain(Chain):
"""Chain to run queries against LLMs.
Example:
.. code-block:: python
from langchain import LLMChain, OpenAI, PromptTemplate
prompt_template = "Tell me a {adjective} joke"
prompt = PromptTemplate(
input_variables=["adjective"], template=prompt_template
)
llm = LLMChain(llm=OpenAI(), prompt=prompt)
"""
@property
def lc_serializable(self) -> bool:
return True
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: BaseLanguageModel
output_key: str = "text" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
prompts, stop = self.prep_prompts([inputs], run_manager=run_manager)
response = self.llm.generate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
)
return self.create_outputs(response)[0]
def prep_prompts(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {
k: inputs[k] for k in self.prompt.input_variables
}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
{"input_list": input_list},
)
try:
response = self.generate(input_list, run_manager=run_manager)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
run_manager.on_chain_end({"outputs": outputs})
return outputs
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
"""Create outputs from response."""
return [
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
]
def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, Any]]:
"""Call predict and then parse the results."""
result = self.predict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
def apply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = self.apply(input_list, callbacks=callbacks)
return self._parse_result(result)
def _parse_result(
self, result: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
if self.prompt.output_parser is not None:
return [
self.prompt.output_parser.parse(res[self.output_key])
for res in result
]
else:
return result
@property
def _chain_type(self) -> str:
return "llm_chain"
@classmethod
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
"""Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template)
def _get_default_document_prompt() -> PromptTemplate:
return PromptTemplate(
input_variables=["page_content"], template="{page_content}"
)
class StuffDocumentsChain(BaseCombineDocumentsChain):
"""Chain that combines documents by stuffing into context."""
llm_chain: LLMChain
"""LLM wrapper to use after formatting documents."""
document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt
)
"""Prompt to use to format each document."""
document_variable_name: str
"""The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided."""
document_separator: str = "\n\n"
"""The string with which to join the formatted documents"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict:
"""Get default document variable name, if not provided."""
llm_chain_variables = values["llm_chain"].prompt.input_variables
if "document_variable_name" not in values:
if len(llm_chain_variables) == 1:
values["document_variable_name"] = llm_chain_variables[0]
else:
raise ValueError(
"document_variable_name must be provided if there are "
"multiple llm_chain_variables"
)
else:
if values["document_variable_name"] not in llm_chain_variables:
raise ValueError(
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
return values
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
# Format each document according to the prompt
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
return inputs
def prompt_length(
self, docs: List[Document], **kwargs: Any
) -> Optional[int]:
"""Get the prompt length by formatting the prompt."""
inputs = self._get_inputs(docs, **kwargs)
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
@property
def _chain_type(self) -> str:
return "stuff_documents_chain"
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(
self, llm: BaseLanguageModel, **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_stuff_chain(
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
document_variable_name: str = "context",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> StuffDocumentsChain:
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(
llm=llm,
prompt=_prompt,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
# TODO: document prompt
return StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
verbose=verbose,
callback_manager=callback_manager,
**kwargs,
)
def load_qa_chain(
llm: BaseLanguageModel,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain:
"""Load question answering chain.
Args:
llm: Language Model to use in the chain.
chain_type: Type of document combining chain to use. Should be one of "stuff",
"map_reduce", "map_rerank", and "refine".
verbose: Whether chains should be run in verbose mode or not. Note that this
applies to all chains that make up the final chain.
callback_manager: Callback manager to use for the chain.
Returns:
A chain to use for question answering.
"""
loader_mapping: Mapping[str, LoadingCallable] = {
"stuff": _load_stuff_chain,
}
if chain_type not in loader_mapping:
raise ValueError(
f"Got unsupported chain type: {chain_type}. "
f"Should be one of {loader_mapping.keys()}"
)
return loader_mapping[chain_type](
llm, verbose=verbose, callback_manager=callback_manager, **kwargs
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,380 @@
import inspect
import os
from functools import partial
from typing import Dict, Any, Optional, List
from langchain.callbacks.manager import CallbackManagerForLLMRun
from pydantic import root_validator
from langchain.llms import gpt4all
from dotenv import dotenv_values
from utils import FakeTokenizer
def get_model_tokenizer_gpt4all(base_model, **kwargs):
# defaults (some of these are generation parameters, so need to be passed in at generation time)
model_kwargs = dict(
n_threads=os.cpu_count() // 2,
temp=kwargs.get("temperature", 0.2),
top_p=kwargs.get("top_p", 0.75),
top_k=kwargs.get("top_k", 40),
n_ctx=2048 - 256,
)
env_gpt4all_file = ".env_gpt4all"
model_kwargs.update(dotenv_values(env_gpt4all_file))
# make int or float if can to satisfy types for class
for k, v in model_kwargs.items():
try:
if float(v) == int(v):
model_kwargs[k] = int(v)
else:
model_kwargs[k] = float(v)
except:
pass
if base_model == "llama":
if "model_path_llama" not in model_kwargs:
raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
model_path = model_kwargs.pop("model_path_llama")
# FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python
from llama_cpp import Llama
# llama sets some things at init model time, not generation time
func_names = list(inspect.signature(Llama.__init__).parameters)
model_kwargs = {
k: v for k, v in model_kwargs.items() if k in func_names
}
model_kwargs["n_ctx"] = int(model_kwargs["n_ctx"])
model = Llama(model_path=model_path, **model_kwargs)
elif base_model in "gpt4all_llama":
if (
"model_name_gpt4all_llama" not in model_kwargs
and "model_path_gpt4all_llama" not in model_kwargs
):
raise ValueError(
"No model_name_gpt4all_llama or model_path_gpt4all_llama in %s"
% env_gpt4all_file
)
model_name = model_kwargs.pop("model_name_gpt4all_llama")
model_type = "llama"
from gpt4all import GPT4All as GPT4AllModel
model = GPT4AllModel(model_name=model_name, model_type=model_type)
elif base_model in "gptj":
if (
"model_name_gptj" not in model_kwargs
and "model_path_gptj" not in model_kwargs
):
raise ValueError(
"No model_name_gpt4j or model_path_gpt4j in %s"
% env_gpt4all_file
)
model_name = model_kwargs.pop("model_name_gptj")
model_type = "gptj"
from gpt4all import GPT4All as GPT4AllModel
model = GPT4AllModel(model_name=model_name, model_type=model_type)
else:
raise ValueError("No such base_model %s" % base_model)
return model, FakeTokenizer(), "cpu"
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
# streaming to std already occurs without this
# sys.stdout.write(token)
# sys.stdout.flush()
pass
def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
# default from class
model_kwargs = {
k: v.default
for k, v in dict(inspect.signature(cls).parameters).items()
if k not in exclude_list
}
# from our defaults
model_kwargs.update(default_kwargs)
# from user defaults
model_kwargs.update(env_kwargs)
# ensure only valid keys
func_names = list(inspect.signature(cls).parameters)
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
return model_kwargs
def get_llm_gpt4all(
model_name,
model=None,
max_new_tokens=256,
temperature=0.1,
repetition_penalty=1.0,
top_k=40,
top_p=0.7,
streaming=False,
callbacks=None,
prompter=None,
verbose=False,
):
assert prompter is not None
env_gpt4all_file = ".env_gpt4all"
env_kwargs = dotenv_values(env_gpt4all_file)
n_ctx = env_kwargs.pop("n_ctx", 2048 - max_new_tokens)
default_kwargs = dict(
context_erase=0.5,
n_batch=1,
n_ctx=n_ctx,
n_predict=max_new_tokens,
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
repeat_penalty=repetition_penalty,
temp=temperature,
temperature=temperature,
top_k=top_k,
top_p=top_p,
use_mlock=True,
verbose=verbose,
)
if model_name == "llama":
cls = H2OLlamaCpp
model_path = (
env_kwargs.pop("model_path_llama") if model is None else model
)
model_kwargs = get_model_kwargs(
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
)
model_kwargs.update(
dict(
model_path=model_path,
callbacks=callbacks,
streaming=streaming,
prompter=prompter,
)
)
llm = cls(**model_kwargs)
llm.client.verbose = verbose
elif model_name == "gpt4all_llama":
cls = H2OGPT4All
model_path = (
env_kwargs.pop("model_path_gpt4all_llama")
if model is None
else model
)
model_kwargs = get_model_kwargs(
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
)
model_kwargs.update(
dict(
model=model_path,
backend="llama",
callbacks=callbacks,
streaming=streaming,
prompter=prompter,
)
)
llm = cls(**model_kwargs)
elif model_name == "gptj":
cls = H2OGPT4All
model_path = (
env_kwargs.pop("model_path_gptj") if model is None else model
)
model_kwargs = get_model_kwargs(
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
)
model_kwargs.update(
dict(
model=model_path,
backend="gptj",
callbacks=callbacks,
streaming=streaming,
prompter=prompter,
)
)
llm = cls(**model_kwargs)
else:
raise RuntimeError("No such model_name %s" % model_name)
return llm
class H2OGPT4All(gpt4all.GPT4All):
model: Any
prompter: Any
"""Path to the pre-trained GPT4All model file."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in the environment."""
try:
if isinstance(values["model"], str):
from gpt4all import GPT4All as GPT4AllModel
full_path = values["model"]
model_path, delimiter, model_name = full_path.rpartition("/")
model_path += delimiter
values["client"] = GPT4AllModel(
model_name=model_name,
model_path=model_path or None,
model_type=values["backend"],
allow_download=False,
)
if values["n_threads"] is not None:
# set n_threads
values["client"].model.set_thread_count(
values["n_threads"]
)
else:
values["client"] = values["model"]
try:
values["backend"] = values["client"].model_type
except AttributeError:
# The below is for compatibility with GPT4All Python bindings <= 0.2.3.
values["backend"] = values["client"].model.model_type
except ImportError:
raise ValueError(
"Could not import gpt4all python package. "
"Please install it with `pip install gpt4all`."
)
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
# Roughly 4 chars per token if natural language
prompt = prompt[-self.n_ctx * 4 :]
# use instruct prompting
data_point = dict(context="", instruction=prompt, input="")
prompt = self.prompter.generate_prompt(data_point)
verbose = False
if verbose:
print("_call prompt: %s" % prompt, flush=True)
# FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
return super()._call(prompt, stop=stop, run_manager=run_manager)
from langchain.llms import LlamaCpp
class H2OLlamaCpp(LlamaCpp):
model_path: Any
prompter: Any
"""Path to the pre-trained GPT4All model file."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
if isinstance(values["model_path"], str):
model_path = values["model_path"]
model_param_names = [
"lora_path",
"lora_base",
"n_ctx",
"n_parts",
"seed",
"f16_kv",
"logits_all",
"vocab_only",
"use_mlock",
"n_threads",
"n_batch",
"use_mmap",
"last_n_tokens_size",
]
model_params = {k: values[k] for k in model_param_names}
# For backwards compatibility, only include if non-null.
if values["n_gpu_layers"] is not None:
model_params["n_gpu_layers"] = values["n_gpu_layers"]
try:
from llama_cpp import Llama
values["client"] = Llama(model_path, **model_params)
except ImportError:
raise ModuleNotFoundError(
"Could not import llama-cpp-python library. "
"Please install the llama-cpp-python library to "
"use this embedding model: pip install llama-cpp-python"
)
except Exception as e:
raise ValueError(
f"Could not load Llama model from path: {model_path}. "
f"Received error {e}"
)
else:
values["client"] = values["model_path"]
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
verbose = False
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
# still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
prompt = prompt[-self.n_ctx * 4 :]
prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
num_prompt_tokens = len(prompt_tokens)
if num_prompt_tokens > self.n_ctx:
# conservative by using int()
chars_per_token = int(len(prompt) / num_prompt_tokens)
prompt = prompt[-self.n_ctx * chars_per_token :]
if verbose:
print(
"reducing tokens, assuming average of %s chars/token: %s"
% chars_per_token,
flush=True,
)
prompt_tokens2 = self.client.tokenize(
b" " + prompt.encode("utf-8")
)
num_prompt_tokens2 = len(prompt_tokens2)
print(
"reduced tokens from %d -> %d"
% (num_prompt_tokens, num_prompt_tokens2),
flush=True,
)
# use instruct prompting
data_point = dict(context="", instruction=prompt, input="")
prompt = self.prompter.generate_prompt(data_point)
if verbose:
print("_call prompt: %s" % prompt, flush=True)
if self.streaming:
text_callback = None
if run_manager:
text_callback = partial(
run_manager.on_llm_new_token, verbose=self.verbose
)
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
if text_callback:
text_callback(prompt)
text = ""
for token in self.stream(
prompt=prompt, stop=stop, run_manager=run_manager
):
text_chunk = token["choices"][0]["text"]
# self.stream already calls text_callback
# if text_callback:
# text_callback(text_chunk)
text += text_chunk
return text
else:
params = self._get_parameters(stop)
params = {**params, **kwargs}
result = self.client(prompt=prompt, **params)
return result["choices"][0]["text"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,93 @@
import traceback
from typing import Callable
import os
from gradio_client.client import Job
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
from gradio_client import Client
class GradioClient(Client):
"""
Parent class of gradio client
To handle automatically refreshing client if detect gradio server changed
"""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
super().__init__(*args, **kwargs)
self.server_hash = self.get_server_hash()
def get_server_hash(self):
"""
Get server hash using super without any refresh action triggered
Returns: git hash of gradio server
"""
return super().submit(api_name="/system_hash").result()
def refresh_client_if_should(self):
# get current hash in order to update api_name -> fn_index map in case gradio server changed
# FIXME: Could add cli api as hash
server_hash = self.get_server_hash()
if self.server_hash != server_hash:
self.refresh_client()
self.server_hash = server_hash
else:
self.reset_session()
def refresh_client(self):
"""
Ensure every client call is independent
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
Returns:
"""
# need session hash to be new every time, to avoid "generator already executing"
self.reset_session()
client = Client(*self.args, **self.kwargs)
for k, v in client.__dict__.items():
setattr(self, k, v)
def submit(
self,
*args,
api_name: str | None = None,
fn_index: int | None = None,
result_callbacks: Callable | list[Callable] | None = None,
) -> Job:
# Note predict calls submit
try:
self.refresh_client_if_should()
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
except Exception as e:
print("Hit e=%s" % str(e), flush=True)
# force reconfig in case only that
self.refresh_client()
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
# see if immediately failed
e = job.future._exception
if e is not None:
print(
"GR job failed: %s %s"
% (str(e), "".join(traceback.format_tb(e.__traceback__))),
flush=True,
)
# force reconfig in case only that
self.refresh_client()
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
e2 = job.future._exception
if e2 is not None:
print(
"GR job failed again: %s\n%s"
% (
str(e2),
"".join(traceback.format_tb(e2.__traceback__)),
),
flush=True,
)
return job

View File

@@ -0,0 +1,765 @@
import os
from apps.stable_diffusion.src.utils.utils import _compile_module
from io import BytesIO
import torch_mlir
from stopping import get_stopping
from prompter import Prompter, PromptType
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
StoppingCriteriaList,
)
import copy
import torch
from transformers import AutoConfig, AutoModelForCausalLM
import gc
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx, save_mlir
from apps.stable_diffusion.src import args
# Brevitas
from typing import List, Tuple
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
# fmt: off
def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
return [lhs[0], rhs[0]]
else:
raise ValueError("Input shapes not supported.")
def quantmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def quantmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
quantmatmul_rhs_group_quant〡shape,
quantmatmul_rhs_group_quant〡dtype,
quantmatmul_rhs_group_quant〡has_value_semantics]
# fmt: on
global_device = "cuda"
global_precision = "fp16"
if not args.run_docuchat_web:
args.device = global_device
args.precision = global_precision
tensor_device = "cpu" if args.device == "cpu" else "cuda"
class H2OGPTModel(torch.nn.Module):
def __init__(self, device, precision):
super().__init__()
torch_dtype = (
torch.float32
if precision == "fp32" or device == "cpu"
else torch.float16
)
device_map = {"": "cpu"} if device == "cpu" else {"": 0}
model_kwargs = {
"local_files_only": False,
"torch_dtype": torch_dtype,
"resume_download": True,
"use_auth_token": False,
"trust_remote_code": True,
"offload_folder": "offline_folder",
"device_map": device_map,
}
config = AutoConfig.from_pretrained(
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
use_auth_token=False,
trust_remote_code=True,
offload_folder="offline_folder",
)
self.model = AutoModelForCausalLM.from_pretrained(
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
config=config,
**model_kwargs,
)
if precision in ["int4", "int8"]:
print("Applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model.transformer.h,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=128,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(self, input_ids, attention_mask):
input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": None,
"use_cache": True,
}
output = self.model(
**input_dict,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return output.logits[:, -1, :]
class H2OGPTSHARKModel(torch.nn.Module):
def __init__(self):
super().__init__()
model_name = "h2ogpt_falcon_7b"
extended_model_name = (
model_name + "_" + args.precision + "_" + args.device
)
vmfb_path = Path(extended_model_name + ".vmfb")
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
shark_module = None
need_to_compile = False
if not vmfb_path.exists():
need_to_compile = True
# Downloading VMFB from shark_tank
print("Trying to download pre-compiled vmfb from shark tank.")
download_public_file(
"gs://shark_tank/langchain/" + str(vmfb_path),
vmfb_path.absolute(),
single_file=True,
)
if vmfb_path.exists():
print(
"Pre-compiled vmfb downloaded from shark tank successfully."
)
need_to_compile = False
if need_to_compile:
if not mlir_path.exists():
print("Trying to download pre-generated mlir from shark tank.")
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/langchain/" + str(mlir_path),
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
# Generating the mlir
bytecode = self.get_bytecode(tensor_device, args.precision)
shark_module = SharkInference(
mlir_module=bytecode,
device=args.device,
mlir_dialect="linalg",
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(
shark_module, extended_model_name, []
)
print("Saved newly generated vmfb.")
if shark_module is None:
if vmfb_path.exists():
print("Compiled vmfb found. Loading it from: ", vmfb_path)
shark_module = SharkInference(
None, device=args.device, mlir_dialect="linalg"
)
shark_module.load_module(str(vmfb_path))
print("Compiled vmfb loaded successfully.")
else:
raise ValueError("Unable to download/generate a vmfb.")
self.model = shark_module
def get_bytecode(self, device, precision):
h2ogpt_model = H2OGPTModel(device, precision)
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 400)
).to(device=device)
compilation_attention_mask = torch.ones(1, 400, dtype=torch.int64).to(
device=device
)
h2ogptCompileInput = (
compilation_input_ids,
compilation_attention_mask,
)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
h2ogpt_model,
h2ogptCompileInput,
is_f16=False,
precision=precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del h2ogpt_model
del self.src_model
print(f"[DEBUG] generating torch mlir")
if precision in ["int4", "int8"]:
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
)
module = torch_mlir.compile(
ts_graph,
[*h2ogptCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
module = torch_mlir.compile(
ts_graph,
[*h2ogptCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module
bytecode = save_mlir(
bytecode,
model_name=f"h2ogpt_{precision}",
frontend="torch",
)
return bytecode
def forward(self, input_ids, attention_mask):
result = torch.from_numpy(
self.model(
"forward",
(input_ids.to(device="cpu"), attention_mask.to(device="cpu")),
)
).to(device=tensor_device)
return result
def decode_tokens(tokenizer, res_tokens):
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = tokenizer.decode(res_tokens, skip_special_tokens=True)
return res_str
def generate_token(h2ogpt_shark_model, model, tokenizer, **generate_kwargs):
del generate_kwargs["max_time"]
generate_kwargs["input_ids"] = generate_kwargs["input_ids"].to(
device=tensor_device
)
generate_kwargs["attention_mask"] = generate_kwargs["attention_mask"].to(
device=tensor_device
)
truncated_input_ids = []
stopping_criteria = generate_kwargs["stopping_criteria"]
generation_config_ = GenerationConfig.from_model_config(model.config)
generation_config = copy.deepcopy(generation_config_)
model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
model_kwargs,
) = model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
input_ids_seq_length = input_ids.shape[-1]
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
)
logits_warper = model._get_logits_warper(generation_config)
(
input_ids,
model_kwargs,
) = model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=model.config.is_encoder_decoder, # False
**model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = (
torch.tensor(eos_token_id).to(device=tensor_device)
if eos_token_id is not None
else None
)
pad_token_id = generation_config.pad_token_id
eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(
input_ids.shape[0],
dtype=torch.long,
device=input_ids.device,
)
timesRan = 0
import time
start = time.time()
print("\n")
res_tokens = []
while True:
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs
)
outputs = h2ogpt_shark_model.forward(
model_inputs["input_ids"], model_inputs["attention_mask"]
)
if args.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = next_token * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
input_ids = torch.cat([input_ids, next_token[:, None]], dim=-1)
model_kwargs["past_key_values"] = None
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
truncated_input_ids.append(input_ids[:, 0])
input_ids = input_ids[:, 1:]
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 1:]
new_word = tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
res_tokens.append(next_token)
if new_word == "<0x0A>":
print("\n", end="", flush=True)
else:
print(f"{new_word}", end=" ", flush=True)
part_str = decode_tokens(tokenizer, res_tokens)
yield part_str
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_token.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0 or stopping_criteria(
input_ids, scores
):
break
timesRan = timesRan + 1
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / timesRan
)
)
torch.cuda.empty_cache()
gc.collect()
res_str = decode_tokens(tokenizer, res_tokens)
yield res_str
def pad_or_truncate_inputs(
input_ids, attention_mask, max_padding_length=400, do_truncation=False
):
inp_shape = input_ids.shape
if inp_shape[1] < max_padding_length:
# do padding
num_add_token = max_padding_length - inp_shape[1]
padded_input_ids = torch.cat(
[
torch.tensor([[11] * num_add_token]).to(device=tensor_device),
input_ids,
],
dim=1,
)
padded_attention_mask = torch.cat(
[
torch.tensor([[0] * num_add_token]).to(device=tensor_device),
attention_mask,
],
dim=1,
)
return padded_input_ids, padded_attention_mask
elif inp_shape[1] > max_padding_length or do_truncation:
# do truncation
num_remove_token = inp_shape[1] - max_padding_length
truncated_input_ids = input_ids[:, num_remove_token:]
truncated_attention_mask = attention_mask[:, num_remove_token:]
return truncated_input_ids, truncated_attention_mask
else:
return input_ids, attention_mask
class H2OTextGenerationPipeline(TextGenerationPipeline):
def __init__(
self,
*args,
debug=False,
chat=False,
stream_output=False,
sanitize_bot_response=False,
use_prompter=True,
prompter=None,
prompt_type=None,
prompt_dict=None,
max_input_tokens=2048 - 256,
**kwargs,
):
"""
HF-like pipeline, but handle instruction prompting and stopping (for some models)
:param args:
:param debug:
:param chat:
:param stream_output:
:param sanitize_bot_response:
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
:param prompter: prompter, can pass if have already
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
If use_prompter, then will make prompter and use it.
:param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
:param max_input_tokens:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self.prompt_text = None
self.use_prompter = use_prompter
self.prompt_type = prompt_type
self.prompt_dict = prompt_dict
self.prompter = prompter
if self.use_prompter:
if self.prompter is not None:
assert self.prompter.prompt_type is not None
else:
self.prompter = Prompter(
self.prompt_type,
self.prompt_dict,
debug=debug,
chat=chat,
stream_output=stream_output,
)
self.human = self.prompter.humanstr
self.bot = self.prompter.botstr
self.can_stop = True
else:
self.prompter = None
self.human = None
self.bot = None
self.can_stop = False
self.sanitize_bot_response = sanitize_bot_response
self.max_input_tokens = (
max_input_tokens # not for generate, so ok that not kwargs
)
@staticmethod
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
verbose = bool(int(os.getenv("VERBOSE_PIPELINE", "0")))
if hasattr(tokenizer, "model_max_length"):
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
model_max_length = tokenizer.model_max_length
if max_prompt_length is not None:
model_max_length = min(model_max_length, max_prompt_length)
# cut at some upper likely limit to avoid excessive tokenization etc
# upper bound of 10 chars/token, e.g. special chars sometimes are long
if len(prompt_text) > model_max_length * 10:
len0 = len(prompt_text)
prompt_text = prompt_text[-model_max_length * 10 :]
if verbose:
print(
"Cut of input: %s -> %s" % (len0, len(prompt_text)),
flush=True,
)
else:
# unknown
model_max_length = None
num_prompt_tokens = None
if model_max_length is not None:
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
# For https://github.com/h2oai/h2ogpt/issues/192
for trial in range(0, 3):
prompt_tokens = tokenizer(prompt_text)["input_ids"]
num_prompt_tokens = len(prompt_tokens)
if num_prompt_tokens > model_max_length:
# conservative by using int()
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
# keep tail, where question is if using langchain
prompt_text = prompt_text[
-model_max_length * chars_per_token :
]
if verbose:
print(
"reducing %s tokens, assuming average of %s chars/token for %s characters"
% (
num_prompt_tokens,
chars_per_token,
len(prompt_text),
),
flush=True,
)
else:
if verbose:
print(
"using %s tokens with %s chars"
% (num_prompt_tokens, len(prompt_text)),
flush=True,
)
break
return prompt_text, num_prompt_tokens
def preprocess(
self,
prompt_text,
prefix="",
handle_long_generation=None,
**generate_kwargs,
):
(
prompt_text,
num_prompt_tokens,
) = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
data_point = dict(context="", instruction=prompt_text, input="")
if self.prompter is not None:
prompt_text = self.prompter.generate_prompt(data_point)
self.prompt_text = prompt_text
if handle_long_generation is None:
# forces truncation of inputs to avoid critical failure
handle_long_generation = None # disable with new approaches
return super().preprocess(
prompt_text,
prefix=prefix,
handle_long_generation=handle_long_generation,
**generate_kwargs,
)
def postprocess(
self,
model_outputs,
return_type=ReturnType.FULL_TEXT,
clean_up_tokenization_spaces=True,
):
records = super().postprocess(
model_outputs,
return_type=return_type,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
for rec in records:
if self.use_prompter:
outputs = rec["generated_text"]
outputs = self.prompter.get_response(
outputs,
prompt=self.prompt_text,
sanitize_bot_response=self.sanitize_bot_response,
)
elif self.bot and self.human:
outputs = (
rec["generated_text"]
.split(self.bot)[1]
.split(self.human)[0]
)
else:
outputs = rec["generated_text"]
rec["generated_text"] = outputs
print(
"prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs),
flush=True,
)
return records
def _forward(self, model_inputs, **generate_kwargs):
if self.can_stop:
stopping_criteria = get_stopping(
self.prompt_type,
self.prompt_dict,
self.tokenizer,
self.device,
human=self.human,
bot=self.bot,
model_max_length=self.tokenizer.model_max_length,
)
generate_kwargs["stopping_criteria"] = stopping_criteria
# return super()._forward(model_inputs, **generate_kwargs)
return self.__forward(model_inputs, **generate_kwargs)
# FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
# FIXME: https://github.com/h2oai/h2ogpt/issues/172
def __forward(self, model_inputs, **generate_kwargs):
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text")
## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
# generate_kwargs = copy.deepcopy(generate_kwargs)
prefix_length = generate_kwargs.pop("prefix_length", 0)
if prefix_length > 0:
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].max_new_tokens
is not None
)
if not has_max_new_tokens:
generate_kwargs["max_length"] = (
generate_kwargs.get("max_length")
or self.model.config.max_length
)
generate_kwargs["max_length"] += prefix_length
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].min_new_tokens
is not None
)
if not has_min_new_tokens and "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
# BS x SL
# pad or truncate the input_ids and attention_mask
max_padding_length = 400
input_ids, attention_mask = pad_or_truncate_inputs(
input_ids, attention_mask, max_padding_length=max_padding_length
)
return_dict = {
"model": self.model,
"tokenizer": self.tokenizer,
"input_ids": input_ids,
"attention_mask": attention_mask,
"attention_mask": attention_mask,
}
return_dict = {**return_dict, **generate_kwargs}
return return_dict

View File

@@ -0,0 +1,247 @@
"""
Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py
But accepts preloaded model to avoid slowness in use and CUDA forking issues
Loader that loads image captions
By default, the loader utilizes the pre-trained BLIP image captioning model.
https://huggingface.co/Salesforce/blip-image-captioning-base
"""
from typing import List, Union, Any, Tuple
import requests
from langchain.docstore.document import Document
from langchain.document_loaders import ImageCaptionLoader
from utils import get_device, NullContext
import pkg_resources
try:
assert pkg_resources.get_distribution("bitsandbytes") is not None
have_bitsandbytes = True
except (pkg_resources.DistributionNotFound, AssertionError):
have_bitsandbytes = False
class H2OImageCaptionLoader(ImageCaptionLoader):
"""Loader that loads the captions of an image"""
def __init__(
self,
path_images: Union[str, List[str]] = None,
blip_processor: str = None,
blip_model: str = None,
caption_gpu=True,
load_in_8bit=True,
# True doesn't seem to work, even though https://huggingface.co/Salesforce/blip2-flan-t5-xxl#in-8-bit-precision-int8
load_half=False,
load_gptq="",
use_safetensors=False,
min_new_tokens=20,
max_tokens=50,
):
if blip_model is None or blip_model is None:
blip_processor = "Salesforce/blip-image-captioning-base"
blip_model = "Salesforce/blip-image-captioning-base"
super().__init__(path_images, blip_processor, blip_model)
self.blip_processor = blip_processor
self.blip_model = blip_model
self.processor = None
self.model = None
self.caption_gpu = caption_gpu
self.context_class = NullContext
self.device = "cpu"
self.load_in_8bit = (
load_in_8bit and have_bitsandbytes
) # only for blip2
self.load_half = load_half
self.load_gptq = load_gptq
self.use_safetensors = use_safetensors
self.gpu_id = "auto"
# default prompt
self.prompt = "image of"
self.min_new_tokens = min_new_tokens
self.max_tokens = max_tokens
def set_context(self):
if get_device() == "cuda" and self.caption_gpu:
import torch
n_gpus = (
torch.cuda.device_count() if torch.cuda.is_available else 0
)
if n_gpus > 0:
self.context_class = torch.device
self.device = "cuda"
def load_model(self):
try:
import transformers
except ImportError:
raise ValueError(
"`transformers` package not found, please install with "
"`pip install transformers`."
)
self.set_context()
if self.caption_gpu:
if self.gpu_id == "auto":
# blip2 has issues with multi-GPU. Error says need to somehow set language model in device map
# device_map = 'auto'
device_map = {"": 0}
else:
if self.device == "cuda":
device_map = {"": self.gpu_id}
else:
device_map = {"": "cpu"}
else:
device_map = {"": "cpu"}
import torch
with torch.no_grad():
with self.context_class(self.device):
context_class_cast = (
NullContext if self.device == "cpu" else torch.autocast
)
with context_class_cast(self.device):
if "blip2" in self.blip_processor.lower():
from transformers import (
Blip2Processor,
Blip2ForConditionalGeneration,
)
if self.load_half and not self.load_in_8bit:
self.processor = Blip2Processor.from_pretrained(
self.blip_processor, device_map=device_map
).half()
self.model = (
Blip2ForConditionalGeneration.from_pretrained(
self.blip_model, device_map=device_map
).half()
)
else:
self.processor = Blip2Processor.from_pretrained(
self.blip_processor,
load_in_8bit=self.load_in_8bit,
device_map=device_map,
)
self.model = (
Blip2ForConditionalGeneration.from_pretrained(
self.blip_model,
load_in_8bit=self.load_in_8bit,
device_map=device_map,
)
)
else:
from transformers import (
BlipForConditionalGeneration,
BlipProcessor,
)
self.load_half = False # not supported
if self.caption_gpu:
if device_map == "auto":
# Blip doesn't support device_map='auto'
if self.device == "cuda":
if self.gpu_id == "auto":
device_map = {"": 0}
else:
device_map = {"": self.gpu_id}
else:
device_map = {"": "cpu"}
else:
device_map = {"": "cpu"}
self.processor = BlipProcessor.from_pretrained(
self.blip_processor, device_map=device_map
)
self.model = (
BlipForConditionalGeneration.from_pretrained(
self.blip_model, device_map=device_map
)
)
return self
def set_image_paths(self, path_images: Union[str, List[str]]):
"""
Load from a list of image files
"""
if isinstance(path_images, str):
self.image_paths = [path_images]
else:
self.image_paths = path_images
def load(self, prompt=None) -> List[Document]:
if self.processor is None or self.model is None:
self.load_model()
results = []
for path_image in self.image_paths:
caption, metadata = self._get_captions_and_metadata(
model=self.model,
processor=self.processor,
path_image=path_image,
prompt=prompt,
)
doc = Document(page_content=caption, metadata=metadata)
results.append(doc)
return results
def _get_captions_and_metadata(
self, model: Any, processor: Any, path_image: str, prompt=None
) -> Tuple[str, dict]:
"""
Helper function for getting the captions and metadata of an image
"""
if prompt is None:
prompt = self.prompt
try:
from PIL import Image
except ImportError:
raise ValueError(
"`PIL` package not found, please install with `pip install pillow`"
)
try:
if path_image.startswith("http://") or path_image.startswith(
"https://"
):
image = Image.open(
requests.get(path_image, stream=True).raw
).convert("RGB")
else:
image = Image.open(path_image).convert("RGB")
except Exception:
raise ValueError(f"Could not get image data for {path_image}")
import torch
with torch.no_grad():
with self.context_class(self.device):
context_class_cast = (
NullContext if self.device == "cpu" else torch.autocast
)
with context_class_cast(self.device):
if self.load_half:
inputs = processor(
image, prompt, return_tensors="pt"
).half()
else:
inputs = processor(image, prompt, return_tensors="pt")
min_length = len(prompt) // 4 + self.min_new_tokens
self.max_tokens = max(self.max_tokens, min_length)
output = model.generate(
**inputs,
min_length=min_length,
max_length=self.max_tokens,
)
caption: str = processor.decode(
output[0], skip_special_tokens=True
)
prompti = caption.find(prompt)
if prompti >= 0:
caption = caption[prompti + len(prompt) :]
metadata: dict = {"image_path": path_image}
return caption, metadata

View File

@@ -0,0 +1,120 @@
# for generate (gradio server) and finetune
datasets==2.13.0
sentencepiece==0.1.99
huggingface_hub==0.16.4
appdirs==1.4.4
fire==0.5.0
docutils==0.20.1
evaluate==0.4.0
rouge_score==0.1.2
sacrebleu==2.3.1
scikit-learn==1.2.2
alt-profanity-check==1.2.2
better-profanity==0.7.0
numpy==1.24.3
pandas==2.0.2
matplotlib==3.7.1
loralib==0.1.1
bitsandbytes==0.39.0
accelerate==0.20.3
peft==0.4.0
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
transformers==4.30.2
tokenizers==0.13.3
APScheduler==3.10.1
# optional for generate
pynvml==11.5.0
psutil==5.9.5
boto3==1.26.101
botocore==1.29.101
# optional for finetune
tensorboard==2.13.0
neptune==1.2.0
# for gradio client
gradio_client==0.2.10
beautifulsoup4==4.12.2
markdown==3.4.3
# data and testing
pytest==7.2.2
pytest-xdist==3.2.1
nltk==3.8.1
textstat==0.7.3
# pandoc==2.3
pypandoc==1.11; sys_platform == "darwin" and platform_machine == "arm64"
pypandoc_binary==1.11; platform_machine == "x86_64"
pypandoc_binary==1.11; sys_platform == "win32"
openpyxl==3.1.2
lm_dataformat==0.0.20
bioc==2.0
# falcon
einops==0.6.1
instructorembedding==1.0.1
# for gpt4all .env file, but avoid worrying about imports
python-dotenv==1.0.0
text-generation==0.6.0
# for tokenization when don't have HF tokenizer
tiktoken==0.4.0
# optional: for OpenAI endpoint or embeddings (requires key)
openai==0.27.8
# optional for chat with PDF
langchain==0.0.202
pypdf==3.12.2
# avoid textract, requires old six
#textract==1.6.5
# for HF embeddings
sentence_transformers==2.2.2
# local vector db
chromadb==0.3.25
# server vector db
#pymilvus==2.2.8
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
# unstructured==0.8.1
# strong support for images
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
unstructured[local-inference]==0.7.4
#pdf2image==1.16.3
#pytesseract==0.3.10
pillow
pdfminer.six==20221105
urllib3
requests_file
#pdf2image==1.16.3
#pytesseract==0.3.10
tabulate==0.9.0
# FYI pandoc already part of requirements.txt
# JSONLoader, but makes some trouble for some users
# jq==1.4.1
# to check licenses
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
pip-licenses==4.3.0
# weaviate vector db
weaviate-client==3.22.1
gpt4all==1.0.5
llama-cpp-python==0.1.73
arxiv==1.4.8
pymupdf==1.22.5 # AGPL license
# extract-msg==0.41.1 # GPL3
# sometimes unstructured fails, these work in those cases. See https://github.com/h2oai/h2ogpt/issues/320
playwright==1.36.0
# requires Chrome binary to be in path
selenium==4.10.0

View File

@@ -0,0 +1,124 @@
from typing import List, Optional, Tuple
import torch
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[
torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]
]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0,
(bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=qkv.device,
)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
print(
"Replacing original LLaMa attention with flash attention", flush=True
)
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@@ -0,0 +1,109 @@
import functools
def get_loaders(model_name, reward_type, llama_type=None, load_gptq=""):
# NOTE: Some models need specific new prompt_type
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
if load_gptq:
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM
use_triton = False
functools.partial(
AutoGPTQForCausalLM.from_quantized,
quantize_config=None,
use_triton=use_triton,
)
return AutoGPTQForCausalLM.from_quantized, AutoTokenizer
if llama_type is None:
llama_type = "llama" in model_name.lower()
if llama_type:
from transformers import LlamaForCausalLM, LlamaTokenizer
return LlamaForCausalLM.from_pretrained, LlamaTokenizer
elif "distilgpt2" in model_name.lower():
from transformers import AutoModelForCausalLM, AutoTokenizer
return AutoModelForCausalLM.from_pretrained, AutoTokenizer
elif "gpt2" in model_name.lower():
from transformers import GPT2LMHeadModel, GPT2Tokenizer
return GPT2LMHeadModel.from_pretrained, GPT2Tokenizer
elif "mbart-" in model_name.lower():
from transformers import (
MBartForConditionalGeneration,
MBart50TokenizerFast,
)
return (
MBartForConditionalGeneration.from_pretrained,
MBart50TokenizerFast,
)
elif (
"t5" == model_name.lower()
or "t5-" in model_name.lower()
or "flan-" in model_name.lower()
):
from transformers import AutoTokenizer, T5ForConditionalGeneration
return T5ForConditionalGeneration.from_pretrained, AutoTokenizer
elif "bigbird" in model_name:
from transformers import (
BigBirdPegasusForConditionalGeneration,
AutoTokenizer,
)
return (
BigBirdPegasusForConditionalGeneration.from_pretrained,
AutoTokenizer,
)
elif (
"bart-large-cnn-samsum" in model_name
or "flan-t5-base-samsum" in model_name
):
from transformers import pipeline
return pipeline, "summarization"
elif (
reward_type
or "OpenAssistant/reward-model".lower() in model_name.lower()
):
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
)
return (
AutoModelForSequenceClassification.from_pretrained,
AutoTokenizer,
)
else:
from transformers import AutoTokenizer, AutoModelForCausalLM
model_loader = AutoModelForCausalLM
tokenizer_loader = AutoTokenizer
return model_loader.from_pretrained, tokenizer_loader
def get_tokenizer(
tokenizer_loader,
tokenizer_base_model,
local_files_only,
resume_download,
use_auth_token,
):
tokenizer = tokenizer_loader.from_pretrained(
tokenizer_base_model,
local_files_only=local_files_only,
resume_download=resume_download,
use_auth_token=use_auth_token,
padding_side="left",
)
tokenizer.pad_token_id = 0 # different from the eos token
# when generating, we will use the logits of right-most token to predict the next token
# so the padding should be on the left,
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
tokenizer.padding_side = "left" # Allow batched inference
return tokenizer

View File

@@ -0,0 +1,203 @@
import os
from gpt_langchain import (
path_to_docs,
get_some_dbs_from_hf,
all_db_zips,
some_db_zips,
create_or_update_db,
)
from utils import get_ngpus_vis
def glob_to_db(
user_path,
chunk=True,
chunk_size=512,
verbose=False,
fail_any_exception=False,
n_jobs=-1,
url=None,
enable_captions=True,
captions_model=None,
caption_loader=None,
enable_ocr=False,
):
sources1 = path_to_docs(
user_path,
verbose=verbose,
fail_any_exception=fail_any_exception,
n_jobs=n_jobs,
chunk=chunk,
chunk_size=chunk_size,
url=url,
enable_captions=enable_captions,
captions_model=captions_model,
caption_loader=caption_loader,
enable_ocr=enable_ocr,
)
return sources1
def make_db_main(
use_openai_embedding: bool = False,
hf_embedding_model: str = None,
persist_directory: str = "db_dir_UserData",
user_path: str = "user_path",
url: str = None,
add_if_exists: bool = True,
collection_name: str = "UserData",
verbose: bool = False,
chunk: bool = True,
chunk_size: int = 512,
fail_any_exception: bool = False,
download_all: bool = False,
download_some: bool = False,
download_one: str = None,
download_dest: str = "./",
n_jobs: int = -1,
enable_captions: bool = True,
captions_model: str = "Salesforce/blip-image-captioning-base",
pre_load_caption_model: bool = False,
caption_gpu: bool = True,
enable_ocr: bool = False,
db_type: str = "chroma",
):
"""
# To make UserData db for generate.py, put pdfs, etc. into path user_path and run:
python make_db.py
# once db is made, can use in generate.py like:
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b --langchain_mode=UserData
or zip-up the db_dir_UserData and share:
zip -r db_dir_UserData.zip db_dir_UserData
# To get all db files (except large wiki_full) do:
python make_db.py --download_some=True
# To get a single db file from HF:
python make_db.py --download_one=db_dir_DriverlessAI_docs.zip
:param use_openai_embedding: Whether to use OpenAI embedding
:param hf_embedding_model: HF embedding model to use. Like generate.py, uses 'hkunlp/instructor-large' if have GPUs, else "sentence-transformers/all-MiniLM-L6-v2"
:param persist_directory: where to persist db
:param user_path: where to pull documents from (None means url is not None. If url is not None, this is ignored.)
:param url: url to generate documents from (None means user_path is not None)
:param add_if_exists: Add to db if already exists, but will not add duplicate sources
:param collection_name: Collection name for new db if not adding
:param verbose: whether to show verbose messages
:param chunk: whether to chunk data
:param chunk_size: chunk size for chunking
:param fail_any_exception: whether to fail if any exception hit during ingestion of files
:param download_all: whether to download all (including 23GB Wikipedia) example databases from h2o.ai HF
:param download_some: whether to download some small example databases from h2o.ai HF
:param download_one: whether to download one chosen example databases from h2o.ai HF
:param download_dest: Destination for downloads
:param n_jobs: Number of cores to use for ingesting multiple files
:param enable_captions: Whether to enable captions on images
:param captions_model: See generate.py
:param pre_load_caption_model: See generate.py
:param caption_gpu: Caption images on GPU if present
:param enable_ocr: Whether to enable OCR on images
:param db_type: Type of db to create. Currently only 'chroma' and 'weaviate' is supported.
:return: None
"""
db = None
# match behavior of main() in generate.py for non-HF case
n_gpus = get_ngpus_vis()
if n_gpus == 0:
if hf_embedding_model is None:
# if no GPUs, use simpler embedding model to avoid cost in time
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
else:
if hf_embedding_model is None:
# if still None, then set default
hf_embedding_model = "hkunlp/instructor-large"
if download_all:
print("Downloading all (and unzipping): %s" % all_db_zips, flush=True)
get_some_dbs_from_hf(download_dest, db_zips=all_db_zips)
if verbose:
print("DONE", flush=True)
return db, collection_name
elif download_some:
print(
"Downloading some (and unzipping): %s" % some_db_zips, flush=True
)
get_some_dbs_from_hf(download_dest, db_zips=some_db_zips)
if verbose:
print("DONE", flush=True)
return db, collection_name
elif download_one:
print("Downloading %s (and unzipping)" % download_one, flush=True)
get_some_dbs_from_hf(
download_dest, db_zips=[[download_one, "", "Unknown License"]]
)
if verbose:
print("DONE", flush=True)
return db, collection_name
if enable_captions and pre_load_caption_model:
# preload, else can be too slow or if on GPU have cuda context issues
# Inside ingestion, this will disable parallel loading of multiple other kinds of docs
# However, if have many images, all those images will be handled more quickly by preloaded model on GPU
from image_captions import H2OImageCaptionLoader
caption_loader = H2OImageCaptionLoader(
None,
blip_model=captions_model,
blip_processor=captions_model,
caption_gpu=caption_gpu,
).load_model()
else:
if enable_captions:
caption_loader = "gpu" if caption_gpu else "cpu"
else:
caption_loader = False
if verbose:
print("Getting sources", flush=True)
assert (
user_path is not None or url is not None
), "Can't have both user_path and url as None"
if not url:
assert os.path.isdir(user_path), (
"user_path=%s does not exist" % user_path
)
sources = glob_to_db(
user_path,
chunk=chunk,
chunk_size=chunk_size,
verbose=verbose,
fail_any_exception=fail_any_exception,
n_jobs=n_jobs,
url=url,
enable_captions=enable_captions,
captions_model=captions_model,
caption_loader=caption_loader,
enable_ocr=enable_ocr,
)
exceptions = [x for x in sources if x.metadata.get("exception")]
print("Exceptions: %s" % exceptions, flush=True)
sources = [x for x in sources if "exception" not in x.metadata]
assert len(sources) > 0, "No sources found"
db = create_or_update_db(
db_type,
persist_directory,
collection_name,
sources,
use_openai_embedding,
add_if_exists,
verbose,
hf_embedding_model,
)
assert db is not None
if verbose:
print("DONE", flush=True)
return db, collection_name

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,403 @@
"""Load Data from a MediaWiki dump xml."""
import ast
import glob
import pickle
import uuid
from typing import List, Optional
import os
import bz2
import csv
import numpy as np
import pandas as pd
import pytest
from matplotlib import pyplot as plt
from langchain.docstore.document import Document
from langchain.document_loaders import MWDumpLoader
# path where downloaded wiki files exist, to be processed
root_path = "/data/jon/h2o-llm"
def unescape(x):
try:
x = ast.literal_eval(x)
except:
try:
x = x.encode("ascii", "ignore").decode("unicode_escape")
except:
pass
return x
def get_views():
# views = pd.read_csv('wiki_page_views_more_1000month.csv')
views = pd.read_csv("wiki_page_views_more_5000month.csv")
views.index = views["title"]
views = views["views"]
views = views.to_dict()
views = {str(unescape(str(k))): v for k, v in views.items()}
views2 = {k.replace("_", " "): v for k, v in views.items()}
# views has _ but pages has " "
views.update(views2)
return views
class MWDumpDirectLoader(MWDumpLoader):
def __init__(
self,
data: str,
encoding: Optional[str] = "utf8",
title_words_limit=None,
use_views=True,
verbose=True,
):
"""Initialize with file path."""
self.data = data
self.encoding = encoding
self.title_words_limit = title_words_limit
self.verbose = verbose
if use_views:
# self.views = get_views()
# faster to use global shared values
self.views = global_views
else:
self.views = None
def load(self) -> List[Document]:
"""Load from file path."""
import mwparserfromhell
import mwxml
dump = mwxml.Dump.from_page_xml(self.data)
docs = []
for page in dump.pages:
if self.views is not None and page.title not in self.views:
if self.verbose:
print("Skipped %s low views" % page.title, flush=True)
continue
for revision in page:
if self.title_words_limit is not None:
num_words = len(" ".join(page.title.split("_")).split(" "))
if num_words > self.title_words_limit:
if self.verbose:
print("Skipped %s" % page.title, flush=True)
continue
if self.verbose:
if self.views is not None:
print(
"Kept %s views: %s"
% (page.title, self.views[page.title]),
flush=True,
)
else:
print("Kept %s" % page.title, flush=True)
code = mwparserfromhell.parse(revision.text)
text = code.strip_code(
normalize=True, collapse=True, keep_template_params=False
)
title_url = str(page.title).replace(" ", "_")
metadata = dict(
title=page.title,
source="https://en.wikipedia.org/wiki/" + title_url,
id=page.id,
redirect=page.redirect,
views=self.views[page.title]
if self.views is not None
else -1,
)
metadata = {k: v for k, v in metadata.items() if v is not None}
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_index(search_term, index_filename):
byte_flag = False
data_length = start_byte = 0
index_file = open(index_filename, "r")
csv_reader = csv.reader(index_file, delimiter=":")
for line in csv_reader:
if not byte_flag and search_term == line[2]:
start_byte = int(line[0])
byte_flag = True
elif byte_flag and int(line[0]) != start_byte:
data_length = int(line[0]) - start_byte
break
index_file.close()
return start_byte, data_length
def get_start_bytes(index_filename):
index_file = open(index_filename, "r")
csv_reader = csv.reader(index_file, delimiter=":")
start_bytes = set()
for line in csv_reader:
start_bytes.add(int(line[0]))
index_file.close()
return sorted(start_bytes)
def get_wiki_filenames():
# requires
# wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2
base_path = os.path.join(
root_path, "enwiki-20230401-pages-articles-multistream"
)
index_file = "enwiki-20230401-pages-articles-multistream-index.txt"
index_filename = os.path.join(base_path, index_file)
wiki_filename = os.path.join(
base_path, "enwiki-20230401-pages-articles-multistream.xml.bz2"
)
return index_filename, wiki_filename
def get_documents_by_search_term(search_term):
index_filename, wiki_filename = get_wiki_filenames()
start_byte, data_length = search_index(search_term, index_filename)
with open(wiki_filename, "rb") as wiki_file:
wiki_file.seek(start_byte)
data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length))
loader = MWDumpDirectLoader(data.decode())
documents = loader.load()
return documents
def get_one_chunk(
wiki_filename,
start_byte,
end_byte,
return_file=True,
title_words_limit=None,
use_views=True,
):
data_length = end_byte - start_byte
with open(wiki_filename, "rb") as wiki_file:
wiki_file.seek(start_byte)
data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length))
loader = MWDumpDirectLoader(
data.decode(), title_words_limit=title_words_limit, use_views=use_views
)
documents1 = loader.load()
if return_file:
base_tmp = "temp_wiki"
if not os.path.isdir(base_tmp):
os.makedirs(base_tmp, exist_ok=True)
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle")
with open(filename, "wb") as f:
pickle.dump(documents1, f)
return filename
return documents1
from joblib import Parallel, delayed
global_views = get_views()
def get_all_documents(small_test=2, n_jobs=None, use_views=True):
print("DO get all wiki docs: %s" % small_test, flush=True)
index_filename, wiki_filename = get_wiki_filenames()
start_bytes = get_start_bytes(index_filename)
end_bytes = start_bytes[1:]
start_bytes = start_bytes[:-1]
if small_test:
start_bytes = start_bytes[:small_test]
end_bytes = end_bytes[:small_test]
if n_jobs is None:
n_jobs = 5
else:
if n_jobs is None:
n_jobs = os.cpu_count() // 4
# default loky backend leads to name space conflict problems
return_file = True # large return from joblib hangs
documents = Parallel(n_jobs=n_jobs, verbose=10, backend="multiprocessing")(
delayed(get_one_chunk)(
wiki_filename,
start_byte,
end_byte,
return_file=return_file,
use_views=use_views,
)
for start_byte, end_byte in zip(start_bytes, end_bytes)
)
if return_file:
# then documents really are files
files = documents.copy()
documents = []
for fil in files:
with open(fil, "rb") as f:
documents.extend(pickle.load(f))
os.remove(fil)
else:
from functools import reduce
from operator import concat
documents = reduce(concat, documents)
assert isinstance(documents, list)
print("DONE get all wiki docs", flush=True)
return documents
def test_by_search_term():
search_term = "Apollo"
assert len(get_documents_by_search_term(search_term)) == 100
search_term = "Abstract (law)"
assert len(get_documents_by_search_term(search_term)) == 100
search_term = "Artificial languages"
assert len(get_documents_by_search_term(search_term)) == 100
def test_start_bytes():
index_filename, wiki_filename = get_wiki_filenames()
assert len(get_start_bytes(index_filename)) == 227850
def test_get_all_documents():
small_test = 20 # 227850
n_jobs = os.cpu_count() // 4
assert (
len(
get_all_documents(
small_test=small_test, n_jobs=n_jobs, use_views=False
)
)
== small_test * 100
)
assert (
len(
get_all_documents(
small_test=small_test, n_jobs=n_jobs, use_views=True
)
)
== 429
)
def get_one_pageviews(fil):
df1 = pd.read_csv(
fil,
sep=" ",
header=None,
names=["region", "title", "views", "foo"],
quoting=csv.QUOTE_NONE,
)
df1.index = df1["title"]
df1 = df1[df1["region"] == "en"]
df1 = df1.drop("region", axis=1)
df1 = df1.drop("foo", axis=1)
df1 = df1.drop("title", axis=1) # already index
base_tmp = "temp_wiki_pageviews"
if not os.path.isdir(base_tmp):
os.makedirs(base_tmp, exist_ok=True)
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.csv")
df1.to_csv(filename, index=True)
return filename
def test_agg_pageviews(gen_files=False):
if gen_files:
path = os.path.join(
root_path,
"wiki_pageviews/dumps.wikimedia.org/other/pageviews/2023/2023-04",
)
files = glob.glob(os.path.join(path, "pageviews*.gz"))
# files = files[:2] # test
n_jobs = os.cpu_count() // 2
csv_files = Parallel(
n_jobs=n_jobs, verbose=10, backend="multiprocessing"
)(delayed(get_one_pageviews)(fil) for fil in files)
else:
# to continue without redoing above
csv_files = glob.glob(
os.path.join(root_path, "temp_wiki_pageviews/*.csv")
)
df_list = []
for csv_file in csv_files:
print(csv_file)
df1 = pd.read_csv(csv_file)
df_list.append(df1)
df = pd.concat(df_list, axis=0)
df = df.groupby("title")["views"].sum().reset_index()
df.to_csv("wiki_page_views.csv", index=True)
def test_reduce_pageview():
filename = "wiki_page_views.csv"
df = pd.read_csv(filename)
df = df[df["views"] < 1e7]
#
plt.hist(df["views"], bins=100, log=True)
views_avg = np.mean(df["views"])
views_median = np.median(df["views"])
plt.title("Views avg: %s median: %s" % (views_avg, views_median))
plt.savefig(filename.replace(".csv", ".png"))
plt.close()
#
views_limit = 5000
df = df[df["views"] > views_limit]
filename = "wiki_page_views_more_5000month.csv"
df.to_csv(filename, index=True)
#
plt.hist(df["views"], bins=100, log=True)
views_avg = np.mean(df["views"])
views_median = np.median(df["views"])
plt.title("Views avg: %s median: %s" % (views_avg, views_median))
plt.savefig(filename.replace(".csv", ".png"))
plt.close()
@pytest.mark.skip("Only if doing full processing again, some manual steps")
def test_do_wiki_full_all():
# Install other requirements for wiki specific conversion:
# pip install -r reqs_optional/requirements_optional_wikiprocessing.txt
# Use "Transmission" in Ubuntu to get wiki dump using torrent:
# See: https://meta.wikimedia.org/wiki/Data_dump_torrents
# E.g. magnet:?xt=urn:btih:b2c74af2b1531d0b63f1166d2011116f44a8fed0&dn=enwiki-20230401-pages-articles-multistream.xml.bz2&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337
# Get index
os.system(
"wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2"
)
# Test that can use LangChain to get docs from subset of wiki as sampled out of full wiki directly using bzip multistream
test_get_all_documents()
# Check can search wiki multistream
test_by_search_term()
# Test can get all start bytes in index
test_start_bytes()
# Get page views, e.g. for entire month of April 2023
os.system(
"wget -b -m -k -o wget.log -e robots=off https://dumps.wikimedia.org/other/pageviews/2023/2023-04/"
)
# Aggregate page views from many files into single file
test_agg_pageviews(gen_files=True)
# Reduce page views to some limit, so processing of full wiki is not too large
test_reduce_pageview()
# Start generate.py with requesting wiki_full in prep. This will use page views as referenced in get_views.
# Note get_views as global() function done once is required to avoid very slow processing
# WARNING: Requires alot of memory to handle, used up to 300GB system RAM at peak
"""
python generate.py --langchain_mode='wiki_full' --visible_langchain_modes="['wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']" &> lc_out.log
"""

View File

@@ -0,0 +1,121 @@
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from enums import PromptType
class StoppingCriteriaSub(StoppingCriteria):
def __init__(
self, stops=[], encounters=[], device="cuda", model_max_length=None
):
super().__init__()
assert (
len(stops) % len(encounters) == 0
), "Number of stops and encounters must match"
self.encounters = encounters
self.stops = [stop.to(device) for stop in stops]
self.num_stops = [0] * len(stops)
self.model_max_length = model_max_length
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stopi, stop in enumerate(self.stops):
if torch.all((stop == input_ids[0][-len(stop) :])).item():
self.num_stops[stopi] += 1
if (
self.num_stops[stopi]
>= self.encounters[stopi % len(self.encounters)]
):
# print("Stopped", flush=True)
return True
if (
self.model_max_length is not None
and input_ids[0].shape[0] >= self.model_max_length
):
# critical limit
return True
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
return False
def get_stopping(
prompt_type,
prompt_dict,
tokenizer,
device,
human="<human>:",
bot="<bot>:",
model_max_length=None,
):
# FIXME: prompt_dict unused currently
if prompt_type in [
PromptType.human_bot.name,
PromptType.instruct_vicuna.name,
PromptType.instruct_with_end.name,
]:
if prompt_type == PromptType.human_bot.name:
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
# stopping only starts once output is beyond prompt
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
stop_words = [human, bot, "\n" + human, "\n" + bot]
encounters = [1, 2]
elif prompt_type == PromptType.instruct_vicuna.name:
# even below is not enough, generic strings and many ways to encode
stop_words = [
"### Human:",
"""
### Human:""",
"""
### Human:
""",
"### Assistant:",
"""
### Assistant:""",
"""
### Assistant:
""",
]
encounters = [1, 2]
else:
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
stop_words = ["### End"]
encounters = [1]
stop_words_ids = [
tokenizer(stop_word, return_tensors="pt")["input_ids"].squeeze()
for stop_word in stop_words
]
# handle single token case
stop_words_ids = [
x if len(x.shape) > 0 else torch.tensor([x])
for x in stop_words_ids
]
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
# avoid padding in front of tokens
if (
tokenizer._pad_token
): # use hidden variable to avoid annoying properly logger bug
stop_words_ids = [
x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x
for x in stop_words_ids
]
# handle fake \n added
stop_words_ids = [
x[1:] if y[0] == "\n" else x
for x, y in zip(stop_words_ids, stop_words)
]
# build stopper
stopping_criteria = StoppingCriteriaList(
[
StoppingCriteriaSub(
stops=stop_words_ids,
encounters=encounters,
device=device,
model_max_length=model_max_length,
)
]
)
else:
stopping_criteria = StoppingCriteriaList()
return stopping_criteria

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,69 @@
from typing import Any, Dict, List, Union, Optional
import time
import queue
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
class StreamingGradioCallbackHandler(BaseCallbackHandler):
"""
Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
"""
def __init__(self, timeout: Optional[float] = None, block=True):
super().__init__()
self.text_queue = queue.SimpleQueue()
self.stop_signal = None
self.do_stop = False
self.timeout = timeout
self.block = block
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running. Clean the queue."""
while not self.text_queue.empty():
try:
self.text_queue.get(block=False)
except queue.Empty:
continue
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
self.text_queue.put(token)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.text_queue.put(self.stop_signal)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
self.text_queue.put(self.stop_signal)
def __iter__(self):
return self
def __next__(self):
while True:
try:
value = (
self.stop_signal
) # value looks unused in pycharm, not true
if self.do_stop:
print("hit stop", flush=True)
# could raise or break, maybe best to raise and make parent see if any exception in thread
raise StopIteration()
# break
value = self.text_queue.get(
block=self.block, timeout=self.timeout
)
break
except queue.Empty:
time.sleep(0.01)
if value == self.stop_signal:
raise StopIteration()
else:
return value

View File

@@ -0,0 +1,442 @@
from pathlib import Path
import argparse
from argparse import RawTextHelpFormatter
import re, gc
"""
This script can be used as a standalone utility to convert IRs to dynamic + combine them.
Following are the various ways this script can be used :-
a. To convert a single Linalg IR to dynamic IR:
--dynamic --first_ir_path=<PATH TO FIRST IR>
b. To convert two Linalg IRs to dynamic IR:
--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>
c. To combine two Linalg IRs into one:
--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
d. To convert both IRs into dynamic as well as combine the IRs:
--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
NOTE: For dynamic you'll also need to provide the following set of flags:-
i. For First Llama : --dynamic_input_size (DEFAULT: 19)
ii. For Second Llama: --model_name (DEFAULT: llama2_7b)
--precision (DEFAULT: 'int4')
You may use --save_dynamic to also save the dynamic IR in option d above.
Else for option a. and b. the dynamic IR(s) will get saved by default.
"""
def combine_mlir_scripts(
first_vicuna_mlir,
second_vicuna_mlir,
output_name,
return_ir=True,
):
print(f"[DEBUG] combining first and second mlir")
print(f"[DEBUG] output_name = {output_name}")
maps1 = []
maps2 = []
constants = set()
f1 = []
f2 = []
print(f"[DEBUG] processing first vicuna mlir")
first_vicuna_mlir = first_vicuna_mlir.splitlines()
while first_vicuna_mlir:
line = first_vicuna_mlir.pop(0)
if re.search("#map\d*\s*=", line):
maps1.append(line)
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "first_vicuna_forward", line)
f1.append(line)
f1 = f1[:-1]
del first_vicuna_mlir
gc.collect()
for i, map_line in enumerate(maps1):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
maps1[i] = map_line
f1 = [
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
for func_line in f1
]
print(f"[DEBUG] processing second vicuna mlir")
second_vicuna_mlir = second_vicuna_mlir.splitlines()
while second_vicuna_mlir:
line = second_vicuna_mlir.pop(0)
if re.search("#map\d*\s*=", line):
maps2.append(line)
elif "global_seed" in line:
continue
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
f2 = f2[:-1]
del second_vicuna_mlir
gc.collect()
for i, map_line in enumerate(maps2):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
maps2[i] = map_line
f2 = [
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
for func_line in f2
]
module_start = 'module attributes {torch.debug_module_name = "_lambda"} {'
module_end = "}"
global_vars = []
vnames = []
global_var_loading1 = []
global_var_loading2 = []
print(f"[DEBUG] processing constants")
counter = 0
constants = list(constants)
while constants:
constant = constants.pop(0)
vname, vbody = constant.split("=")
vname = re.sub("%", "", vname)
vname = vname.strip()
vbody = re.sub("arith.constant", "", vbody)
vbody = vbody.strip()
if len(vbody.split(":")) < 2:
print(constant)
vdtype = vbody.split(":")[-1].strip()
fixed_vdtype = vdtype
if "c1_i64" in vname:
print(constant)
counter += 1
if counter == 2:
counter = 0
print("detected duplicate")
continue
vnames.append(vname)
if "true" not in vname:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
else:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : i1"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
new_f1, new_f2 = [], []
print(f"[DEBUG] processing f1")
for line in f1:
if "func.func" in line:
new_f1.append(line)
for global_var in global_var_loading1:
new_f1.append(global_var)
else:
new_f1.append(line)
print(f"[DEBUG] processing f2")
for line in f2:
if "func.func" in line:
new_f2.append(line)
for global_var in global_var_loading2:
if (
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
in global_var
):
print(global_var)
new_f2.append(global_var)
else:
new_f2.append(line)
f1 = new_f1
f2 = new_f2
del new_f1
del new_f2
gc.collect()
print(
[
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
for x in [maps1, maps2, global_vars, f1, f2]
]
)
# doing it this way rather than assembling the whole string
# to prevent OOM with 64GiB RAM when encoding the file.
print(f"[DEBUG] Saving mlir to {output_name}")
with open(output_name, "w+") as f_:
f_.writelines(line + "\n" for line in maps1)
f_.writelines(line + "\n" for line in maps2)
f_.writelines(line + "\n" for line in [module_start])
f_.writelines(line + "\n" for line in global_vars)
f_.writelines(line + "\n" for line in f1)
f_.writelines(line + "\n" for line in f2)
f_.writelines(line + "\n" for line in [module_end])
del maps1
del maps2
del module_start
del global_vars
del f1
del f2
del module_end
gc.collect()
if return_ir:
print(f"[DEBUG] Reading combined mlir back in")
with open(output_name, "rb") as f:
return f.read()
def write_in_dynamic_inputs0(module, dynamic_input_size):
print("[DEBUG] writing dynamic inputs to first vicuna")
# Current solution for ensuring mlir files support dynamic inputs
# TODO: find a more elegant way to implement this
new_lines = []
module = module.splitlines()
while module:
line = module.pop(0)
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
continue
new_lines.append(line)
return "\n".join(new_lines)
def write_in_dynamic_inputs1(module, model_name, precision):
print("[DEBUG] writing dynamic inputs to second vicuna")
def remove_constant_dim(line):
if "c19_i64" in line:
line = re.sub("c19_i64", "dim_i64", line)
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
if "x20x" in line or "<20x" in line:
line = re.sub("20x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
if " 20," in line:
line = re.sub(" 20,", " %dimp1,", line)
return line
module = module.splitlines()
new_lines = []
# Using a while loop and the pop method to avoid creating a copy of module
if "llama2_13b" in model_name:
pkv_tensor_shape = "tensor<1x40x?x128x"
elif "llama2_70b" in model_name:
pkv_tensor_shape = "tensor<1x8x?x128x"
else:
pkv_tensor_shape = "tensor<1x32x?x128x"
if precision in ["fp16", "int4", "int8"]:
pkv_tensor_shape += "f16>"
else:
pkv_tensor_shape += "f32>"
while module:
line = module.pop(0)
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
)
continue
if "%c2 = arith.constant 2 : index" in line:
continue
if "%c20_i64 = arith.constant 20 : i64" in line:
new_lines.append("%c1_i64 = arith.constant 1 : i64")
new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
new_lines.append(
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
)
continue
line = remove_constant_dim(line)
new_lines.append(line)
return "\n".join(new_lines)
def save_dynamic_ir(ir_to_save, output_file):
if not ir_to_save:
return
# We only get string output from the dynamic conversion utility.
from contextlib import redirect_stdout
with open(output_file, "w") as f:
with redirect_stdout(f):
print(ir_to_save)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="llama ir utility",
description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n"
+ "\tFollowing are the various ways this script can be used :-\n"
+ "\t\ta. To convert a single Linalg IR to dynamic IR:\n"
+ "\t\t\t--dynamic --first_ir_path=<PATH TO FIRST IR>\n"
+ "\t\tb. To convert two Linalg IRs to dynamic IR:\n"
+ "\t\t\t--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>\n"
+ "\t\tc. To combine two Linalg IRs into one:\n"
+ "\t\t\t--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n"
+ "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n"
+ "\t\t\t--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n\n"
+ "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n"
+ "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n"
+ "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n"
+ "\t\t\t--precision (DEFAULT: 'int4')\n"
+ "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n"
+ "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--precision",
"-p",
default="int4",
choices=["fp32", "fp16", "int8", "int4"],
help="Precision of the concerned IR",
)
parser.add_argument(
"--model_name",
type=str,
default="llama2_7b",
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
help="Specify which model to run.",
)
parser.add_argument(
"--first_ir_path",
default=None,
help="path to first llama mlir file",
)
parser.add_argument(
"--second_ir_path",
default=None,
help="path to second llama mlir file",
)
parser.add_argument(
"--dynamic_input_size",
type=int,
default=19,
help="Specify the static input size to replace with dynamic dim.",
)
parser.add_argument(
"--dynamic",
default=False,
action=argparse.BooleanOptionalAction,
help="Converts the IR(s) to dynamic",
)
parser.add_argument(
"--save_dynamic",
default=False,
action=argparse.BooleanOptionalAction,
help="Save the individual IR(s) after converting to dynamic",
)
parser.add_argument(
"--combine",
default=False,
action=argparse.BooleanOptionalAction,
help="Converts the IR(s) to dynamic",
)
args, unknown = parser.parse_known_args()
dynamic = args.dynamic
combine = args.combine
assert (
dynamic or combine
), "neither `dynamic` nor `combine` flag is turned on"
first_ir_path = args.first_ir_path
second_ir_path = args.second_ir_path
assert first_ir_path or second_ir_path, "no input ir has been provided"
if combine:
assert (
first_ir_path and second_ir_path
), "you will need to provide both IRs to combine"
precision = args.precision
model_name = args.model_name
dynamic_input_size = args.dynamic_input_size
save_dynamic = args.save_dynamic
print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}")
print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}")
if dynamic and not combine:
save_dynamic = True
first_ir = None
first_dynamic_ir_name = None
second_ir = None
second_dynamic_ir_name = None
if first_ir_path:
first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic"
with open(first_ir_path, "r") as f:
first_ir = f.read()
if second_ir_path:
second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic"
with open(second_ir_path, "r") as f:
second_ir = f.read()
if dynamic:
first_ir = (
write_in_dynamic_inputs0(first_ir, dynamic_input_size)
if first_ir
else None
)
second_ir = (
write_in_dynamic_inputs1(second_ir, model_name, precision)
if second_ir
else None
)
if save_dynamic:
save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir")
save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir")
if combine:
combine_mlir_scripts(
first_ir,
second_ir,
f"{model_name}_{precision}.mlir",
return_ir=False,
)

View File

@@ -0,0 +1,211 @@
import torch
import torch_mlir
from transformers import (
AutoTokenizer,
StoppingCriteria,
)
from io import BytesIO
from pathlib import Path
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
get_vmfb_from_path,
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def shouldStop(tokens):
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if tokens[0][-1] == stop_id:
return True
return False
MAX_SEQUENCE_LENGTH = 256
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def compile_stableLM(
model,
model_inputs,
model_name,
model_vmfb_name,
device="cuda",
precision="fp32",
debug=False,
):
from shark.shark_inference import SharkInference
# device = "cuda" # "cpu"
# TODO: vmfb and mlir name should include precision and device
vmfb_path = (
Path(model_name + f"_{device}.vmfb")
if model_vmfb_name is None
else Path(model_vmfb_name)
)
shark_module = get_vmfb_from_path(
vmfb_path, device, mlir_dialect="tm_tensor"
)
if shark_module is not None:
return shark_module
mlir_path = Path(model_name + ".mlir")
print(
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
)
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
)
print("Saved vmfb at ", str(path))
return shark_module
class StableLMModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
combine_input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
output = self.model(**combine_input_dict)
return output.logits
# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
def get_tokenizer():
model_path = "stabilityai/stablelm-tuned-alpha-3b"
tok = AutoTokenizer.from_pretrained(model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
print("Sucessfully loaded the tokenizer to the memory")
return tok
# sharkStableLM = compile_stableLM
# (
# None,
# tuple([input_ids, attention_mask]),
# "stableLM_linalg_f32_seqLen256",
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
# )
def generate(
new_text,
max_new_tokens,
sharkStableLM,
tokenizer=None,
):
if tokenizer is None:
tokenizer = get_tokenizer()
# Construct the input message string for the model by
# concatenating the current system message and conversation history
# Tokenize the messages string
# sharkStableLM = compile_stableLM
# (
# None,
# tuple([input_ids, attention_mask]),
# "stableLM_linalg_f32_seqLen256",
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
# )
words_list = []
for i in range(max_new_tokens):
# numWords = len(new_text.split())
# if(numWords>220):
# break
params = {
"new_text": new_text,
}
generated_token_op = generate_new_token(
sharkStableLM, tokenizer, params
)
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]
if stop_generation:
break
print(detok, end="", flush=True)
words_list.append(detok)
if detok == "":
break
new_text = new_text + detok
return words_list
def generate_new_token(shark_model, tokenizer, params):
new_text = params["new_text"]
model_inputs = tokenizer(
[new_text],
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
output = shark_model(
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
)
output = torch.from_numpy(output)
next_toks = torch.topk(output, 1)
stop_generation = False
if shouldStop(next_toks.indices):
stop_generation = True
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
detok = tokenizer.decode(
new_token,
skip_special_tokens=True,
)
ret_dict = {
"new_token": new_token,
"detok": detok,
"stop_generation": stop_generation,
}
return ret_dict

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,94 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_submodules
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += copy_metadata('huggingface-hub')
datas += copy_metadata('sentencepiece')
datas += copy_metadata("pyyaml")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('py-cpuinfo')
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['scripts/vicuna.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_llama_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,22 @@
import torch
class FalconModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": None,
"use_cache": True,
}
output = self.model(
**input_dict,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)[0]
return output[:, -1, :]

View File

@@ -0,0 +1,503 @@
import torch
import dataclasses
from enum import auto, Enum
from typing import List, Any
from transformers import StoppingCriteria
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
class LayerNorm(torch.nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class VisionModel(torch.nn.Module):
def __init__(
self,
ln_vision,
visual_encoder,
precision="fp32",
weight_group_size=128,
):
super().__init__()
self.ln_vision = ln_vision
self.visual_encoder = visual_encoder
if precision in ["int4", "int8"]:
print("Vision Model applying weight quantization to ln_vision")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.ln_vision,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
print(
"Vision Model applying weight quantization to visual_encoder"
)
quantize_model(
self.visual_encoder,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(self, image):
image_embeds = self.ln_vision(self.visual_encoder(image))
return image_embeds
class QformerBertModel(torch.nn.Module):
def __init__(self, qformer_bert):
super().__init__()
self.qformer_bert = qformer_bert
def forward(self, query_tokens, image_embeds, image_atts):
query_output = self.qformer_bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
return query_output.last_hidden_state
class FirstLlamaModel(torch.nn.Module):
def __init__(self, model, precision="fp32", weight_group_size=128):
super().__init__()
self.model = model
print("SHARK: Loading LLAMA Done")
if precision in ["int4", "int8"]:
print("First Llama applying weight quantization")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(self, inputs_embeds, position_ids, attention_mask):
print("************************************")
print(
"inputs_embeds: ",
inputs_embeds.shape,
" dtype: ",
inputs_embeds.dtype,
)
print(
"position_ids: ",
position_ids.shape,
" dtype: ",
position_ids.dtype,
)
print(
"attention_mask: ",
attention_mask.shape,
" dtype: ",
attention_mask.dtype,
)
print("************************************")
config = {
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"past_key_values": None,
"use_cache": True,
"attention_mask": attention_mask,
}
output = self.model(
**config,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return_vals = []
return_vals.append(output.logits)
temp_past_key_values = output.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class SecondLlamaModel(torch.nn.Module):
def __init__(self, model, precision="fp32", weight_group_size=128):
super().__init__()
self.model = model
print("SHARK: Loading LLAMA Done")
if precision in ["int4", "int8"]:
print("Second Llama applying weight quantization")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,
input_ids,
position_ids,
attention_mask,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
):
print("************************************")
print("input_ids: ", input_ids.shape, " dtype: ", input_ids.dtype)
print(
"position_ids: ",
position_ids.shape,
" dtype: ",
position_ids.dtype,
)
print(
"attention_mask: ",
attention_mask.shape,
" dtype: ",
attention_mask.dtype,
)
print("past_key_values: ", i1.shape, i2.shape, i63.shape, i64.shape)
print("past_key_values dtype: ", i1.dtype)
print("************************************")
config = {
"input_ids": input_ids,
"position_ids": position_ids,
"past_key_values": (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
),
"use_cache": True,
"attention_mask": attention_mask,
}
output = self.model(
**config,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return_vals = []
return_vals.append(output.logits)
temp_past_key_values = output.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
skip_next: bool = False
conv_id: Any = None
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
conv_id=self.conv_id,
)
def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
"conv_id": self.conv_id,
}
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop) :])).item():
return True
return False
CONV_VISION = Conversation(
system="Give the following image: <Img>ImageContent</Img>. "
"You will be able to see the image once I provide it to you. Please answer my questions.",
roles=("Human", "Assistant"),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)

View File

@@ -0,0 +1,15 @@
import torch
class StableLMModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
combine_input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
output = self.model(**combine_input_dict)
return output.logits

View File

@@ -0,0 +1,876 @@
import argparse
import json
import re
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from typing import List, Optional, Tuple, Union
import numpy as np
import iree.runtime
import itertools
import subprocess
import torch
import torch_mlir
from torch_mlir import TensorPlaceholder
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LlamaPreTrainedModel,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
FirstVicunaLayer,
SecondVicunaLayer,
CompiledVicunaLayer,
ShardedVicunaModel,
LMHead,
LMHeadCompiled,
VicunaEmbedding,
VicunaEmbeddingCompiled,
VicunaNorm,
VicunaNormCompiled,
)
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna7B,
)
from apps.language_models.utils import (
get_vmfb_from_path,
)
from shark.shark_downloader import download_public_file
from shark.shark_importer import get_f16_inputs
from shark.shark_inference import SharkInference
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaRMSNorm,
_make_causal_mask,
_expand_mask,
)
from torch import nn
from time import time
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(config)
for _ in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
t1 = time()
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = (
use_cache if use_cache is not None else self.config.use_cache
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = (
seq_length_with_past + past_key_values_length
)
if position_ids is None:
device = (
input_ids.device
if input_ids is not None
else inputs_embeds.device
)
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.compressedlayers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[8 * idx : 8 * (idx + 1)]
if past_key_values is not None
else None
)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer.forward(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[1:],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
try:
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
except:
_ = 10
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
next_cache = tuple(itertools.chain.from_iterable(next_cache))
print(f"Token generated in {time() - t1} seconds")
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class EightLayerLayerSV(torch.nn.Module):
def __init__(self, layers):
super().__init__()
assert len(layers) == 8
self.layers = layers
def forward(
self,
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
):
pkvs = [
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
]
new_pkvs = []
for layer, pkv in zip(self.layers, pkvs):
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
pkv[0],
pkv[1],
),
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
return (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
class EightLayerLayerFV(torch.nn.Module):
def __init__(self, layers):
super().__init__()
assert len(layers) == 8
self.layers = layers
def forward(self, hidden_states, attention_mask, position_ids):
new_pkvs = []
for layer in self.layers:
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=None,
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
return (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
class CompiledEightLayerLayerSV(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
(
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
) = past_key_value
pkv00 = pkv00.detatch()
pkv01 = pkv01.detatch()
pkv10 = pkv10.detatch()
pkv11 = pkv11.detatch()
pkv20 = pkv20.detatch()
pkv21 = pkv21.detatch()
pkv30 = pkv30.detatch()
pkv31 = pkv31.detatch()
pkv40 = pkv40.detatch()
pkv41 = pkv41.detatch()
pkv50 = pkv50.detatch()
pkv51 = pkv51.detatch()
pkv60 = pkv60.detatch()
pkv61 = pkv61.detatch()
pkv70 = pkv70.detatch()
pkv71 = pkv71.detatch()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
),
send_to_host=False,
)
return (
output[0],
(output[1][0], output[1][1]),
(output[2][0], output[2][1]),
(output[3][0], output[3][1]),
(output[4][0], output[4][1]),
(output[5][0], output[5][1]),
(output[6][0], output[6][1]),
(output[7][0], output[7][1]),
(output[8][0], output[8][1]),
)
def forward_compressed(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = (
input_ids.device if input_ids is not None else inputs_embeds.device
)
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.compressedlayers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[8 * idx : 8 * (idx + 1)]
if past_key_values is not None
else None
)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1],
)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class CompiledEightLayerLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
t2 = time()
if past_key_value is None:
try:
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
except:
pass
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
t1 = time()
output = self.model(
"first_vicuna_forward",
(hidden_states, attention_mask, position_ids),
send_to_host=False,
)
output2 = (
output[0],
(
output[1],
output[2],
),
(
output[3],
output[4],
),
(
output[5],
output[6],
),
(
output[7],
output[8],
),
(
output[9],
output[10],
),
(
output[11],
output[12],
),
(
output[13],
output[14],
),
(
output[15],
output[16],
),
)
return output2
else:
(
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
) = past_key_value
try:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv00 = pkv00.detach()
pkv01 = pkv01.detach()
pkv10 = pkv10.detach()
pkv11 = pkv11.detach()
pkv20 = pkv20.detach()
pkv21 = pkv21.detach()
pkv30 = pkv30.detach()
pkv31 = pkv31.detach()
pkv40 = pkv40.detach()
pkv41 = pkv41.detach()
pkv50 = pkv50.detach()
pkv51 = pkv51.detach()
pkv60 = pkv60.detach()
pkv61 = pkv61.detach()
pkv70 = pkv70.detach()
pkv71 = pkv71.detach()
except:
x = 10
t1 = time()
if type(hidden_states) == iree.runtime.array_interop.DeviceArray:
hidden_states = np.array(hidden_states, hidden_states.dtype)
hidden_states = torch.tensor(hidden_states)
hidden_states = hidden_states.detach()
output = self.model(
"second_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
),
send_to_host=False,
)
print(f"{time() - t1}")
del pkv00
del pkv01
del pkv10
del pkv11
del pkv20
del pkv21
del pkv30
del pkv31
del pkv40
del pkv41
del pkv50
del pkv51
del pkv60
del pkv61
del pkv70
del pkv71
output2 = (
output[0],
(
output[1],
output[2],
),
(
output[3],
output[4],
),
(
output[5],
output[6],
),
(
output[7],
output[8],
),
(
output[9],
output[10],
),
(
output[11],
output[12],
),
(
output[13],
output[14],
),
(
output[15],
output[16],
),
)
return output2

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,231 @@
import torch
class FirstVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states, attention_mask, position_ids):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class SecondVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
past_key_value0,
past_key_value1,
),
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
# assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers = layers
self.norm = norm
self.embedding = embedding
self.lmhead = lmhead
self.model.model.norm = self.norm
self.model.model.embed_tokens = self.embedding
self.model.lm_head = self.lmhead
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers
)
def forward(
self,
input_ids,
is_first=True,
past_key_values=None,
attention_mask=None,
):
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
class LMHead(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states):
output = self.model(hidden_states)
return output
class LMHeadCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, hidden_states):
hidden_states = hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
class VicunaNorm(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states):
output = self.model(hidden_states)
return output
class VicunaNormCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, hidden_states):
try:
hidden_states.detach()
except:
pass
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
class VicunaEmbedding(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids):
output = self.model(input_ids)
return output
class VicunaEmbeddingCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,))
output = torch.tensor(output)
return output
class CompiledVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
if past_key_value is None:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"first_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
else:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"second_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)

View File

@@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
class SharkLLMBase(ABC):
def __init__(
self,
model_name,
hf_model_path=None,
max_num_tokens=512,
) -> None:
self.model_name = model_name
self.hf_model_path = hf_model_path
self.max_num_tokens = max_num_tokens
self.shark_model = None
self.device = "cpu"
self.precision = "fp32"
@classmethod
@abstractmethod
def compile(self):
pass
@classmethod
@abstractmethod
def generate(self, prompt):
pass
@classmethod
@abstractmethod
def generate_new_token(self, params):
pass
@classmethod
@abstractmethod
def get_tokenizer(self):
pass
@classmethod
@abstractmethod
def get_src_model(self):
pass
def load_init_from_config(self):
pass

View File

@@ -0,0 +1,567 @@
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.utils import (
get_vmfb_from_path,
)
from io import BytesIO
from pathlib import Path
from contextlib import redirect_stdout
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
StoppingCriteriaList,
)
import copy
import re
import torch
import torch_mlir
import os
import argparse
parser = argparse.ArgumentParser(
prog="falcon runner",
description="runs a falcon model",
)
parser.add_argument(
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
)
parser.add_argument(
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
"--falcon_vmfb_path", default=None, help="path to falcon's vmfb"
)
parser.add_argument(
"--falcon_mlir_path",
default=None,
help="path to falcon's mlir file",
)
parser.add_argument(
"--use_precompiled_model",
default=True,
action=argparse.BooleanOptionalAction,
help="use the precompiled vmfb",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=True,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
parser.add_argument(
"--cli",
default=True,
action=argparse.BooleanOptionalAction,
help="Run model in cli mode",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication token for falcon-180B model.",
)
class Falcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_auth_token: str = None,
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=None,
falcon_vmfb_path=None,
debug=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
print("hf_model_path: ", self.hf_model_path)
if "180b" in self.model_name and hf_auth_token == None:
raise ValueError(
""" HF auth token required for falcon-180b. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.src_model = self.get_src_model()
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
trust_remote_code=True,
token=self.hf_auth_token,
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 11
return tokenizer
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile(self):
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ "_"
+ self.device
+ ".vmfb",
self.falcon_vmfb_path.absolute(),
single_file=True,
)
vmfb = get_vmfb_from_path(
self.falcon_vmfb_path, self.device, "linalg"
)
if vmfb is not None:
return vmfb
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
if self.falcon_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
print(
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
)
if args.load_mlir_from_shark_tank:
# Downloading MLIR from shark_tank
print(f"[DEBUG] Trying to download mlir from shark_tank")
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
print(
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
)
mlir_generated = True
if not mlir_generated:
print(f"[DEBUG] generating MLIR locally")
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 100)
)
compilation_attention_mask = torch.ones(
1, 100, dtype=torch.int64
)
falconCompileInput = (
compilation_input_ids,
compilation_attention_mask,
)
model = FalconModel(self.src_model)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
falconCompileInput,
is_f16=self.precision in ["fp16", "int4"],
f16_input_mask=[False, False],
mlir_type="torchscript",
is_gptq=self.precision == "int4",
)
del model
print(f"[DEBUG] generating torch mlir")
module = torch_mlir.compile(
ts_graph,
[*falconCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module
f_ = open(self.falcon_mlir_path, "wb")
f_.write(bytecode)
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
f_.close()
del bytecode
shark_module = SharkInference(
mlir_module=self.falcon_mlir_path,
device=self.device,
mlir_dialect="linalg",
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
+ [
"--iree-llvmcpu-use-fast-min-max-ops",
]
if self.precision == "int4"
else [],
debug=self.debug,
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def generate(self, prompt):
model_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.max_padding_length,
add_special_tokens=False,
return_tensors="pt",
)
model_inputs["prompt_text"] = prompt
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
generate_kwargs = {
"max_length": self.max_num_tokens,
"do_sample": True,
"top_k": 10,
"num_return_sequences": 1,
"eos_token_id": 11,
}
generate_kwargs["input_ids"] = input_ids
generate_kwargs["attention_mask"] = attention_mask
generation_config_ = GenerationConfig.from_model_config(
self.src_model.config
)
generation_config = copy.deepcopy(generation_config_)
model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
stopping_criteria = StoppingCriteriaList()
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
model_kwargs,
) = self.src_model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
self.logits_processor = self.src_model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids.shape[-1],
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.src_model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
)
self.logits_warper = self.src_model._get_logits_warper(
generation_config
)
(
self.input_ids,
self.model_kwargs,
) = self.src_model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
**model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id) if eos_token_id is not None else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
output_attentions = generation_config.output_attentions # False
output_hidden_states = generation_config.output_hidden_states # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
all_text = prompt
for i in range(self.max_num_tokens - 1):
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
all_text = all_text + new_word
print(f"{new_word}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(input_ids, self.scores)
):
break
torch.cuda.empty_cache()
gc.collect()
return all_text
def generate_new_token(self):
model_inputs = self.src_model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = torch.from_numpy(
self.shark_model(
"forward",
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
if self.precision in ["fp16", "int4"]:
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
if __name__ == "__main__":
args = parser.parse_args()
falcon_mlir_path = (
Path(
"falcon_"
+ args.falcon_variant_to_use
+ "_"
+ args.precision
+ ".mlir"
)
if args.falcon_mlir_path is None
else Path(args.falcon_mlir_path)
)
falcon_vmfb_path = (
Path(
"falcon_"
+ args.falcon_variant_to_use
+ "_"
+ args.precision
+ "_"
+ args.device
+ ".vmfb"
)
if args.falcon_vmfb_path is None
else Path(args.falcon_vmfb_path)
)
if args.precision == "int4":
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
else:
hf_model_path_value = (
"TheBloke/falcon-"
+ args.falcon_variant_to_use
+ "-instruct-GPTQ"
)
else:
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "tiiuae/falcon-180B-chat"
else:
hf_model_path_value = (
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
)
falcon = Falcon(
model_name="falcon_" + args.falcon_variant_to_use,
hf_model_path=hf_model_path_value,
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
falcon_vmfb_path=falcon_vmfb_path,
)
import gc
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
continue_execution = True
print("\n-----\nScript executing for the following config: \n")
print("Falcon Model: ", falcon.model_name)
print("Precision: ", args.precision)
print("Device: ", args.device)
while continue_execution:
use_default_prompt = input(
"\nDo you wish to use the default prompt text? Y/N ?: "
)
if use_default_prompt in ["Y", "y"]:
prompt = default_prompt_text
else:
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
User: {prompt}
Assistant:"""
res_str = falcon.generate(prompt_template)
torch.cuda.empty_cache()
gc.collect()
print(
"\n\n-----\nHere's the complete formatted result: \n\n",
res_str,
)
continue_execution = input(
"\nDo you wish to run script one more time? Y/N ?: "
)
continue_execution = (
True if continue_execution in ["Y", "y"] else False
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,68 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
return cls()
def build(self, **kwargs):
cfg = OmegaConf.create(kwargs)
return self.from_config(cfg)
class BlipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=224, mean=None, std=None):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 224)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
return cls(image_size=image_size, mean=mean, std=std)

View File

@@ -0,0 +1,5 @@
datasets:
cc_sbu_align:
data_type: images
build_info:
storage: /path/to/cc_sbu_align/

View File

@@ -0,0 +1,33 @@
model:
arch: mini_gpt4
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
freeze_vit: True
freeze_qformer: True
# Q-Former
num_query_token: 32
# Vicuna
llama_model: "lmsys/vicuna-7b-v1.3"
# generation configs
prompt: ""
preprocess:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"

View File

@@ -0,0 +1,25 @@
model:
arch: mini_gpt4
model_type: pretrain_vicuna
freeze_vit: True
freeze_qformer: True
max_txt_len: 160
end_sym: "###"
low_resource: False
prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
ckpt: 'prerained_minigpt4_7b.pth'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain

View File

@@ -0,0 +1,629 @@
# Based on EVA, BEIT, timm and DeiT code bases
# https://github.com/baaivision/EVA
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import math
import requests
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"input_size": (3, 224, 224),
"pool_size": None,
"crop_pct": 0.9,
"interpolation": "bicubic",
"mean": (0.5, 0.5, 0.5),
"std": (0.5, 0.5, 0.5),
**kwargs,
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
window_size=None,
attn_head_dim=None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(
torch.meshgrid([coords_h, coords_w])
) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += (
window_size[0] - 1
) # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2,
dtype=relative_coords.dtype,
)
relative_position_index[1:, 1:] = relative_coords.sum(
-1
) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer(
"relative_position_index", relative_position_index
)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat(
(
self.q_bias,
torch.zeros_like(self.v_bias, requires_grad=False),
self.v_bias,
)
)
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.relative_position_bias_table is not None:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
init_values=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
window_size=None,
attn_head_dim=None,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
window_size=window_size,
attn_head_dim=attn_head_dim,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = (
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True
)
self.gamma_2 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True
)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
)
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(
self.gamma_1
* self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
)
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (
img_size[0] // patch_size[0]
)
self.patch_shape = (
img_size[0] // patch_size[0],
img_size[1] // patch_size[1],
)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2,
dtype=relative_coords.dtype,
)
relative_position_index[1:, 1:] = relative_coords.sum(
-1
) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer(
"relative_position_index", relative_position_index
)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
class VisionTransformer(nn.Module):
"""Vision Transformer with support for patch or hybrid CNN input stage"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
init_values=None,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
use_mean_pooling=True,
init_scale=0.001,
use_checkpoint=False,
):
super().__init__()
self.image_size = img_size
self.num_classes = num_classes
self.num_features = (
self.embed_dim
) = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim)
)
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(
window_size=self.patch_embed.patch_shape, num_heads=num_heads
)
else:
self.rel_pos_bias = None
self.use_checkpoint = use_checkpoint
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
window_size=self.patch_embed.patch_shape
if use_rel_pos_bias
else None,
)
for i in range(depth)
]
)
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.cls_token, std=0.02)
# trunc_normal_(self.mask_token, std=.02)
# if isinstance(self.head, nn.Linear):
# trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
# if isinstance(self.head, nn.Linear):
# self.head.weight.data.mul_(init_scale)
# self.head.bias.data.mul_(init_scale)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = (
nn.Linear(self.embed_dim, num_classes)
if num_classes > 0
else nn.Identity()
)
def forward_features(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(
batch_size, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = (
self.rel_pos_bias() if self.rel_pos_bias is not None else None
)
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias)
return x
# x = self.norm(x)
# if self.fc_norm is not None:
# t = x[:, 1:, :]
# return self.fc_norm(t.mean(1))
# else:
# return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
# x = self.head(x)
return x
def get_intermediate_layers(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(
batch_size, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
features = []
rel_pos_bias = (
self.rel_pos_bias() if self.rel_pos_bias is not None else None
)
for blk in self.blocks:
x = blk(x, rel_pos_bias)
features.append(x)
return features
def interpolate_pos_embed(model, checkpoint_model):
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"].float()
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int(
(pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5
)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size, orig_size, new_size, new_size)
)
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(
-1, orig_size, orig_size, embedding_size
).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode="bicubic",
align_corners=False,
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
def convert_weights_to_fp16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
# l.weight.data = l.weight.data.half()
l.weight.data = l.weight.data
if l.bias is not None:
# l.bias.data = l.bias.data.half()
l.bias.data = l.bias.data
# if isinstance(l, (nn.MultiheadAttention, Attention)):
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
# tensor = getattr(l, attr)
# if tensor is not None:
# tensor.data = tensor.data.half()
model.apply(_convert_weights_to_fp16)
def create_eva_vit_g(
img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"
):
model = VisionTransformer(
img_size=img_size,
patch_size=14,
use_mean_pooling=False,
embed_dim=1408,
depth=39,
num_heads=1408 // 88,
mlp_ratio=4.3637,
qkv_bias=True,
drop_path_rate=drop_path_rate,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
use_checkpoint=use_checkpoint,
)
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
local_filename = "eva_vit_g.pth"
response = requests.get(url)
if response.status_code == 200:
with open(local_filename, "wb") as f:
f.write(response.content)
print("File downloaded successfully.")
state_dict = torch.load(local_filename, map_location="cpu")
interpolate_pos_embed(model, state_dict)
incompatible_keys = model.load_state_dict(state_dict, strict=False)
if precision == "fp16":
# model.to("cuda")
convert_weights_to_fp16(model)
return model

View File

@@ -0,0 +1,4 @@
<Img><ImageHere></Img> Describe this image in detail.
<Img><ImageHere></Img> Take a look at this image and describe what you notice.
<Img><ImageHere></Img> Please provide a detailed description of the picture.
<Img><ImageHere></Img> Could you describe the contents of this image for me?

View File

@@ -0,0 +1,187 @@
import torch
import torch_mlir
from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
from io import BytesIO
from pathlib import Path
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
get_vmfb_from_path,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.stablelm_model import (
StableLMModel,
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
class SharkStableLM(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
max_num_tokens=512,
device="cuda",
precision="fp32",
debug="False",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
def shouldStop(self, tokens):
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if tokens[0][-1] == stop_id:
return True
return False
def get_src_model(self):
model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, torch_dtype=torch.float32
)
return model
def get_model_inputs(self):
input_ids = torch.randint(3, (1, self.max_sequence_len))
attention_mask = torch.randint(3, (1, self.max_sequence_len))
return input_ids, attention_mask
def compile(self):
tmp_model_name = (
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
)
# device = "cuda" # "cpu"
# TODO: vmfb and mlir name should include precision and device
model_vmfb_name = None
vmfb_path = (
Path(tmp_model_name + f"_{self.device}.vmfb")
if model_vmfb_name is None
else Path(model_vmfb_name)
)
shark_module = get_vmfb_from_path(
vmfb_path, self.device, mlir_dialect="tm_tensor"
)
if shark_module is not None:
return shark_module
mlir_path = Path(tmp_model_name + ".mlir")
print(
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
model = StableLMModel(self.get_src_model())
model_inputs = self.get_model_inputs()
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(tmp_model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
)
print("Saved vmfb at ", str(path))
return shark_module
def get_tokenizer(self):
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
return tok
def generate(self, prompt):
words_list = []
for i in range(self.max_num_tokens):
params = {
"new_text": prompt,
}
generated_token_op = self.generate_new_token(params)
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]
if stop_generation:
break
print(detok, end="", flush=True) # this is for CLI and DEBUG
words_list.append(detok)
if detok == "":
break
prompt = prompt + detok
return words_list
def generate_new_token(self, params):
new_text = params["new_text"]
model_inputs = self.tokenizer(
[new_text],
padding="max_length",
max_length=self.max_sequence_len,
truncation=True,
return_tensors="pt",
)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
output = self.shark_model(
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
)
output = torch.from_numpy(output)
next_toks = torch.topk(output, 1)
stop_generation = False
if self.shouldStop(next_toks.indices):
stop_generation = True
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
detok = self.tokenizer.decode(
new_token,
skip_special_tokens=True,
)
ret_dict = {
"new_token": new_token,
"detok": detok,
"stop_generation": stop_generation,
}
return ret_dict
# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""

View File

@@ -0,0 +1,48 @@
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from pathlib import Path
from shark.shark_downloader import download_public_file
# expects a Path / str as arg
# returns None if path not found or SharkInference module
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
if not isinstance(vmfb_path, Path):
vmfb_path = Path(vmfb_path)
from shark.shark_inference import SharkInference
if not vmfb_path.exists():
return None
print("Loading vmfb from: ", vmfb_path)
print("Device from get_vmfb_from_path - ", device)
shark_module = SharkInference(
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
return shark_module
def get_vmfb_from_config(
shark_container,
model,
precision,
device,
vmfb_path,
padding=None,
device_id=None,
):
vmfb_url = (
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
)
if padding:
vmfb_url = vmfb_url + f"_{padding}"
vmfb_url = vmfb_url + ".vmfb"
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
return get_vmfb_from_path(
vmfb_path, device, "tm_tensor", device_id=device_id
)

View File

View File

@@ -0,0 +1,87 @@
Compile / Run Instructions:
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
Compile Commands FP32/FP16:
```shell
Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda /path/to/input/mlir -o /path/to/output/vmfb
CPU:
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu /path/to/input/mlir -o /path/to/output/vmfb
```
Run / Benchmark Command (FP32 - NCHW):
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
```shell
## Vulkan AMD:
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CUDA:
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=cuda --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CPU:
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=local-task --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
Run via vulkan_gui for RGP Profiling:
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
```shell
./build/vulkan_gui/iree-vulkan-gui --module=/path/to/unet.vmfb --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
</details>
<details>
<summary>Debug Commands</summary>
## Debug commands and other advanced usage follows.
```shell
python txt2img.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
```
## dump all dispatch .spv and isa using amdllpc
```shell
python txt2img.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
```
## Compile and save the .vmfb (using vulkan fp16 as an example):
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
```
## Capture an RGP trace
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
```
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
```shell
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf16
```
## Run the unet module with iree-benchmark-module (same config as above):
```shell
##if you want to use .npz inputs:
unzip ~/.local/shark_tank/<your unet>/inputs.npz
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --input=@arr_0.npy --input=1xf16 --input=@arr_2.npy --input=@arr_3.npy --input=@arr_4.npy
```
</details>

View File

@@ -0,0 +1 @@
from apps.stable_diffusion.scripts.train_lora_word import lora_train

View File

@@ -0,0 +1,127 @@
import sys
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
StencilPipeline,
resize_stencil,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
def main():
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
image = Image.open(args.img_path).convert("RGB")
# When the models get uploaded, it should be default to False.
args.import_mlir = True
use_stencil = args.use_stencil
if use_stencil:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, args.width, args.height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model
if use_stencil:
img2img_obj = StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
start_time = time.time()
generated_imgs = img2img_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.batch_size,
args.height,
args.width,
args.steps,
args.strength,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, strength={args.strength}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += img2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
extra_info = {"STRENGTH": args.strength}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,105 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
def main():
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
if args.mask_path is None:
print("Flag --mask_path is required.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
model_id = (
args.hf_model_id
if "inpaint" in args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
image = Image.open(args.img_path)
mask_image = Image.open(args.mask_path)
inpaint_obj = InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
start_time = time.time()
generated_imgs = inpaint_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
mask_image,
args.batch_size,
args.height,
args.width,
args.inpaint_full_res,
args.inpaint_full_res_padding,
args.steps,
args.guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += f"seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += inpaint_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,19 @@
from apps.stable_diffusion.src import args
from apps.stable_diffusion.scripts import (
img2img,
txt2img,
# inpaint,
# outpaint,
)
if __name__ == "__main__":
if args.app == "txt2img":
txt2img.main()
elif args.app == "img2img":
img2img.main()
# elif args.app == "inpaint":
# inpaint.main()
# elif args.app == "outpaint":
# outpaint.main()
else:
print(f"args.app value is {args.app} but this isn't supported")

View File

@@ -0,0 +1,120 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
def main():
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
model_id = (
args.hf_model_id
if "inpaint" in args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
image = Image.open(args.img_path)
outpaint_obj = OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
start_time = time.time()
generated_imgs = outpaint_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.pixels,
args.mask_blur,
args.left,
args.right,
args.top,
args.bottom,
args.noise_q,
args.color_variation,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += f"seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += outpaint_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
# save this information as metadata of output generated image.
directions = []
if args.left:
directions.append("left")
if args.right:
directions.append("right")
if args.top:
directions.append("up")
if args.bottom:
directions.append("down")
extra_info = {
"PIXELS": args.pixels,
"MASK_BLUR": args.mask_blur,
"DIRECTIONS": directions,
"NOISE_Q": args.noise_q,
"COLOR_VARIATION": args.color_variation,
}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,240 @@
import logging
import os
from models.stable_diffusion.main import stable_diff_inf
from models.stable_diffusion.utils import get_available_devices
from dotenv import load_dotenv
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram import BotCommand
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
from io import BytesIO
import random
log = logging.getLogger("TG.Bot")
logging.basicConfig()
log.warning("Start")
load_dotenv()
os.environ["AMD_ENABLE_LLPC"] = "0"
TG_TOKEN = os.getenv("TG_TOKEN")
SELECTED_MODEL = "stablediffusion"
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
STEPS = 30
NEGATIVE_PROMPT = (
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
" legs, fused fingers, too many fingers"
)
GUIDANCE_SCALE = 6
available_devices = get_available_devices()
models_list = [
"stablediffusion",
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]
sheds_list = [
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
]
def image_to_bytes(image):
bio = BytesIO()
bio.name = "image.jpeg"
image.save(bio, "JPEG")
bio.seek(0)
return bio
def get_try_again_markup():
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
reply_markup = InlineKeyboardMarkup(keyboard)
return reply_markup
def generate_image(prompt):
seed = random.randint(1, 10000)
log.warning(SELECTED_MODEL)
log.warning(STEPS)
image, text = stable_diff_inf(
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
steps=STEPS,
guidance_scale=GUIDANCE_SCALE,
seed=seed,
scheduler_key=SELECTED_SCHEDULER,
variant=SELECTED_MODEL,
device_key=available_devices[0],
)
return image, seed
async def generate_and_send_photo(
update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
progress_msg = await update.message.reply_text(
"Generating image...", reply_to_message_id=update.message.message_id
)
im, seed = generate_image(prompt=update.message.text)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{update.message.text}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=update.message.message_id,
)
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
query = update.callback_query
if query.data in models_list:
global SELECTED_MODEL
SELECTED_MODEL = query.data
await query.answer()
await query.edit_message_text(text=f"Selected model: {query.data}")
return
if query.data in sheds_list:
global SELECTED_SCHEDULER
SELECTED_SCHEDULER = query.data
await query.answer()
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
return
replied_message = query.message.reply_to_message
await query.answer()
progress_msg = await query.message.reply_text(
"Generating image...", reply_to_message_id=replied_message.message_id
)
if query.data == "TRYAGAIN":
prompt = replied_message.text
im, seed = generate_image(prompt)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{prompt}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=replied_message.message_id,
)
async def select_model_handler(update, context):
text = "Select model"
keyboard = []
for model in models_list:
keyboard.append(
[
InlineKeyboardButton(text=model, callback_data=model),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def select_scheduler_handler(update, context):
text = "Select schedule"
keyboard = []
for shed in sheds_list:
keyboard.append(
[
InlineKeyboardButton(text=shed, callback_data=shed),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def set_steps_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_steps ")[1]
global STEPS
STEPS = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_steps 30"
)
await update.message.reply_text(input_args)
async def set_negative_prompt_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_negative_prompt ")[1]
global NEGATIVE_PROMPT
NEGATIVE_PROMPT = input_args
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_negative_prompt ugly, bad art, mutated"
)
await update.message.reply_text(input_args)
async def set_guidance_scale_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_guidance_scale ")[1]
global GUIDANCE_SCALE
GUIDANCE_SCALE = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_guidance_scale 7"
)
await update.message.reply_text(input_args)
async def setup_bot_commands(application: Application) -> None:
await application.bot.set_my_commands(
[
BotCommand("select_model", "to select model"),
BotCommand("select_scheduler", "to select scheduler"),
BotCommand("set_steps", "to set steps"),
BotCommand("set_guidance_scale", "to set guidance scale"),
BotCommand("set_negative_prompt", "to set negative prompt"),
]
)
app = (
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
)
app.add_handler(CommandHandler("select_model", select_model_handler))
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
app.add_handler(CommandHandler("set_steps", set_steps_handler))
app.add_handler(
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
)
app.add_handler(
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
)
app.add_handler(
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
)
app.add_handler(CallbackQueryHandler(button))
log.warning("Start bot")
app.run_polling()

View File

@@ -0,0 +1,693 @@
# Install the required libs
# pip install -U git+https://github.com/huggingface/diffusers.git
# pip install accelerate transformers ftfy
# HuggingFace Token
# YOUR_TOKEN = "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"
# Import required libraries
import itertools
import math
import os
from typing import List
import random
import torch_mlir
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import PIL
import logging
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
import torch._dynamo as dynamo
from torch.fx.experimental.proxy_tensor import make_fx
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
from shark.shark_inference import SharkInference
torch._dynamo.config.verbose = True
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker,
)
from PIL import Image
from tqdm.auto import tqdm
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer,
)
from io import BytesIO
from dataclasses import dataclass
from apps.stable_diffusion.src import (
args,
get_schedulers,
set_init_device_flags,
clear_all,
)
from apps.stable_diffusion.src.utils import update_lora_weight
# Setup the dataset
class LoraDataset(Dataset):
def __init__(
self,
data_root,
tokenizer,
size=512,
repeats=100,
interpolation="bicubic",
set="train",
prompt="myloraprompt",
center_crop=False,
):
self.data_root = data_root
self.tokenizer = tokenizer
self.size = size
self.center_crop = center_crop
self.prompt = prompt
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
self.num_images = len(self.image_paths)
self._length = self.num_images
if set == "train":
self._length = self.num_images * repeats
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
example["input_ids"] = self.tokenizer(
self.prompt,
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids[0]
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = np.array(image).astype(np.uint8)
image = (image / 127.5 - 1.0).astype(np.float32)
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
return example
def torch_device(device):
device_tokens = device.split("=>")
if len(device_tokens) == 1:
device_str = device_tokens[0].strip()
else:
device_str = device_tokens[1].strip()
device_type_tokens = device_str.split("://")
if device_type_tokens[0] == "metal":
device_type_tokens[0] = "vulkan"
if len(device_type_tokens) > 1:
return device_type_tokens[0] + ":" + device_type_tokens[1]
else:
return device_type_tokens[0]
########## Setting up the model ##########
def lora_train(
prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
training_images_dir: str,
lora_save_dir: str,
use_lora: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
print(
"Note LoRA training is not compatible with the latest torch-mlir branch"
)
print(
"To run LoRA training you'll need this to follow this guide for the torch-mlir branch: https://github.com/nod-ai/SHARK/tree/main/shark/examples/shark_training/stable_diffusion"
)
torch.manual_seed(seed)
args.prompts = [prompt]
args.steps = steps
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be "
"empty.",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = custom_model
else:
args.hf_model_id = custom_model
args.training_images_dir = training_images_dir
args.lora_save_dir = lora_save_dir
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = torch_device(device)
args.use_lora = use_lora
# Load the Stable Diffusion model
text_encoder = CLIPTextModel.from_pretrained(
args.hf_model_id, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(args.hf_model_id, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
args.hf_model_id, subfolder="unet"
)
def freeze_params(params):
for param in params:
param.requires_grad = False
# Freeze everything but LoRA
freeze_params(vae.parameters())
freeze_params(unet.parameters())
freeze_params(text_encoder.parameters())
# Move vae and unet to device
vae.to(args.device)
unet.to(args.device)
text_encoder.to(args.device)
if use_lora != "":
update_lora_weight(unet, args.use_lora, "unet")
else:
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = vae
def forward(self, input):
x = self.vae.encode(input, return_dict=False)[0]
return x
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = unet
def forward(self, x, y, z):
return self.unet.forward(x, y, z, return_dict=False)[0]
shark_vae = VaeModel()
shark_unet = UnetModel()
####### Creating our training data ########
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id,
subfolder="tokenizer",
)
# Let's create the Dataset and Dataloader
train_dataset = LoraDataset(
data_root=args.training_images_dir,
tokenizer=tokenizer,
size=vae.sample_size,
prompt=args.prompts[0],
repeats=100,
center_crop=False,
set="train",
)
def create_dataloader(train_batch_size=1):
return torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True
)
# Create noise_scheduler for training
noise_scheduler = DDPMScheduler.from_config(
args.hf_model_id, subfolder="scheduler"
)
######## Training ###########
# Define hyperparameters for our training. If you are not happy with your results,
# you can tune the `learning_rate` and the `max_train_steps`
# Setting up all training args
hyperparameters = {
"learning_rate": 5e-04,
"scale_lr": True,
"max_train_steps": steps,
"train_batch_size": batch_size,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True,
"mixed_precision": "fp16",
"seed": 42,
"output_dir": "sd-concept-output",
}
# creating output directory
cwd = os.getcwd()
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
while not os.path.exists(str(out_dir)):
try:
os.mkdir(out_dir)
except OSError as error:
print("Output directory not created")
###### Torch-MLIR Compilation ######
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
@make_simple_dynamo_backend
def refbackend_torchdynamo_backend(
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
):
# handling usage of empty tensor without initializing
transform_fx(fx_graph)
fx_graph.recompile()
if _returns_nothing(fx_graph):
return fx_graph
removed_none_indexes = _remove_nones(fx_graph)
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
mlir_module = torch_mlir.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors"
)
bytecode_stream = BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
def compiled_callable(*inputs):
inputs = [x.numpy() for x in inputs]
result = shark_module("forward", inputs)
if was_unwrapped:
result = [
result,
]
if not isinstance(result, list):
result = torch.from_numpy(result)
else:
result = tuple(torch.from_numpy(x) for x in result)
result = list(result)
for removed_index in removed_none_indexes:
result.insert(removed_index, None)
result = tuple(result)
return result
return compiled_callable
def predictions(torch_func, jit_func, batchA, batchB):
res = jit_func(batchA.numpy(), batchB.numpy())
if res is not None:
# prediction = torch.from_numpy(res)
prediction = res
else:
prediction = None
return prediction
logger = logging.getLogger(__name__)
train_batch_size = hyperparameters["train_batch_size"]
gradient_accumulation_steps = hyperparameters[
"gradient_accumulation_steps"
]
learning_rate = hyperparameters["learning_rate"]
if hyperparameters["scale_lr"]:
learning_rate = (
learning_rate
* gradient_accumulation_steps
* train_batch_size
# * accelerator.num_processes
)
# Initialize the optimizer
optimizer = torch.optim.AdamW(
lora_layers.parameters(), # only optimize the embeddings
lr=learning_rate,
)
# Training function
def train_func(batch_pixel_values, batch_input_ids):
# Convert images to latent space
latents = shark_vae(batch_pixel_values).sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.num_train_timesteps,
(bsz,),
device=latents.device,
).long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch_input_ids)[0]
# Predict the noise residual
noise_pred = shark_unet(
noisy_latents,
timesteps,
encoder_hidden_states,
)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
)
loss = (
F.mse_loss(noise_pred, target, reduction="none")
.mean([1, 2, 3])
.mean()
)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
def training_function():
max_train_steps = hyperparameters["max_train_steps"]
output_dir = hyperparameters["output_dir"]
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
train_dataloader = create_dataloader(train_batch_size)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
num_train_epochs = math.ceil(
max_train_steps / num_update_steps_per_epoch
)
# Train!
total_batch_size = (
train_batch_size
* gradient_accumulation_steps
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(
f" Instantaneous batch size per device = {train_batch_size}"
)
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(
# range(max_train_steps), disable=not accelerator.is_local_main_process
range(max_train_steps)
)
progress_bar.set_description("Steps")
global_step = 0
params__ = [
i for i in text_encoder.get_input_embeddings().parameters()
]
for epoch in range(num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
dynamo_callable = dynamo.optimize(
refbackend_torchdynamo_backend
)(train_func)
lam_func = lambda x, y: dynamo_callable(
torch.from_numpy(x), torch.from_numpy(y)
)
loss = predictions(
train_func,
lam_func,
batch["pixel_values"],
batch["input_ids"],
)
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item()}
progress_bar.set_postfix(**logs)
if global_step >= max_train_steps:
break
training_function()
# Save the lora weights
unet.save_attn_procs(args.lora_save_dir)
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
if param.grad is not None:
del param.grad # free some memory
torch.cuda.empty_cache()
if __name__ == "__main__":
if args.clear_all:
clear_all()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
if len(args.prompts) != 1:
print("Need exactly one prompt for the LoRA word")
lora_train(
args.prompts[0],
args.height,
args.width,
args.training_steps,
args.guidance_scale,
args.seed,
args.batch_count,
args.batch_size,
args.scheduler,
"None",
args.hf_model_id,
args.precision,
args.device,
args.max_length,
args.training_images_dir,
args.lora_save_dir,
args.use_lora,
)

View File

@@ -0,0 +1,131 @@
import os
from pathlib import Path
from shark_tuner.codegen_tuner import SharkCodegenTuner
from shark_tuner.iree_utils import (
dump_dispatches,
create_context,
export_module_to_mlir_file,
)
from shark_tuner.model_annotation import model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.utils import set_init_device_flags
from apps.stable_diffusion.src.utils.sd_annotation import (
get_device_args,
load_winograd_configs,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def load_mlir_module():
if "upscaler" in args.hf_model_id:
is_upscaler = True
else:
is_upscaler = False
sd_model = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
max_len=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
is_upscaler=is_upscaler,
use_tuned=False,
low_cpu_mem_usage=args.low_cpu_mem_usage,
return_mlir=True,
)
if args.annotation_model == "unet":
mlir_module = sd_model.unet()
model_name = sd_model.model_name["unet"]
elif args.annotation_model == "vae":
mlir_module = sd_model.vae()
model_name = sd_model.model_name["vae"]
else:
raise ValueError(
f"{args.annotation_model} is not supported for tuning."
)
return mlir_module, model_name
def main():
args.use_tuned = False
set_init_device_flags()
mlir_module, model_name = load_mlir_module()
# Get device and device specific arguments
device, device_spec_args = get_device_args()
device_spec = ""
vulkan_target_triple = ""
if device_spec_args:
device_spec = device_spec_args[-1].split("=")[-1].strip()
if device == "vulkan":
vulkan_target_triple = device_spec
device_spec = device_spec.split("-")[0]
# Add winograd annotation for vulkan device
use_winograd = (
True
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
else False
)
winograd_config = (
load_winograd_configs()
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
else ""
)
with create_context() as ctx:
input_module = model_annotation(
ctx,
input_contents=mlir_module,
config_path=winograd_config,
search_op="conv",
winograd=use_winograd,
)
# Dump model dispatches
generates_dir = Path.home() / "tmp"
if not os.path.exists(generates_dir):
os.makedirs(generates_dir)
dump_mlir = generates_dir / "temp.mlir"
dispatch_dir = generates_dir / f"{model_name}_{device_spec}_dispatches"
export_module_to_mlir_file(input_module, dump_mlir)
dump_dispatches(
dump_mlir,
device,
dispatch_dir,
vulkan_target_triple,
use_winograd=use_winograd,
)
# Tune each dispatch
dtype = "f16" if args.precision == "fp16" else "f32"
config_filename = f"{model_name}_{device_spec}_configs.json"
for f_path in os.listdir(dispatch_dir):
if not f_path.endswith(".mlir"):
continue
model_dir = os.path.join(dispatch_dir, f_path)
tuner = SharkCodegenTuner(
model_dir,
device,
"random",
args.num_iters,
args.tuned_config_dir,
dtype,
args.search_op,
batch_size=1,
config_filename=config_filename,
use_dispatch=True,
vulkan_target_triple=vulkan_target_triple,
)
tuner.tune()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,88 @@
import torch
import transformers
import time
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
def main():
if args.clear_all:
clear_all()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += (
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
)
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,92 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
if __name__ == "__main__":
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
# When the models get uploaded, it should be defaulted to False.
args.import_mlir = True
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
image = (
Image.open(args.img_path)
.convert("RGB")
.resize((args.height, args.width))
)
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model
upscaler_obj = UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ddpm_scheduler=schedulers["DDPM"],
ondemand=args.ondemand,
)
start_time = time.time()
generated_imgs = upscaler_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.batch_size,
args.height,
args.width,
args.steps,
args.noise_level,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, noise_level={args.noise_level}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += upscaler_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
extra_info = {"NOISE LEVEL": args.noise_level}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)

View File

@@ -0,0 +1,45 @@
# -*- mode: python ; coding: utf-8 -*-
from apps.stable_diffusion.shark_studio_imports import pathex, datas, hiddenimports
binaries = []
block_cipher = None
a = Analysis(
['web/index.py'],
pathex=pathex,
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='nodai_shark_studio',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,85 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_submodules
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('py-cpuinfo')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
]
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['scripts/main.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,87 @@
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
# python path for pyinstaller
pathex = [
".",
"./apps/language_models/langchain",
"./apps/language_models/src/pipelines/minigpt4_utils",
]
# datafiles for pyinstaller
datas = []
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
datas += copy_metadata("packaging")
datas += copy_metadata("filelock")
datas += copy_metadata("numpy")
datas += copy_metadata("importlib_metadata")
datas += copy_metadata("torch-mlir")
datas += copy_metadata("omegaconf")
datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("pytorch_lightning")
datas += collect_data_files("skimage")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree")
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
datas += collect_data_files("cv2")
datas += [
("src/utils/resources/prompts.json", "resources"),
("src/utils/resources/model_db.json", "resources"),
("src/utils/resources/opt_flags.json", "resources"),
("src/utils/resources/base_model.json", "resources"),
("web/ui/css/*", "ui/css"),
("web/ui/logos/*", "logos"),
(
"../language_models/src/pipelines/minigpt4_utils/configs/*",
"minigpt4_utils/configs",
),
(
"../language_models/src/pipelines/minigpt4_utils/prompts/*",
"minigpt4_utils/prompts",
),
]
# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("diffusers") if "tests" not in x
]
blacklist = ["tests", "convert"]
hiddenimports += [
x
for x in collect_submodules("transformers")
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]

View File

@@ -0,0 +1,18 @@
from apps.stable_diffusion.src.utils import (
args,
set_init_device_flags,
prompt_examples,
get_available_devices,
clear_all,
save_output_img,
resize_stencil,
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,
StencilPipeline,
UpscalerPipeline,
)
from apps.stable_diffusion.src.schedulers import get_schedulers

View File

@@ -0,0 +1,12 @@
from apps.stable_diffusion.src.models.model_wrappers import (
SharkifyStableDiffusionModel,
)
from apps.stable_diffusion.src.models.opt_params import (
get_vae_encode,
get_vae,
get_unet,
get_clip,
get_tokenizer,
get_params,
get_variant_version,
)

View File

@@ -0,0 +1,896 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from collections import defaultdict
from pathlib import Path
import torch
import safetensors.torch
import traceback
import subprocess
import sys
import os
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
base_models,
args,
preprocessCKPT,
convert_original_vae,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
get_stencil_model_id,
update_lora_weight,
)
# These shapes are parameter dependent.
def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape = []
for i in range(len(shape)):
if shape[i] == "max_len":
new_shape.append(max_len)
elif shape[i] == "height":
new_shape.append(height)
elif shape[i] == "width":
new_shape.append(width)
elif isinstance(shape[i], str):
if "*" in shape[i]:
mul_val = int(shape[i].split("*")[0])
if "batch_size" in shape[i]:
new_shape.append(batch_size * mul_val)
elif "height" in shape[i]:
new_shape.append(height * mul_val)
elif "width" in shape[i]:
new_shape.append(width * mul_val)
elif "/" in shape[i]:
import math
div_val = int(shape[i].split("/")[1])
if "batch_size" in shape[i]:
new_shape.append(math.ceil(batch_size / div_val))
elif "height" in shape[i]:
new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val))
else:
new_shape.append(shape[i])
return new_shape
def check_compilation(model, model_name):
if not model:
raise Exception(
f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
)
class SharkifyStableDiffusionModel:
def __init__(
self,
model_id: str,
custom_weights: str,
custom_vae: str,
precision: str,
max_len: int = 64,
width: int = 512,
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
low_cpu_mem_usage: bool = False,
debug: bool = False,
sharktank_dir: str = "",
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
use_lora: str = "",
use_quantize: str = None,
return_mlir: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.use_quantize = use_quantize
if custom_weights != "":
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = (
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
)
if not os.path.isfile(weights_path):
subprocess.run(
["wget", custom_weights, "-O", weights_path]
)
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(
custom_weights
)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
self.model_id = "stabilityai/stable-diffusion-2-1-base"
self.custom_vae = custom_vae
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
"_"
+ str(batch_size)
+ "_"
+ str(max_len)
+ "_"
+ str(height)
+ "_"
+ str(width)
+ "_"
+ precision
)
print(f"use_tuned? sharkify: {use_tuned}")
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.is_upscaler = is_upscaler
self.use_stencil = get_stencil_model_id(use_stencil)
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora
print(self.model_name)
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
self.sharktank_dir = sharktank_dir
self.generate_vmfb = generate_vmfb
self.inputs = dict()
self.model_to_run = ""
if self.custom_weights != "":
self.model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights, self.is_inpaint)
else:
self.model_to_run = args.hf_model_id
self.custom_vae = self.process_custom_vae()
self.base_model_id = fetch_and_update_base_model_id(self.model_to_run)
if self.base_model_id != "" and args.ckpt_loc != "":
args.hf_model_id = self.base_model_id
self.return_mlir = return_mlir
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = [
"clip",
"unet",
"unet512",
"stencil_unet",
"stencil_unet_512",
"vae",
"vae_encode",
"stencil_adaptor",
"stencil_adaptor_512",
]
index = 0
for model in sub_model_list:
sub_model = model
model_config = self.model_name
if "vae" == model:
if self.custom_vae != "":
model_config = model_config + get_path_stem(
self.custom_vae
)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
model_config = model_config + get_path_stem(self.use_stencil)
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
sys.exit("please specify max_len in the range [32, 77].")
if not (width % 8 == 0 and width >= 128):
sys.exit("width should be greater than 128 and multiple of 8")
if not (height % 8 == 0 and height >= 128):
sys.exit("height should be greater than 128 and multiple of 8")
# Get the input info for a model i.e. "unet", "clip", "vae", etc.
def get_input_info_for(self, model_info):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = []
for inp in model_info:
shape = model_info[inp]["shape"]
dtype = dtype_config[model_info[inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape,
self.max_len,
self.width,
self.height,
self.batch_size,
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map.append(tensor)
return input_map
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
def forward(self, input):
latents = self.vae.encode(input).latent_dist.sample()
return 0.18215 * latents
vae_encode = VaeEncodeModel()
inputs = tuple(self.inputs["vae_encode"])
is_f16 = (
True
if not self.is_upscaler and self.precision == "fp16"
else False
)
shark_vae_encode, vae_encode_mlir = compile_through_fx(
vae_encode,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae_encode"],
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae_encode",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae_encode, vae_encode_mlir
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae
def forward(self, input):
if not self.base_vae:
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
if self.base_vae:
return x
x = x * 255.0
return x.round()
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = (
True
if not self.is_upscaler and self.precision == "fp16"
else False
)
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae, vae_mlir = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae, vae_mlir
def get_controlled_unet(self, use_large=False):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
def forward(
self,
latent,
timestep,
text_embedding,
guidance_scale,
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
control13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple(
[
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]
)
mb_res_samples = control13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
down_block_additional_residuals=db_res_samples,
mid_block_additional_residual=mb_res_samples,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = ControlledUnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
model_name = "stencil_unet"
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[:2]
+ (torch.nn.functional.pad(inputs[2], pad),)
+ inputs[3:]
)
model_name = "stencil_unet_512"
input_mask = [
True,
True,
True,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self, use_large=False):
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
):
super().__init__()
self.cnet = ControlNetModel.from_pretrained(
model_id,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.cnet.in_channels
self.train(False)
def forward(
self,
latent,
timestep,
text_embedding,
stencil_image_input,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
# TODO: guidance NOT NEEDED change in `get_input_info` later
latents = torch.cat(
[latent] * 2
) # needs to be same as controlledUNET latents
stencil_image = torch.cat(
[stencil_image_input] * 2
) # needs to be same as controlledUNET latents
(
down_block_res_samples,
mid_block_res_sample,
) = self.cnet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
controlnet_cond=stencil_image,
return_dict=False,
)
return tuple(
list(down_block_res_samples) + [mid_block_res_sample]
)
scnet = StencilControlNetModel(
low_cpu_mem_usage=self.low_cpu_mem_usage
)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor"]
)
input_mask = [True, True, True, True]
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_cnet, cnet_mlir
def get_unet(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.config.in_channels
self.train(False)
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
# TODO: Instead of flattening the `control` try to use the list.
def forward(
self,
latent,
timestep,
text_embedding,
guidance_scale,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["unet512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["unet"]
)
input_mask = [True, True, True, False]
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_unet_upscaler(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, latent, timestep, text_embedding, noise_level):
unet_out = self.unet.forward(
latent,
timestep,
text_embedding,
noise_level,
return_dict=False,
)[0]
return unet_out
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
input_mask = [True, True, True, False]
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(
self.text_encoder, use_lora, "text_encoder"
)
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
save_dir = ""
if self.debug:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["clip"]
)
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip, clip_mlir = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
extended_model_name=self.model_name["clip"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip, clip_mlir
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
if not custom_vae.endswith((".ckpt", ".safetensors")):
return self.custom_vae
try:
preprocessCKPT(self.custom_vae)
return get_path_to_diffusers_checkpoint(self.custom_vae)
except:
print("Processing standalone Vae checkpoint")
vae_checkpoint = None
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
if custom_vae.endswith(".ckpt"):
vae_checkpoint = torch.load(
self.custom_vae, map_location="cpu"
)
else:
vae_checkpoint = safetensors.torch.load_file(
self.custom_vae, device="cpu"
)
if "state_dict" in vae_checkpoint:
vae_checkpoint = vae_checkpoint["state_dict"]
try:
vae_checkpoint = convert_original_vae(vae_checkpoint)
finally:
vae_dict = {
k: v
for k, v in vae_checkpoint.items()
if k[0:4] != "loss" and k not in vae_ignore_keys
}
return vae_dict
def compile_unet_variants(self, model, use_large=False):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler(use_large=use_large)
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import (
get_unet,
)
return get_unet()
else:
return self.get_unet(use_large=use_large)
else:
return self.get_controlled_unet(use_large=use_large)
def vae_encode(self):
try:
self.inputs["vae_encode"] = self.get_input_info_for(
base_models["vae_encode"]
)
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
check_compilation(compiled_vae_encode, "Vae Encode")
if self.return_mlir:
return vae_encode_mlir
return compiled_vae_encode
except Exception as e:
sys.exit(e)
def clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(base_models["clip"])
compiled_clip, clip_mlir = self.get_clip()
check_compilation(compiled_clip, "Clip")
if self.return_mlir:
return clip_mlir
return compiled_clip
except Exception as e:
sys.exit(e)
def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
compiled_unet = None
unet_inputs = base_models[model]
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[self.base_model_id]
)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[model_id]
)
try:
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
)
except Exception as e:
print(e)
print(
"Retrying with a different base model configuration"
)
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(self.model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
break
check_compilation(compiled_unet, "Unet")
if self.return_mlir:
return unet_mlir
return compiled_unet
except Exception as e:
sys.exit(e)
def vae(self):
try:
vae_input = (
base_models["vae"]["vae_upscaler"]
if self.is_upscaler
else base_models["vae"]["vae"]
)
self.inputs["vae"] = self.get_input_info_for(vae_input)
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")
if self.return_mlir:
return vae_mlir
return compiled_vae
except Exception as e:
sys.exit(e)
def controlnet(self, use_large=False):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
use_large=use_large
)
check_compilation(compiled_stencil_adaptor, "Stencil")
if self.return_mlir:
return controlnet_mlir
return compiled_stencil_adaptor
except Exception as e:
sys.exit(e)

View File

@@ -0,0 +1,130 @@
import sys
from transformers import CLIPTokenizer
from apps.stable_diffusion.src.utils import (
models_db,
args,
get_shark_model,
get_opt_flags,
)
hf_model_variant_map = {
"Linaqruf/anything-v3.0": ["anythingv3", "v1_4"],
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v1_4"],
"prompthero/openjourney": ["openjourney", "v1_4"],
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v1_4"],
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
"stabilityai/stable-diffusion-2-inpainting": [
"stablediffusion",
"inpaint_v2",
],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
bucket_key = "gs://shark_tank/prashant_nod"
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit(
"The int8 quantized model currently requires the height and width to be 512, and max_length to be 77"
)
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
def get_params(bucket_key, model_key, model, is_tuned, precision):
try:
bucket = models_db[0][bucket_key]
model_name = models_db[1][model_key]
except KeyError:
raise Exception(
f"{bucket_key}/{model_key} is not present in the models database"
)
iree_flags = get_opt_flags(model, precision="fp16")
return bucket, model_name, iree_flags
def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
# TODO: Get the quantize model from model_db.json
if args.use_quantize == "int8":
bk, mk, flags = get_quantize_model()
return get_shark_model(bk, mk, flags)
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "unet", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae_encode():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
is_base = "/base" if args.use_base_vae else ""
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
variant, version = get_variant_version(args.hf_model_id)
bucket_key = f"{variant}/untuned"
model_key = (
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
)
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "clip", "untuned", "fp32"
)
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
)
return tokenizer

View File

@@ -0,0 +1,18 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
InpaintPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpaint import (
OutpaintPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_stencil import (
StencilPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_upscaler import (
UpscalerPipeline,
)

View File

@@ -0,0 +1,238 @@
import torch
import time
import numpy as np
from tqdm.auto import tqdm
from random import randint
from PIL import Image
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class Image2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_image_latents(
self,
image,
batch_size,
height,
width,
generator,
num_inference_steps,
strength,
dtype,
resample_type,
):
# Pre process image -> get image encoded -> process latents
# TODO: process with variable HxW combos
# Pre-process image
if resample_type == "Lanczos":
resample_type = Image.LANCZOS
elif resample_type == "Nearest Neighbor":
resample_type = Image.NEAREST
elif resample_type == "Bilinear":
resample_type = Image.BILINEAR
elif resample_type == "Bicubic":
resample_type = Image.BICUBIC
elif resample_type == "Adaptive":
resample_type = Image.ADAPTIVE
elif resample_type == "Antialias":
resample_type = Image.ANTIALIAS
elif resample_type == "Box":
resample_type = Image.BOX
elif resample_type == "Affine":
resample_type = Image.AFFINE
elif resample_type == "Cubic":
resample_type = Image.CUBIC
else: # Fallback to Lanczos
resample_type = Image.LANCZOS
image = image.resize((width, height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
image_arr = 2 * (image_arr - 0.5)
# set scheduler steps
self.scheduler.set_timesteps(num_inference_steps)
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps
)
t_start = max(num_inference_steps - init_timestep, 0)
# timesteps reduced as per strength
timesteps = self.scheduler.timesteps[t_start:]
# new number of steps to be used as per strength will be
# num_inference_steps = num_inference_steps - t_start
# image encode
latents = self.encode_image((image_arr,))
latents = torch.from_numpy(latents).to(dtype)
# add noise to data
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
latents = self.scheduler.add_noise(
latents, noise, timesteps[0].repeat(1)
)
return latents, timesteps
def encode_image(self, input_image):
self.load_vae_encode()
vae_encode_start = time.time()
latents = self.vae_encode("forward", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
if self.ondemand:
self.unload_vae_encode()
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
return latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
strength,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare input image latent
image_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
# Get Image latents
latents = self.produce_img_latents(
latents=image_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,487 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from PIL import Image, ImageOps
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class InpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
latents = latents * self.scheduler.init_noise_sigma
return latents
def get_crop_region(self, mask, pad=0):
h, w = mask.shape
crop_left = 0
for i in range(w):
if not (mask[:, i] == 0).all():
break
crop_left += 1
crop_right = 0
for i in reversed(range(w)):
if not (mask[:, i] == 0).all():
break
crop_right += 1
crop_top = 0
for i in range(h):
if not (mask[i] == 0).all():
break
crop_top += 1
crop_bottom = 0
for i in reversed(range(h)):
if not (mask[i] == 0).all():
break
crop_bottom += 1
return (
int(max(crop_left - pad, 0)),
int(max(crop_top - pad, 0)),
int(min(w - crop_right + pad, w)),
int(min(h - crop_bottom + pad, h)),
)
def expand_crop_region(
self,
crop_region,
processing_width,
processing_height,
image_width,
image_height,
):
x1, y1, x2, y2 = crop_region
ratio_crop_region = (x2 - x1) / (y2 - y1)
ratio_processing = processing_width / processing_height
if ratio_crop_region > ratio_processing:
desired_height = (x2 - x1) / ratio_processing
desired_height_diff = int(desired_height - (y2 - y1))
y1 -= desired_height_diff // 2
y2 += desired_height_diff - desired_height_diff // 2
if y2 >= image_height:
diff = y2 - image_height
y2 -= diff
y1 -= diff
if y1 < 0:
y2 -= y1
y1 -= y1
if y2 >= image_height:
y2 = image_height
else:
desired_width = (y2 - y1) * ratio_processing
desired_width_diff = int(desired_width - (x2 - x1))
x1 -= desired_width_diff // 2
x2 += desired_width_diff - desired_width_diff // 2
if x2 >= image_width:
diff = x2 - image_width
x2 -= diff
x1 -= diff
if x1 < 0:
x2 -= x1
x1 -= x1
if x2 >= image_width:
x2 = image_width
return x1, y1, x2, y2
def resize_image(self, resize_mode, im, width, height):
"""
resize_mode:
0: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
1: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
"""
if resize_mode == 0:
ratio = width / height
src_ratio = im.width / im.height
src_w = (
width if ratio > src_ratio else im.width * height // im.height
)
src_h = (
height if ratio <= src_ratio else im.height * width // im.width
)
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(
resized,
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
)
else:
ratio = width / height
src_ratio = im.width / im.height
src_w = (
width if ratio < src_ratio else im.width * height // im.height
)
src_h = (
height if ratio >= src_ratio else im.height * width // im.width
)
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(
resized,
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
)
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
res.paste(
resized.resize((width, fill_height), box=(0, 0, width, 0)),
box=(0, 0),
)
res.paste(
resized.resize(
(width, fill_height),
box=(0, resized.height, width, resized.height),
),
box=(0, fill_height + src_h),
)
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
res.paste(
resized.resize(
(fill_width, height), box=(0, 0, 0, height)
),
box=(0, 0),
)
res.paste(
resized.resize(
(fill_width, height),
box=(resized.width, 0, resized.width, height),
),
box=(fill_width + src_w, 0),
)
return res
def prepare_mask_and_masked_image(
self,
image,
mask,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
):
# preprocess image
image = image.resize((width, height))
mask = mask.resize((width, height))
paste_to = ()
overlay_image = None
if inpaint_full_res:
# prepare overlay image
overlay_image = Image.new("RGB", (image.width, image.height))
overlay_image.paste(
image.convert("RGB"),
mask=ImageOps.invert(mask.convert("L")),
)
# prepare mask
mask = mask.convert("L")
crop_region = self.get_crop_region(
np.array(mask), inpaint_full_res_padding
)
crop_region = self.expand_crop_region(
crop_region, width, height, mask.width, mask.height
)
x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region)
mask = self.resize_image(1, mask, width, height)
paste_to = (x1, y1, x2 - x1, y2 - y1)
# prepare image
image = image.crop(crop_region)
image = self.resize_image(1, image, width, height)
if isinstance(image, (Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image, paste_to, overlay_image
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
):
mask = torch.nn.functional.interpolate(
mask, size=(height // 8, width // 8)
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
return mask, masked_image_latents
def apply_overlay(self, image, paste_loc, overlay):
x, y, w, h = paste_loc
image = self.resize_image(0, image, w, h)
overlay.paste(image, (x, y))
return overlay
def generate_images(
self,
prompts,
neg_prompts,
image,
mask_image,
batch_size,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Preprocess mask and image
(
mask,
masked_image,
paste_to,
overlay_image,
) = self.prepare_mask_and_masked_image(
image,
mask_image,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
)
# Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask=mask,
masked_image=masked_image,
batch_size=batch_size,
height=height,
width=width,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
mask=mask,
masked_image_latents=masked_image_latents,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
if inpaint_full_res:
output_image = self.apply_overlay(
all_imgs[0], paste_to, overlay_image
)
return [output_image]
return all_imgs

View File

@@ -0,0 +1,581 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from PIL import Image, ImageDraw, ImageFilter
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
import math
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class OutpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_mask_and_masked_image(
self, image, mask, mask_blur, width, height
):
if mask_blur > 0:
mask = mask.filter(ImageFilter.GaussianBlur(mask_blur))
image = image.resize((width, height))
mask = mask.resize((width, height))
# preprocess image
if isinstance(image, (Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
):
mask = torch.nn.functional.interpolate(
mask, size=(height // 8, width // 8)
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
return mask, masked_image_latents
def get_matched_noise(
self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05
):
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(data):
if data.ndim > 2: # has channels
out_fft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]),
dtype=np.complex128,
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_fft[:, :, c] = np.fft.fft2(
np.fft.fftshift(c_data), norm="ortho"
)
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
else: # one channel
out_fft = np.zeros(
(data.shape[0], data.shape[1]), dtype=np.complex128
)
out_fft[:, :] = np.fft.fft2(
np.fft.fftshift(data), norm="ortho"
)
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
return out_fft
def _ifft2(data):
if data.ndim > 2: # has channels
out_ifft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]),
dtype=np.complex128,
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_ifft[:, :, c] = np.fft.ifft2(
np.fft.fftshift(c_data), norm="ortho"
)
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
else: # one channel
out_ifft = np.zeros(
(data.shape[0], data.shape[1]), dtype=np.complex128
)
out_ifft[:, :] = np.fft.ifft2(
np.fft.fftshift(data), norm="ortho"
)
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
return out_ifft
def _get_gaussian_window(width, height, std=3.14, mode=0):
window_scale_x = float(width / min(width, height))
window_scale_y = float(height / min(width, height))
window = np.zeros((width, height))
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
for y in range(height):
fy = (y / height * 2.0 - 1.0) * window_scale_y
if mode == 0:
window[:, y] = np.exp(-(x**2 + fy**2) * std)
else:
window[:, y] = (
1 / ((x**2 + 1.0) * (fy**2 + 1.0))
) ** (std / 3.14)
return window
def _get_masked_window_rgb(np_mask_grey, hardness=1.0):
np_mask_rgb = np.zeros(
(np_mask_grey.shape[0], np_mask_grey.shape[1], 3)
)
if hardness != 1.0:
hardened = np_mask_grey[:] ** hardness
else:
hardened = np_mask_grey[:]
for c in range(3):
np_mask_rgb[:, :, c] = hardened[:]
return np_mask_rgb
def _match_cumulative_cdf(source, template):
src_values, src_unique_indices, src_counts = np.unique(
source.ravel(), return_inverse=True, return_counts=True
)
tmpl_values, tmpl_counts = np.unique(
template.ravel(), return_counts=True
)
# calculate normalized quantiles for each array
src_quantiles = np.cumsum(src_counts) / source.size
tmpl_quantiles = np.cumsum(tmpl_counts) / template.size
interp_a_values = np.interp(
src_quantiles, tmpl_quantiles, tmpl_values
)
return interp_a_values[src_unique_indices].reshape(source.shape)
def _match_histograms(image, reference):
if image.ndim != reference.ndim:
raise ValueError(
"Image and reference must have the same number of channels."
)
if image.shape[-1] != reference.shape[-1]:
raise ValueError(
"Number of channels in the input image and reference image must match!"
)
matched = np.empty(image.shape, dtype=image.dtype)
for channel in range(image.shape[-1]):
matched_channel = _match_cumulative_cdf(
image[..., channel], reference[..., channel]
)
matched[..., channel] = matched_channel
matched = matched.astype(np.float64, copy=False)
return matched
width = _np_src_image.shape[0]
height = _np_src_image.shape[1]
num_channels = _np_src_image.shape[2]
np_src_image = _np_src_image[:] * (1.0 - np_mask_rgb)
np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
# rather than leave the masked area black, we get better results from fft by filling the average unmasked color
windowed_image = _np_src_image * (
1.0 - _get_masked_window_rgb(np_mask_grey)
)
windowed_image /= np.max(windowed_image)
windowed_image += np.average(_np_src_image) * np_mask_rgb
src_fft = _fft2(
windowed_image
) # get feature statistics from masked src img
src_dist = np.absolute(src_fft)
src_phase = src_fft / src_dist
# create a generator with a static seed to make outpainting deterministic / only follow global seed
rng = np.random.default_rng(0)
noise_window = _get_gaussian_window(
width, height, mode=1
) # start with simple gaussian noise
noise_rgb = rng.random((width, height, num_channels))
noise_grey = np.sum(noise_rgb, axis=2) / 3.0
# the colorfulness of the starting noise is blended to greyscale with a parameter
noise_rgb *= color_variation
for c in range(num_channels):
noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
noise_fft = _fft2(noise_rgb)
for c in range(num_channels):
noise_fft[:, :, c] *= noise_window
noise_rgb = np.real(_ifft2(noise_fft))
shaped_noise_fft = _fft2(noise_rgb)
shaped_noise_fft[:, :, :] = (
np.absolute(shaped_noise_fft[:, :, :]) ** 2
* (src_dist**noise_q)
* src_phase
) # perform the actual shaping
# color_variation
brightness_variation = 0.0
contrast_adjusted_np_src = (
_np_src_image[:] * (brightness_variation + 1.0)
- brightness_variation * 2.0
)
shaped_noise = np.real(_ifft2(shaped_noise_fft))
shaped_noise -= np.min(shaped_noise)
shaped_noise /= np.max(shaped_noise)
shaped_noise[img_mask, :] = _match_histograms(
shaped_noise[img_mask, :] ** 1.0,
contrast_adjusted_np_src[ref_mask, :],
)
shaped_noise = (
_np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
)
matched_noise = shaped_noise[:]
return np.clip(matched_noise, 0.0, 1.0)
def generate_images(
self,
prompts,
neg_prompts,
image,
pixels,
mask_blur,
is_left,
is_right,
is_top,
is_bottom,
noise_q,
color_variation,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
process_width = width
process_height = height
left = pixels if is_left else 0
right = pixels if is_right else 0
up = pixels if is_top else 0
down = pixels if is_bottom else 0
target_w = math.ceil((image.width + left + right) / 64) * 64
target_h = math.ceil((image.height + up + down) / 64) * 64
if left > 0:
left = left * (target_w - image.width) // (left + right)
if right > 0:
right = target_w - image.width - left
if up > 0:
up = up * (target_h - image.height) // (up + down)
if down > 0:
down = target_h - image.height - up
def expand(
init_img,
expand_pixels,
is_left=False,
is_right=False,
is_top=False,
is_bottom=False,
):
is_horiz = is_left or is_right
is_vert = is_top or is_bottom
pixels_horiz = expand_pixels if is_horiz else 0
pixels_vert = expand_pixels if is_vert else 0
res_w = init_img.width + pixels_horiz
res_h = init_img.height + pixels_vert
process_res_w = math.ceil(res_w / 64) * 64
process_res_h = math.ceil(res_h / 64) * 64
img = Image.new("RGB", (process_res_w, process_res_h))
img.paste(
init_img,
(pixels_horiz if is_left else 0, pixels_vert if is_top else 0),
)
msk = Image.new("RGB", (process_res_w, process_res_h), "white")
draw = ImageDraw.Draw(msk)
draw.rectangle(
(
expand_pixels + mask_blur if is_left else 0,
expand_pixels + mask_blur if is_top else 0,
msk.width - expand_pixels - mask_blur
if is_right
else res_w,
msk.height - expand_pixels - mask_blur
if is_bottom
else res_h,
),
fill="black",
)
np_image = (np.asarray(img) / 255.0).astype(np.float64)
np_mask = (np.asarray(msk) / 255.0).astype(np.float64)
noised = self.get_matched_noise(
np_image, np_mask, noise_q, color_variation
)
output_image = Image.fromarray(
np.clip(noised * 255.0, 0.0, 255.0).astype(np.uint8),
mode="RGB",
)
target_width = (
min(width, init_img.width + pixels_horiz)
if is_horiz
else img.width
)
target_height = (
min(height, init_img.height + pixels_vert)
if is_vert
else img.height
)
crop_region = (
0 if is_left else output_image.width - target_width,
0 if is_top else output_image.height - target_height,
target_width if is_left else output_image.width,
target_height if is_top else output_image.height,
)
mask_to_process = msk.crop(crop_region)
image_to_process = output_image.crop(crop_region)
# Preprocess mask and image
mask, masked_image = self.prepare_mask_and_masked_image(
image_to_process, mask_to_process, mask_blur, width, height
)
# Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask=mask,
masked_image=masked_image,
batch_size=batch_size,
height=height,
width=width,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
mask=mask,
masked_image_latents=masked_image_latents,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
res_img = all_imgs[0].resize(
(image_to_process.width, image_to_process.height)
)
output_image.paste(
res_img,
(
0 if is_left else output_image.width - res_img.width,
0 if is_top else output_image.height - res_img.height,
),
)
output_image = output_image.crop((0, 0, res_w, res_h))
return output_image
img = image.resize((width, height))
if left > 0:
img = expand(img, left, is_left=True)
if right > 0:
img = expand(img, right, is_right=True)
if up > 0:
img = expand(img, up, is_top=True)
if down > 0:
img = expand(img, down, is_bottom=True)
return [img]

View File

@@ -0,0 +1,346 @@
import torch
import time
import numpy as np
from tqdm.auto import tqdm
from random import randint
from PIL import Image
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import controlnet_hint_conversion
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class StencilPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
self.controlnet_512 = None
def load_controlnet(self):
if self.controlnet is not None:
return
self.controlnet = self.sd_model.controlnet()
def unload_controlnet(self):
del self.controlnet
self.controlnet = None
def load_controlnet_512(self):
if self.controlnet_512 is not None:
return
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
def unload_controlnet_512(self):
del self.controlnet_512
self.controlnet_512 = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_stencil_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
controlnet_conditioning_scale: float = 1.0,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
self.load_controlnet()
else:
self.load_unet_512()
self.load_controlnet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
else:
control = self.controlnet_512(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
else:
print(self.unet_512)
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
self.unload_controlnet()
self.unload_controlnet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
strength,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
controlnet_hint = controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
)
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare initial latent.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
final_timesteps = self.scheduler.timesteps
# Get Image latents
latents = self.produce_stencil_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,156 @@
import torch
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,357 @@
import inspect
import torch
import time
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_IDLE,
SD_STATE_CANCEL,
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from PIL import Image
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, Image.Image):
image = [image]
if isinstance(image[0], Image.Image):
w, h = image[0].size
w, h = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class UpscalerPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
low_res_scheduler: Union[
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.low_res_scheduler = low_res_scheduler
self.status = SD_STATE_IDLE
def prepare_extra_step_kwargs(self, generator, eta):
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
latents = 1 / 0.08333 * (latents.float())
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height,
width,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_img_latents(
self,
latents,
image,
text_embeddings,
guidance_scale,
noise_level,
total_timesteps,
dtype,
cpu_scheduling,
extra_step_kwargs,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.status = SD_STATE_IDLE
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = torch.cat([latent_model_input, image], dim=1)
timestep = torch.tensor([t]).to(dtype).detach().numpy()
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if cpu_scheduling:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
else:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
noise_level,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# 4. Preprocess image
image = preprocess(image).to(dtype)
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long)
noise = torch.randn(
image.shape,
generator=generator,
).to(dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
image = torch.cat([image] * 2)
noise_level = torch.cat([noise_level] * image.shape[0])
height, width = image.shape[2:]
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
eta = 0.0
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# guidance scale as a float32 tensor.
# guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
image=image,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
noise_level=noise_level,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
extra_step_kwargs=extra_step_kwargs,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,939 @@
import torch
import numpy as np
from transformers import CLIPTokenizer
from PIL import Image
from tqdm.auto import tqdm
import time
from typing import Union
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae,
get_clip,
get_unet,
get_tokenizer,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
import sys
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
class StableDiffusionPipeline:
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
self.vae = None
self.text_encoder = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
self.sd_model = sd_model
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
# TODO: Find a better workaround for fetching base_model_id early
# enough for CLIPTokenizer.
try:
self.tokenizer = get_tokenizer()
except:
self.load_unet()
self.unload_unet()
self.tokenizer = get_tokenizer()
def load_clip(self):
if self.text_encoder is not None:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder = self.sd_model.clip()
else:
try:
self.text_encoder = get_clip()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.text_encoder = self.sd_model.clip()
def unload_clip(self):
del self.text_encoder
self.text_encoder = None
def load_unet(self):
if self.unet is not None:
return
if self.import_mlir or self.use_lora:
self.unet = self.sd_model.unet()
else:
try:
self.unet = get_unet()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.unet = self.sd_model.unet()
def unload_unet(self):
del self.unet
self.unet = None
def load_unet_512(self):
if self.unet_512 is not None:
return
if self.import_mlir or self.use_lora:
self.unet_512 = self.sd_model.unet(use_large=True)
else:
try:
self.unet_512 = get_unet(use_large=True)
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.unet_512 = self.sd_model.unet(use_large=True)
def unload_unet_512(self):
del self.unet_512
self.unet_512 = None
def load_vae(self):
if self.vae is not None:
return
if self.import_mlir or self.use_lora:
self.vae = self.sd_model.vae()
else:
try:
self.vae = get_vae()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.vae = self.sd_model.vae()
def unload_vae(self):
del self.vae
self.vae = None
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
text_input = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
# Get unconditional embeddings as well
uncond_input = self.tokenizer(
neg_prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
self.load_clip()
clip_inf_start = time.time()
text_embeddings = self.text_encoder("forward", (text_input,))
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
if use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
if use_base_vae:
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def produce_img_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
self.status = SD_STATE_IDLE
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
@classmethod
def from_pretrained(
cls,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
import_mlir: bool,
model_id: str,
ckpt_loc: str,
custom_vae: str,
precision: str,
max_length: int,
batch_size: int,
height: int,
width: int,
use_base_vae: bool,
use_tuned: bool,
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
if (
not import_mlir
and not use_lora
and cls.__name__ == "StencilPipeline"
):
sys.exit("StencilPipeline not supported with SharkTank currently.")
is_inpaint = cls.__name__ in [
"InpaintPipeline",
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
sd_model = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
use_quantize=use_quantize,
)
if cls.__name__ in ["UpscalerPipeline"]:
return cls(
scheduler,
ddpm_scheduler,
sd_model,
import_mlir,
use_lora,
ondemand,
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
# #####################################################
# Implements text embeddings with weights from prompts
# https://huggingface.co/AlanB/lpw_stable_diffusion_mod
# #####################################################
def encode_prompts_weight(
self,
prompt,
negative_prompt,
model_max_length,
do_classifier_free_guidance=True,
max_embeddings_multiples=1,
num_images_per_prompt=1,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation.
Ignored when not using guidance
(i.e., ignored if `guidance_scale` is less than `1`).
model_max_length (int):
SHARK: pass the max length instead of relying on
pipe.tokenizer.model_max_length
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not,
SHARK: must be set to True as we always expect neg embeddings
(defaulted to True)
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the
max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error
(defaulted to 1)
num_images_per_prompt (`int`):
number of images that should be generated per prompt
SHARK: num_images_per_prompt is not used (defaulted to 1)
"""
# SHARK: Save model_max_length, load the clip and init inference time
self.model_max_length = model_max_length
self.load_clip()
clip_inf_start = time.time()
batch_size = len(prompt) if isinstance(prompt, list) else 1
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: "
f"{negative_prompt} has batch size {len(negative_prompt)}, "
f"but `prompt`: {prompt} has batch size {batch_size}. "
f"Please make sure that passed `negative_prompt` matches "
"the batch size of `prompt`."
)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt
if do_classifier_free_guidance
else None,
max_embeddings_multiples=max_embeddings_multiples,
)
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = text_embeddings.shape
# text_embeddings = text_embeddings.repeat(
# 1,
# num_images_per_prompt,
# 1
# )
# text_embeddings = (
# text_embeddings.view(
# bs_embed * num_images_per_prompt,
# seq_len,
# -1
# )
# )
if do_classifier_free_guidance:
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = uncond_embeddings.shape
# uncond_embeddings = (
# uncond_embeddings.repeat(
# 1,
# num_images_per_prompt,
# 1
# )
# )
# uncond_embeddings = (
# uncond_embeddings.view(
# bs_embed * num_images_per_prompt,
# seq_len,
# -1
# )
# )
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if text_embeddings.shape[1] > model_max_length:
pad = (0, 0) * (len(text_embeddings.shape) - 2)
pad = pad + (0, 512 - text_embeddings.shape[1])
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()
from typing import List, Optional, Union
import re
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs:
text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(
pipe: StableDiffusionPipeline, prompt: List[str], max_length: int
):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens,
weights,
max_length,
bos,
eos,
no_boseos_middle=True,
chunk_length=77,
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length
if no_boseos_middle
else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
tokens[i] = (
[bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
)
if no_boseos_middle:
weights[i] = (
[1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
)
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j
* (chunk_length - 2) : min(
len(weights[i]), (j + 1) * (chunk_length - 2)
)
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
# text_embedding = pipe.text_encoder(text_input_chunk)[0]
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
formatted_text_input_chunk = torch.cat(
[text_input_chunk, text_input_chunk]
)
text_embedding = pipe.text_encoder(
"forward", (formatted_text_input_chunk,)
)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
else:
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
# Convert the result to tensor
# text_embeddings = pipe.text_encoder(text_input)[0]
formatted_text_input = torch.cat([text_input, text_input])
text_embeddings = pipe.text_encoder(
"forward", (formatted_text_input,)
)[0]
text_embeddings = torch.from_numpy(text_embeddings)[None, :]
return text_embeddings
# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
return [[49407 if token is None else token for token in tokens[0]]]
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
pipe (`StableDiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(
pipe, prompt, max_length - 2
)
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(
prompt, max_length=max_length, truncation=True
).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(
max_length, max([len(token) for token in uncond_tokens])
)
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
prompt_tokens = filter_nonetype_tokens(prompt_tokens)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
uncond_tokens = filter_nonetype_tokens(uncond_tokens)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(
uncond_tokens, dtype=torch.long, device="cpu"
)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
prompt_weights = torch.tensor(
prompt_weights, dtype=torch.float, device="cpu"
)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
uncond_weights = torch.tensor(
uncond_weights, dtype=torch.float, device="cpu"
)
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@@ -0,0 +1,4 @@
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)

View File

@@ -0,0 +1,104 @@
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
def get_schedulers(model_id):
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
)
schedulers[
"DPMSolverMultistep++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
)
schedulers[
"DPMSolverMultistepKarras"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_karras_sigmas=True,
)
schedulers[
"DPMSolverMultistepKarras++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"SharkEulerDiscrete"
] = SharkEulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverSinglestep"
] = DPMSolverSinglestepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"KDPM2AncestralDiscrete"
] = KDPM2AncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers

View File

@@ -0,0 +1,154 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_shark_model,
args,
)
import torch
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
)
def compile(self):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
device = args.device.split(":", 1)[0].strip()
model_input = {
"euler": {
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
}
example_latent = model_input["euler"]["latent"]
example_output = model_input["euler"]["output"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_sigma = model_input["euler"]["sigma"]
example_dt = model_input["euler"]["dt"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
def _import(self):
scaling_model = ScalingModel()
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model, _ = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
if args.import_mlir:
_import(self)
else:
try:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_step_" + args.precision,
iree_flags,
)
except:
print(
"failed to download model, falling back and using import_mlir"
)
args.import_mlir = True
_import(self)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
return self.scaling_model(
"forward",
(
sample,
sigma,
),
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
return self.step_model(
"forward",
(
noise_pred,
sigma,
latent,
dt,
),
send_to_host=False,
)

View File

@@ -0,0 +1,43 @@
from apps.stable_diffusion.src.utils.profiler import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.utils.resources import (
prompt_examples,
models_db,
base_models,
opt_flags,
resource_path,
)
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
controlnet_hint_conversion,
get_stencil_model_id,
)
from apps.stable_diffusion.src.utils.utils import (
get_shark_model,
compile_through_fx,
set_iree_runtime_flags,
map_device_to_name_path,
set_init_device_flags,
get_available_devices,
get_opt_flags,
preprocessCKPT,
convert_original_vae,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
parse_seed_input,
batch_seeds,
get_path_stem,
get_extended_name,
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
clear_all,
save_output_img,
get_generation_text_info,
update_lora_weight,
resize_stencil,
_compile_module,
)

View File

@@ -0,0 +1,20 @@
from apps.stable_diffusion.src.utils.stable_args import args
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
from shark.parser import shark_args
if shark_args.vulkan_debug_utils and "vulkan" in args.device:
import iree
print(f"Profiling and saving to {file_path}.")
vulkan_device = iree.runtime.get_device(args.device)
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
return vulkan_device
return None
def end_profiling(device):
if device:
return device.end_profiling()

View File

@@ -0,0 +1,37 @@
import os
import json
import sys
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
def get_json_file(path):
json_var = []
loc_json = resource_path(path)
if os.path.exists(loc_json):
with open(loc_json, encoding="utf-8") as fopen:
json_var = json.load(fopen)
if not json_var:
print(f"Unable to fetch {path}")
return json_var
# TODO: This shouldn't be called from here, every time the file imports
# it will run all the global vars.
prompt_examples = get_json_file("resources/prompts.json")
models_db = get_json_file("resources/model_db.json")
# The base_model contains the input configuration for the different
# models and also helps in providing information for the variants.
base_models = get_json_file("resources/base_model.json")
# Contains optimization flags for different models.
opt_flags = get_json_file("resources/opt_flags.json")

View File

@@ -0,0 +1,296 @@
{
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"vae_upscaler": {
"latents" : {
"shape" : [
"1*batch_size",4,"8*height","8*width"
],
"dtype":"f32"
}
}
},
"unet": {
"stabilityai/stable-diffusion-2-1": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-2-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"runwayml/stable-diffusion-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-x4-upscaler": {
"latents": {
"shape": [
"2*batch_size",
7,
"8*height",
"8*width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"noise_level": {
"shape": [2],
"dtype": "i64"
}
}
},
"stencil_adaptor": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
}
},
"stencil_unet": {
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
},
"control1": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control2": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control3": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control4": {
"shape": [2, 320, "height/2", "width/2"],
"dtype": "f32"
},
"control5": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"control6": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"control7": {
"shape": [2, 640, "height/4", "width/4"],
"dtype": "f32"
},
"control8": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"control9": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"control10": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control11": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control12": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
}
}
}

View File

@@ -0,0 +1,23 @@
[
{
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
"stablediffusion/inpaint_v1":"runwayml/stable-diffusion-inpainting",
"stablediffusion/inpaint_v2":"stabilityai/stable-diffusion-2-inpainting",
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
"openjourney/v1_4":"prompthero/openjourney",
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
},
{
"stablediffusion/fp16":"fp16",
"stablediffusion/fp32":"main",
"anythingv3/fp16":"diffusers",
"anythingv3/fp32":"diffusers",
"analogdiffusion/fp16":"main",
"analogdiffusion/fp32":"main",
"openjourney/fp16":"main",
"openjourney/fp32":"main"
}
]

View File

@@ -0,0 +1,19 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/nightly"
},
{
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan"
}
]

View File

@@ -0,0 +1,84 @@
{
"unet": {
"tuned": {
"fp16": {
"default_compilation_flags": []
},
"fp32": {
"default_compilation_flags": []
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
},
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
},
"fp32": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
},
"clip": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
}
}

View File

@@ -0,0 +1,11 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"],
["A photo of a beach, sunset, calm, beautiful landscape, waves, water"],
["(a large body of water with snowy mountains in the background), (fog, foggy, rolling fog), (clouds, cloudy, rolling clouds), dramatic sky and landscape, extraordinary landscape, (beautiful snow capped mountain background), (forest, dirt path)"],
["a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smokes coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"]]

View File

@@ -0,0 +1,300 @@
import os
import io
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import iree_target_map, run_cmd
from shark.shark_downloader import (
download_model,
download_public_file,
WORKDIR,
)
from shark.parser import shark_args
from apps.stable_diffusion.src.utils.stable_args import args
def get_device():
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
return device
def get_device_args():
device = get_device()
device_spec_args = []
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
gpu_flags = get_iree_gpu_args()
for flag in gpu_flags:
device_spec_args.append(flag)
elif device == "vulkan":
device_spec_args.append(
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
return device, device_spec_args
# Download the model (Unet or VAE fp16) from shark_tank
def load_model_from_tank():
from apps.stable_diffusion.src.models import (
get_params,
get_variant_version,
)
variant, version = get_variant_version(args.hf_model_id)
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{variant}/untuned"
if args.annotation_model == "unet":
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
is_base = "/base" if args.use_base_vae else ""
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)
return mlir_model, model_name
# Download the tuned config files from shark_tank
def load_winograd_configs():
device = get_device()
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
if not os.path.exists(WORKDIR):
os.mkdir(WORKDIR)
winograd_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
def load_lower_configs(base_model_id=None):
from apps.stable_diffusion.src.models import get_variant_version
from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
)
if not base_model_id:
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
variant, version = get_variant_version(base_model_id)
if version == "inpaint_v1":
version = "v1_4"
elif version == "inpaint_v2":
version = "v2_1base"
config_bucket = "gs://shark_tank/sd_tuned_configs/"
device, device_spec_args = get_device_args()
spec = ""
if device_spec_args:
spec = device_spec_args[-1].split("=")[-1].strip()
if device == "vulkan":
spec = spec.split("-")[0]
if args.annotation_model == "vae":
if not spec or spec in ["sm_80"]:
config_name = (
f"{args.annotation_model}_{args.precision}_{device}.json"
)
else:
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["sm_80"]:
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
and args.width == 768
):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
elif spec in ["rdna3"] and version in [
"v2_1",
"v2_1base",
"v1_4",
"v1_5",
]:
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.max_length}_"
f"{args.precision}_"
f"{device}_"
f"{spec}_"
f"{args.width}x{args.height}.json"
)
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.precision}_"
f"{device}_"
f"{spec}_"
f"{args.width}x{args.height}.json"
)
else:
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.precision}_"
f"{device}_"
f"{spec}.json"
)
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading lowering config file from ", lowering_config_dir)
full_gs_url = config_bucket + config_name
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
# Annotate the model with Winograd attribute on selected conv ops
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
with create_context() as ctx:
winograd_model = model_annotation(
ctx,
input_contents=input_mlir,
config_path=winograd_config_dir,
search_op="conv",
winograd=True,
)
bytecode_stream = io.BytesIO()
winograd_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
if args.save_annotation:
if model_name.split("_")[-1] != "tuned":
out_file_path = os.path.join(
args.annotation_output, model_name + "_tuned_torch.mlir"
)
else:
out_file_path = os.path.join(
args.annotation_output, model_name + "_torch.mlir"
)
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return bytecode
def dump_after_mlir(input_mlir, use_winograd):
import iree.compiler as ireec
device, device_spec_args = get_device_args()
if use_winograd:
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))"
)
else:
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
)
dump_module = ireec.compile_str(
input_mlir,
target_backends=[iree_target_map(device)],
extra_args=device_spec_args
+ [
preprocess_flag,
"--compile-to=preprocessing",
],
)
return dump_module
# For Unet annotate the model with tuned lowering configs
def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
# Dump IR after padding/img2col/winograd passes
dump_module = dump_after_mlir(input_mlir, use_winograd)
print("Applying tuned configs on", model_name)
# Annotate the model with lowering configs in the config file
with create_context() as ctx:
tuned_model = model_annotation(
ctx,
input_contents=dump_module,
config_path=lowering_config_dir,
search_op="all",
)
bytecode_stream = io.BytesIO()
tuned_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
if args.save_annotation:
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return bytecode
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
winograd_model, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
tuned_model = mlir_model
else:
use_winograd = False
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)
return tuned_model
if __name__ == "__main__":
mlir_model, model_name = load_model_from_tank()
sd_model_annotation(mlir_model, model_name)

View File

@@ -0,0 +1,732 @@
import argparse
import os
from pathlib import Path
def path_expand(s):
return Path(s).expanduser().resolve()
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
# Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
)
p.add_argument(
"-p",
"--prompts",
nargs="+",
default=[
"a photo taken of the front of a super-car drifting on a road near "
"mountains at high speeds with smokes coming off the tires, front "
"angle, front point of view, trees in the mountains of the "
"background, ((sharp focus))"
],
help="Text of which images to be generated.",
)
p.add_argument(
"--negative_prompts",
nargs="+",
default=[
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
"blurry, ugly, blur, oversaturated, cropped"
],
help="Text you don't want to see in the generated image.",
)
p.add_argument(
"--img_path",
type=str,
help="Path to the image input for img2img/inpainting.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="The number of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=str,
default=-1,
help="The seed or list of seeds to use. -1 for a random one.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="The number of inferences to be made in a single `batch_count`.",
)
p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 769, 8),
help="The height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 769, 8),
help="The width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="The value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="The value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="Max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--max_embeddings_multiples",
type=int,
default=5,
help="The max multiple length of prompt embeddings compared to the max "
"output length of text encoder.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="The strength of change applied on the given input image for "
"img2img.",
)
p.add_argument(
"--use_hiresfix",
type=bool,
default=False,
help="Use Hires Fix to do higher resolution images, while trying to "
"avoid the issues that come with it. This is accomplished by first "
"generating an image using txt2img, then running it through img2img.",
)
p.add_argument(
"--hiresfix_height",
type=int,
default=768,
choices=range(128, 769, 8),
help="The height of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_width",
type=int,
default=768,
choices=range(128, 769, 8),
help="The width of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_strength",
type=float,
default=0.6,
help="The denoising strength to apply for the Hires Fix.",
)
p.add_argument(
"--resample_type",
type=str,
default="Nearest Neighbor",
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
help="The resample type to use when resizing an image before being run "
"through stable diffusion.",
)
##############################################################################
# Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model.",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt.",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The number of steps to train.",
)
##############################################################################
# Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting.",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture.",
)
p.add_argument(
"--inpaint_full_res_padding",
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding.",
)
p.add_argument(
"--pixels",
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting.",
)
p.add_argument(
"--mask_blur",
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting.",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend left for outpainting.",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend right for outpainting.",
)
p.add_argument(
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend top for outpainting.",
)
p.add_argument(
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend bottom for outpainting.",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) "
"(min=0.0, max=4.0).",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0).",
)
##############################################################################
# Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="Device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="Precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="Attempts to load the model from a precompiled flat-buffer "
"and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="Saves the compiled flat-buffer to the local directory.",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available.",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
"HeunDiscrete].",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="Specify the format in which output image is save. "
"Supported options: jpg / png.",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json.",
)
p.add_argument(
"--batch_count",
type=int,
default=1,
help="Number of batches to be generated with random seeds in "
"single execution.",
)
p.add_argument(
"--repeatable_seeds",
default=False,
action=argparse.BooleanOptionalAction,
help="The seed of the first batch will be used as the rng seed to "
"generate the subsequent seeds for subsequent batches in that run.",
)
p.add_argument(
"--ckpt_loc",
type=str,
default="",
help="Path to SD's .ckpt file.",
)
p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
"needs to be plugged in.",
)
p.add_argument(
"--hf_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption.",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
"or an integer).",
)
p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble"],
help="Enable the stencil feature.",
)
p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint "
"file (~3 MB).",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="Runs the quantized version of stable diffusion model. "
"This is currently in experimental phase. "
"Currently, only runs the stable-diffusion-2-1-base model in "
"int8 quantization.",
)
p.add_argument(
"--ondemand",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM.",
)
p.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
p.add_argument(
"--device_allocator_heap_key",
type=str,
default="",
help="Specify heap key for device caching allocator."
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan.",
)
p.add_argument(
"--iree_metal_target_platform",
type=str,
default="",
help="Specify target triple for metal.",
)
##############################################################################
# Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="Use the default scheduler precompiled into the model if available.",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. "
"If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. "
"Use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help="Dispatches to return benchmark data on. "
'Use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help="Directory where you want to store dispatch data "
'generated with "--dispatch_benchmarks".',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for inserting debug frames between iterations "
"for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="Flag setting warmup count for CLIP and VAE [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to clear all mlir and vmfb from common locations. "
"Recompiling will take several minutes.",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether or not to save a generation information "
"json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for whether or not to save generation information in "
"PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="If import_mlir is True, saves mlir via the debug option "
"in shark importer. Does nothing if import_mlir is false (the default).",
)
p.add_argument(
"--compile_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to toggle debug assert/verify flags for imported IR in the"
"iree-compiler. Default to false.",
)
p.add_argument(
"--iree_constant_folding",
default=True,
action=argparse.BooleanOptionalAction,
help="Controls constant folding in iree-compile for all SD models.",
)
##############################################################################
# Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the progress bar animation during "
"image generation.",
)
p.add_argument(
"--ckpt_dir",
type=str,
default="",
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="One of: [api, app, web].",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for generating a public URL.",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="Flag for setting server port.",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling rest API.",
)
p.add_argument(
"--debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling debugging log in WebUI.",
)
p.add_argument(
"--output_gallery",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the output gallery tab, and avoid exposing "
"images under --output_dir in the UI.",
)
p.add_argument(
"--output_gallery_followlinks",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether the output gallery tab in the UI should "
"follow symlinks when listing subdirectories under --output_dir.",
)
##############################################################################
# SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file.",
)
p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)
p.add_argument(
"--save_annotation",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file.",
)
##############################################################################
# SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file.",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning.",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all.",
)
##############################################################################
# DocuChat Flags
##############################################################################
p.add_argument(
"--run_docuchat_web",
default=False,
action=argparse.BooleanOptionalAction,
help="Specifies whether the docuchat's web version is running or not.",
)
args, unknown = p.parse_known_args()
if args.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
os.getcwd(), args.hf_model_id.replace("/", "_")
)

View File

@@ -0,0 +1,2 @@
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector

View File

@@ -0,0 +1,6 @@
import cv2
class CannyDetector:
def __call__(self, img, low_threshold, high_threshold):
return cv2.Canny(img, low_threshold, high_threshold)

View File

@@ -0,0 +1,62 @@
import requests
from pathlib import Path
import torch
import numpy as np
# from annotator.util import annotator_ckpts_path
from apps.stable_diffusion.src.utils.stencils.openpose.body import Body
from apps.stable_diffusion.src.utils.stencils.openpose.hand import Hand
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
draw_bodypose,
draw_handpose,
handDetect,
)
body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth"
class OpenposeDetector:
def __init__(self):
cwd = Path.cwd()
ckpt_path = Path(cwd, "stencil_annotator")
ckpt_path.mkdir(parents=True, exist_ok=True)
body_modelpath = ckpt_path / "body_pose_model.pth"
hand_modelpath = ckpt_path / "hand_pose_model.pth"
if not body_modelpath.is_file():
r = requests.get(body_model_path, allow_redirects=True)
open(body_modelpath, "wb").write(r.content)
if not hand_modelpath.is_file():
r = requests.get(hand_model_path, allow_redirects=True)
open(hand_modelpath, "wb").write(r.content)
self.body_estimation = Body(body_modelpath)
self.hand_estimation = Hand(hand_modelpath)
def __call__(self, oriImg, hand=False):
oriImg = oriImg[:, :, ::-1].copy()
with torch.no_grad():
candidate, subset = self.body_estimation(oriImg)
canvas = np.zeros_like(oriImg)
canvas = draw_bodypose(canvas, candidate, subset)
if hand:
hands_list = handDetect(candidate, subset, oriImg)
all_hand_peaks = []
for x, y, w, is_left in hands_list:
peaks = self.hand_estimation(
oriImg[y : y + w, x : x + w, :]
)
peaks[:, 0] = np.where(
peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x
)
peaks[:, 1] = np.where(
peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y
)
all_hand_peaks.append(peaks)
canvas = draw_handpose(canvas, all_hand_peaks)
return canvas, dict(
candidate=candidate.tolist(), subset=subset.tolist()
)

View File

@@ -0,0 +1,499 @@
import cv2
import numpy as np
import math
from scipy.ndimage.filters import gaussian_filter
import torch
import torch.nn as nn
from collections import OrderedDict
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
make_layers,
transfer,
padRightDownCorner,
)
class BodyPoseModel(nn.Module):
def __init__(self):
super(BodyPoseModel, self).__init__()
# these layers have no relu layer
no_relu_layers = [
"conv5_5_CPM_L1",
"conv5_5_CPM_L2",
"Mconv7_stage2_L1",
"Mconv7_stage2_L2",
"Mconv7_stage3_L1",
"Mconv7_stage3_L2",
"Mconv7_stage4_L1",
"Mconv7_stage4_L2",
"Mconv7_stage5_L1",
"Mconv7_stage5_L2",
"Mconv7_stage6_L1",
"Mconv7_stage6_L1",
]
blocks = {}
block0 = OrderedDict(
[
("conv1_1", [3, 64, 3, 1, 1]),
("conv1_2", [64, 64, 3, 1, 1]),
("pool1_stage1", [2, 2, 0]),
("conv2_1", [64, 128, 3, 1, 1]),
("conv2_2", [128, 128, 3, 1, 1]),
("pool2_stage1", [2, 2, 0]),
("conv3_1", [128, 256, 3, 1, 1]),
("conv3_2", [256, 256, 3, 1, 1]),
("conv3_3", [256, 256, 3, 1, 1]),
("conv3_4", [256, 256, 3, 1, 1]),
("pool3_stage1", [2, 2, 0]),
("conv4_1", [256, 512, 3, 1, 1]),
("conv4_2", [512, 512, 3, 1, 1]),
("conv4_3_CPM", [512, 256, 3, 1, 1]),
("conv4_4_CPM", [256, 128, 3, 1, 1]),
]
)
# Stage 1
block1_1 = OrderedDict(
[
("conv5_1_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_2_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_3_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_4_CPM_L1", [128, 512, 1, 1, 0]),
("conv5_5_CPM_L1", [512, 38, 1, 1, 0]),
]
)
block1_2 = OrderedDict(
[
("conv5_1_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_2_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_3_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_4_CPM_L2", [128, 512, 1, 1, 0]),
("conv5_5_CPM_L2", [512, 19, 1, 1, 0]),
]
)
blocks["block1_1"] = block1_1
blocks["block1_2"] = block1_2
self.model0 = make_layers(block0, no_relu_layers)
# Stages 2 - 6
for i in range(2, 7):
blocks["block%d_1" % i] = OrderedDict(
[
("Mconv1_stage%d_L1" % i, [185, 128, 7, 1, 3]),
("Mconv2_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d_L1" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d_L1" % i, [128, 38, 1, 1, 0]),
]
)
blocks["block%d_2" % i] = OrderedDict(
[
("Mconv1_stage%d_L2" % i, [185, 128, 7, 1, 3]),
("Mconv2_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d_L2" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d_L2" % i, [128, 19, 1, 1, 0]),
]
)
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_1 = blocks["block1_1"]
self.model2_1 = blocks["block2_1"]
self.model3_1 = blocks["block3_1"]
self.model4_1 = blocks["block4_1"]
self.model5_1 = blocks["block5_1"]
self.model6_1 = blocks["block6_1"]
self.model1_2 = blocks["block1_2"]
self.model2_2 = blocks["block2_2"]
self.model3_2 = blocks["block3_2"]
self.model4_2 = blocks["block4_2"]
self.model5_2 = blocks["block5_2"]
self.model6_2 = blocks["block6_2"]
def forward(self, x):
out1 = self.model0(x)
out1_1 = self.model1_1(out1)
out1_2 = self.model1_2(out1)
out2 = torch.cat([out1_1, out1_2, out1], 1)
out2_1 = self.model2_1(out2)
out2_2 = self.model2_2(out2)
out3 = torch.cat([out2_1, out2_2, out1], 1)
out3_1 = self.model3_1(out3)
out3_2 = self.model3_2(out3)
out4 = torch.cat([out3_1, out3_2, out1], 1)
out4_1 = self.model4_1(out4)
out4_2 = self.model4_2(out4)
out5 = torch.cat([out4_1, out4_2, out1], 1)
out5_1 = self.model5_1(out5)
out5_2 = self.model5_2(out5)
out6 = torch.cat([out5_1, out5_2, out1], 1)
out6_1 = self.model6_1(out6)
out6_2 = self.model6_2(out6)
return out6_1, out6_2
class Body(object):
def __init__(self, model_path):
self.model = BodyPoseModel()
if torch.cuda.is_available():
self.model = self.model.cuda()
model_dict = transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def __call__(self, oriImg):
scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre1 = 0.1
thre2 = 0.05
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = cv2.resize(
oriImg,
(0, 0),
fx=scale,
fy=scale,
interpolation=cv2.INTER_CUBIC,
)
imageToTest_padded, pad = padRightDownCorner(
imageToTest, stride, padValue
)
im = (
np.transpose(
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
(3, 2, 0, 1),
)
/ 256
- 0.5
)
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
if torch.cuda.is_available():
data = data.cuda()
with torch.no_grad():
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
# extract outputs, resize, and remove padding
heatmap = np.transpose(
np.squeeze(Mconv7_stage6_L2), (1, 2, 0)
) # output 1 is heatmaps
heatmap = cv2.resize(
heatmap,
(0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC,
)
heatmap = heatmap[
: imageToTest_padded.shape[0] - pad[2],
: imageToTest_padded.shape[1] - pad[3],
:,
]
heatmap = cv2.resize(
heatmap,
(oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC,
)
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
paf = np.transpose(
np.squeeze(Mconv7_stage6_L1), (1, 2, 0)
) # output 0 is PAFs
paf = cv2.resize(
paf,
(0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC,
)
paf = paf[
: imageToTest_padded.shape[0] - pad[2],
: imageToTest_padded.shape[1] - pad[3],
:,
]
paf = cv2.resize(
paf,
(oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC,
)
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
paf_avg += +paf / len(multiplier)
all_peaks = []
peak_counter = 0
for part in range(18):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
map_left = np.zeros(one_heatmap.shape)
map_left[1:, :] = one_heatmap[:-1, :]
map_right = np.zeros(one_heatmap.shape)
map_right[:-1, :] = one_heatmap[1:, :]
map_up = np.zeros(one_heatmap.shape)
map_up[:, 1:] = one_heatmap[:, :-1]
map_down = np.zeros(one_heatmap.shape)
map_down[:, :-1] = one_heatmap[:, 1:]
peaks_binary = np.logical_and.reduce(
(
one_heatmap >= map_left,
one_heatmap >= map_right,
one_heatmap >= map_up,
one_heatmap >= map_down,
one_heatmap > thre1,
)
)
peaks = list(
zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
) # note reverse
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
peak_id = range(peak_counter, peak_counter + len(peaks))
peaks_with_score_and_id = [
peaks_with_score[i] + (peak_id[i],)
for i in range(len(peak_id))
]
all_peaks.append(peaks_with_score_and_id)
peak_counter += len(peaks)
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
[3, 17],
[6, 18],
]
# the middle joints heatmap correpondence
mapIdx = [
[31, 32],
[39, 40],
[33, 34],
[35, 36],
[41, 42],
[43, 44],
[19, 20],
[21, 22],
[23, 24],
[25, 26],
[27, 28],
[29, 30],
[47, 48],
[49, 50],
[53, 54],
[51, 52],
[55, 56],
[37, 38],
[45, 46],
]
connection_all = []
special_k = []
mid_num = 10
for k in range(len(mapIdx)):
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
candA = all_peaks[limbSeq[k][0] - 1]
candB = all_peaks[limbSeq[k][1] - 1]
nA = len(candA)
nB = len(candB)
indexA, indexB = limbSeq[k]
if nA != 0 and nB != 0:
connection_candidate = []
for i in range(nA):
for j in range(nB):
vec = np.subtract(candB[j][:2], candA[i][:2])
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
norm = max(0.001, norm)
vec = np.divide(vec, norm)
startend = list(
zip(
np.linspace(
candA[i][0], candB[j][0], num=mid_num
),
np.linspace(
candA[i][1], candB[j][1], num=mid_num
),
)
)
vec_x = np.array(
[
score_mid[
int(round(startend[I][1])),
int(round(startend[I][0])),
0,
]
for I in range(len(startend))
]
)
vec_y = np.array(
[
score_mid[
int(round(startend[I][1])),
int(round(startend[I][0])),
1,
]
for I in range(len(startend))
]
)
score_midpts = np.multiply(
vec_x, vec[0]
) + np.multiply(vec_y, vec[1])
score_with_dist_prior = sum(score_midpts) / len(
score_midpts
) + min(0.5 * oriImg.shape[0] / norm - 1, 0)
criterion1 = len(
np.nonzero(score_midpts > thre2)[0]
) > 0.8 * len(score_midpts)
criterion2 = score_with_dist_prior > 0
if criterion1 and criterion2:
connection_candidate.append(
[
i,
j,
score_with_dist_prior,
score_with_dist_prior
+ candA[i][2]
+ candB[j][2],
]
)
connection_candidate = sorted(
connection_candidate, key=lambda x: x[2], reverse=True
)
connection = np.zeros((0, 5))
for c in range(len(connection_candidate)):
i, j, s = connection_candidate[c][0:3]
if i not in connection[:, 3] and j not in connection[:, 4]:
connection = np.vstack(
[connection, [candA[i][3], candB[j][3], s, i, j]]
)
if len(connection) >= min(nA, nB):
break
connection_all.append(connection)
else:
special_k.append(k)
connection_all.append([])
# last number in each row is the total parts number of that person
# the second last number in each row is the score of the overall configuration
subset = -1 * np.ones((0, 20))
candidate = np.array(
[item for sublist in all_peaks for item in sublist]
)
for k in range(len(mapIdx)):
if k not in special_k:
partAs = connection_all[k][:, 0]
partBs = connection_all[k][:, 1]
indexA, indexB = np.array(limbSeq[k]) - 1
for i in range(len(connection_all[k])): # = 1:size(temp,1)
found = 0
subset_idx = [-1, -1]
for j in range(len(subset)): # 1:size(subset,1):
if (
subset[j][indexA] == partAs[i]
or subset[j][indexB] == partBs[i]
):
subset_idx[found] = j
found += 1
if found == 1:
j = subset_idx[0]
if subset[j][indexB] != partBs[i]:
subset[j][indexB] = partBs[i]
subset[j][-1] += 1
subset[j][-2] += (
candidate[partBs[i].astype(int), 2]
+ connection_all[k][i][2]
)
elif found == 2: # if found 2 and disjoint, merge them
j1, j2 = subset_idx
membership = (
(subset[j1] >= 0).astype(int)
+ (subset[j2] >= 0).astype(int)
)[:-2]
if len(np.nonzero(membership == 2)[0]) == 0: # merge
subset[j1][:-2] += subset[j2][:-2] + 1
subset[j1][-2:] += subset[j2][-2:]
subset[j1][-2] += connection_all[k][i][2]
subset = np.delete(subset, j2, 0)
else: # as like found == 1
subset[j1][indexB] = partBs[i]
subset[j1][-1] += 1
subset[j1][-2] += (
candidate[partBs[i].astype(int), 2]
+ connection_all[k][i][2]
)
# if find no partA in the subset, create a new subset
elif not found and k < 17:
row = -1 * np.ones(20)
row[indexA] = partAs[i]
row[indexB] = partBs[i]
row[-1] = 2
row[-2] = (
sum(
candidate[
connection_all[k][i, :2].astype(int), 2
]
)
+ connection_all[k][i][2]
)
subset = np.vstack([subset, row])
# delete some rows of subset which has few parts occur
deleteIdx = []
for i in range(len(subset)):
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
deleteIdx.append(i)
subset = np.delete(subset, deleteIdx, axis=0)
# candidate: x, y, score, id
return candidate, subset

View File

@@ -0,0 +1,205 @@
import cv2
import numpy as np
from scipy.ndimage.filters import gaussian_filter
import torch
import torch.nn as nn
from skimage.measure import label
from collections import OrderedDict
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
make_layers,
transfer,
padRightDownCorner,
npmax,
)
class HandPoseModel(nn.Module):
def __init__(self):
super(HandPoseModel, self).__init__()
# these layers have no relu layer
no_relu_layers = [
"conv6_2_CPM",
"Mconv7_stage2",
"Mconv7_stage3",
"Mconv7_stage4",
"Mconv7_stage5",
"Mconv7_stage6",
]
# stage 1
block1_0 = OrderedDict(
[
("conv1_1", [3, 64, 3, 1, 1]),
("conv1_2", [64, 64, 3, 1, 1]),
("pool1_stage1", [2, 2, 0]),
("conv2_1", [64, 128, 3, 1, 1]),
("conv2_2", [128, 128, 3, 1, 1]),
("pool2_stage1", [2, 2, 0]),
("conv3_1", [128, 256, 3, 1, 1]),
("conv3_2", [256, 256, 3, 1, 1]),
("conv3_3", [256, 256, 3, 1, 1]),
("conv3_4", [256, 256, 3, 1, 1]),
("pool3_stage1", [2, 2, 0]),
("conv4_1", [256, 512, 3, 1, 1]),
("conv4_2", [512, 512, 3, 1, 1]),
("conv4_3", [512, 512, 3, 1, 1]),
("conv4_4", [512, 512, 3, 1, 1]),
("conv5_1", [512, 512, 3, 1, 1]),
("conv5_2", [512, 512, 3, 1, 1]),
("conv5_3_CPM", [512, 128, 3, 1, 1]),
]
)
block1_1 = OrderedDict(
[
("conv6_1_CPM", [128, 512, 1, 1, 0]),
("conv6_2_CPM", [512, 22, 1, 1, 0]),
]
)
blocks = {}
blocks["block1_0"] = block1_0
blocks["block1_1"] = block1_1
# stage 2-6
for i in range(2, 7):
blocks["block%d" % i] = OrderedDict(
[
("Mconv1_stage%d" % i, [150, 128, 7, 1, 3]),
("Mconv2_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d" % i, [128, 22, 1, 1, 0]),
]
)
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_0 = blocks["block1_0"]
self.model1_1 = blocks["block1_1"]
self.model2 = blocks["block2"]
self.model3 = blocks["block3"]
self.model4 = blocks["block4"]
self.model5 = blocks["block5"]
self.model6 = blocks["block6"]
def forward(self, x):
out1_0 = self.model1_0(x)
out1_1 = self.model1_1(out1_0)
concat_stage2 = torch.cat([out1_1, out1_0], 1)
out_stage2 = self.model2(concat_stage2)
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
out_stage3 = self.model3(concat_stage3)
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
out_stage4 = self.model4(concat_stage4)
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
out_stage5 = self.model5(concat_stage5)
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
out_stage6 = self.model6(concat_stage6)
return out_stage6
class Hand(object):
def __init__(self, model_path):
self.model = HandPoseModel()
if torch.cuda.is_available():
self.model = self.model.cuda()
model_dict = transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def __call__(self, oriImg):
scale_search = [0.5, 1.0, 1.5, 2.0]
# scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre = 0.05
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
# paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = cv2.resize(
oriImg,
(0, 0),
fx=scale,
fy=scale,
interpolation=cv2.INTER_CUBIC,
)
imageToTest_padded, pad = padRightDownCorner(
imageToTest, stride, padValue
)
im = (
np.transpose(
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
(3, 2, 0, 1),
)
/ 256
- 0.5
)
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
if torch.cuda.is_available():
data = data.cuda()
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
with torch.no_grad():
output = self.model(data).cpu().numpy()
# output = self.model(data).numpy()q
# extract outputs, resize, and remove padding
heatmap = np.transpose(
np.squeeze(output), (1, 2, 0)
) # output 1 is heatmaps
heatmap = cv2.resize(
heatmap,
(0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC,
)
heatmap = heatmap[
: imageToTest_padded.shape[0] - pad[2],
: imageToTest_padded.shape[1] - pad[3],
:,
]
heatmap = cv2.resize(
heatmap,
(oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC,
)
heatmap_avg += heatmap / len(multiplier)
all_peaks = []
for part in range(21):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
# 全部小于阈值
if np.sum(binary) == 0:
all_peaks.append([0, 0])
continue
label_img, label_numbers = label(
binary, return_num=True, connectivity=binary.ndim
)
max_index = (
np.argmax(
[
np.sum(map_ori[label_img == i])
for i in range(1, label_numbers + 1)
]
)
+ 1
)
label_img[label_img != max_index] = 0
map_ori[label_img == 0] = 0
y, x = npmax(map_ori)
all_peaks.append([x, y])
return np.array(all_peaks)

Some files were not shown because too many files have changed in this diff Show More