Compare commits

..

135 Commits

Author SHA1 Message Date
dan
489a858af1 enforce fp32 accumulates for cpu 2023-10-29 18:59:00 +00:00
Vivek Khandelwal
b83d32fafe Fix Falcon GPTQ Pipeline 2023-10-11 20:09:32 +05:30
Vivek Khandelwal
0a618e1863 Add support for Falcon GPTQ 2023-10-11 10:47:48 +05:30
Phaneesh Barwaria
a731eb6ed4 Macos fixes (#1883)
* fix venv setup for MacOS

* allow stream fuse binding on mac

* clean iree metal args
2023-10-09 23:36:12 -07:00
Ean Garvey
2004d16945 Revert "[SDXL] Add SDXL pipeline to SHARK (#1731)" (#1882)
This reverts commit 9f0a421764.
2023-10-09 18:01:44 -07:00
Gaurav Shukla
6e409bfb77 fix else if syntax error
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-10 06:23:56 +05:30
Gaurav Shukla
77727d149c [warning] Fix dropdown warning
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-10 05:18:43 +05:30
Ean Garvey
66f6e79d68 Split CPU/GPU definitions conditionally outside of torch contexts. (#1879) 2023-10-09 16:46:41 -07:00
Ean Garvey
3b825579a7 (LLaMa-2) Point to int4 + f32 acc .mlir for cpu (#1878)
- fixes some issues with non-system prompt invocation

Co-authored-by: Gaurav Shukla <gauravshukla789@gmail.com>
2023-10-09 14:37:35 -05:00
Abhishek Varma
9f0a421764 [SDXL] Add SDXL pipeline to SHARK (#1731)
-- This commit adds SDXL pipeline to SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-10-09 13:01:37 -05:00
Gaurav Shukla
c28682110c [chatbot] Flag to add system prompt
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-09 22:17:39 +05:30
Ean Garvey
caf6cc5d8f Switch most compile flows to use ireec.compile_file. (#1863)
* Switch most compile flows to use ireec.compile_file.

* re-add input type to compile_str path.

* Check if mlir_module exists before checking if it's a path or pyobject.

* Fix some save_dir cases
2023-10-06 23:04:43 -05:00
Ean Garvey
8614a18474 Remove tf dependencies from importer path. (#1874)
* Remove tf dependencies from import path.

* Fix formatting.
2023-10-06 12:27:12 -07:00
Jakub Kuderski
86c1c0c215 Add aggregate statistics to microbenchmark (#1871)
Print averaged results at the end of all iterations. Increase the
default number of iterations to 5.

Example:
```
Number of iterations: 5
Prefill: avg. 0.03 s, stddev 0.00
Decode: avg. 43.34 tokens/s, stdev 0.13
```

Also remove the -2 in the number of generated tokens -- I did not find
any evidence we need it.
2023-10-06 10:03:07 -07:00
Daniel Garvey
8bb364bcb8 enforce fp32 accumulates for cpu (#1873) 2023-10-06 11:34:49 -05:00
Daniel Garvey
7abddd01ec argmax inside model + brevitas pin (#1872) 2023-10-05 20:15:21 -07:00
Abhishek Varma
2a451fa0c7 [Llama2] Add a standalone utility for dynamic and combining IRs
-- This script adds a standalone utility for converting Llama IRs
   to dynamic and combining them as well.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-10-05 20:01:06 +05:30
Jakub Kuderski
9c4610b9da Add microbenchmark mode to vicuna CLI (#1864)
Add flags to enable a non-internactive mode for microbenchmarking llama
models. In this mode, the system and user prompts are specified with CLI
flags, and the number of generated tokens and iterations is fixed.

Also move the stats below the response and trim any response blankspace.
2023-10-05 00:12:08 -04:00
powderluv
a38cc9d216 Update vulkan_utils.py for Radeon 780m igpu (#1866) 2023-10-04 20:33:07 -07:00
Jakub Kuderski
1c382449ec [vulkan] Print note about module load times. NFC. (#1862)
Print a note ahead of a potentially long inactivity to set the right expectations.

Separately, we should add progress to the UI and make this loading faster.
2023-10-03 17:27:27 -04:00
Gaurav Shukla
7cc9b3f8e8 [llama cli] Fix llama cli
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-03 20:39:53 +05:30
Gaurav Shukla
e54517e967 [UI] Disable config generator, lora train and model manager (#1858)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-02 22:34:40 -07:00
Ean Garvey
326327a799 Collect pipeline submodules for diffusers ckpt preprocessing. (#1859) 2023-10-03 00:29:28 -04:00
Ean Garvey
785b65c7b0 Add flag for specifying device-local caching allocator heap key. (#1856) 2023-10-03 00:28:39 -04:00
Sungsoon Cho
0d16c81687 Remove unused import. (#1857) 2023-10-02 11:36:08 -05:00
Vivek Khandelwal
8dd7850c69 Add Falcon-GPTQ support 2023-10-02 16:39:57 +05:30
Gaurav Shukla
e930ba85b4 [os] Remove os dependency from vmfb naming (#1854)
Also fixes a small ui issue for chatbot.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 12:38:17 -05:00
Gaurav Shukla
cd732e7a38 [chatbot] split execution time to prefill and decode
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
Gaurav Shukla
8e0f8b3227 [ui] Update chatbot UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
Gaurav Shukla
b8210ef796 [chatbot] Re-instantiate the chatbot object if device id changes
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
PhaneeshB
94594542a9 remove use of vulkaninfo 2023-09-28 21:57:00 +05:30
Gaurav Shukla
82f833e87d [vulkan] Update vmfb naming
Update vmfb naming for vulkan devices in order to resolve naming
conflicts in the presence of multiple vulkan devices.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-28 14:52:11 +05:30
Vivek Khandelwal
c9d6870105 Modify falcon pipeline for 180b support 2023-09-28 12:39:35 +05:30
Jakub Kuderski
4fec03a6cc [vulkan] Switch from coop matrix NV to KHR (#1848) 2023-09-27 21:43:37 -04:00
harsh-nod
9a27f51378 Deprecate inference directory
This patch removes the inference directory that was no longer being used.
2023-09-27 14:29:00 -07:00
Abhishek Varma
ad1a0f35ff Fix misdirection while saving vmfb
-- Currently SHARK suggests that vmfb has been saved, while
    that is not the case and no vmfb is generated. 
    This creates a misdirection for IR/vmfbs which are of larger
    size.
-- This commit therefore fixes that misdirection.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-27 16:25:29 +05:30
Nelson Sharpe
6773278ec2 Fix checkpoint_path unexpected argument (#1832) 2023-09-24 14:17:52 -07:00
Abhishek Varma
9a0efffcca [Llama2] Fix wrong Vulkan device ID + Add Vulkan compile flags
-- This commit fixes the wrong Vulkan device being selected during
   runtime.
-- It also adds couple of IREE compilation flags to target specific
   Vulkan device.
-- It also changes the Vulkan device listing to be more in tune with
   lowering control flow.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-22 22:24:18 +05:30
gpetters94
61c6f153d9 Switch to keras-nightly to fix a Linux issue (#1835) 2023-09-21 12:33:45 -04:00
Phaneesh Barwaria
effd42e8f5 pin gradio to v3.44.3 2023-09-21 17:33:43 +05:30
Sungsoon Cho
b5fbb1a8a0 Rename the func arg save_json to avoid name collision. (#1837)
* Rename the func arg save_json to avoid name collision.

* black formatted.
2023-09-19 17:29:27 -05:00
Quinn Dawkins
ded74d09cd [vicuna.py] Keep past key values on device (#1836)
The past key values are only used within the models themselves and can
be kept on device. For vulkan int4, this gives 44 tok/s (for the first
prompt) and settles at around 26 tok/s on 7900xtx.
2023-09-19 18:17:41 -04:00
Boian Petkantchin
79267931c1 Add argument --additional_compile_args (#1119)
This allows to pass more arguemnts to the IREE compiler
Example:
python my-app.py --additional_compile_args="--mlir-pretty-debuginfo --mlir-timing"

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-09-19 11:26:03 -05:00
zjgarvey
9eceba69b7 local_tank_cache included into clear_all (#1833) 2023-09-18 00:27:23 -05:00
Ean Garvey
ca609afb6a Update README.md (#1830) 2023-09-14 10:33:57 -05:00
Gaurav Shukla
11bdce9790 [flags] Fix vulkan runtime flags as vma is dropped from iree (#1831) 2023-09-14 08:58:59 -05:00
Ean Garvey
684943a4a6 (SD) Fix tokenizers imports in pyinstaller builds. (#1828)
* Fix tokenizers metadata.

* (SD) Disable VAE lowering configs (rdna3) and add versioned tunings.

* Update sd_annotation.py

* (SD) Add cv2 to spec.

* Update stencil pipeline with the new img2img arg.
2023-09-12 12:23:48 -05:00
PhaneeshB
b817bb8455 add roles for llama2 2023-09-12 10:59:28 +05:30
Ean Garvey
780f520f02 Fix vk.target_env extensions and remove redundant SD imports. (#1826)
* Remove redundant IREE runtime imports.

* Fix vulkan target env extensions.
2023-09-11 13:42:52 -05:00
Dom
c61b6f8d65 Code refactoring (#1817)
* use join

* fix bug

* further code optimizations

---------

Co-authored-by: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com>
2023-09-11 11:30:56 -05:00
Abhishek Varma
c854208d49 [Llama2] Prefetch llama2 tokenizer configs (#1824)
-- This commit prefetches llama2 tokenizer configs from shark_tank.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-08 11:29:54 -07:00
Gaurav Shukla
c5dcfc1f13 [vicuna] Exit when mlir is not present in shark tank (#1825)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-08 10:30:29 -07:00
Abhishek Varma
bde63ee8ae Add logging feature in WebUI (#1821) 2023-09-08 05:48:05 -07:00
Vivek Khandelwal
9681d494eb Update decomp list and shark trainer for DLRM 2023-09-06 21:24:50 +05:30
Gaurav Shukla
ede6bf83e2 [vicuna] Disabling the IR generation path
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-06 20:13:17 +05:30
Ean Garvey
2c2693fb7d Fix torchvision versioning in Linux importer setup. (#1809) 2023-09-05 12:57:03 -05:00
Vivek Khandelwal
1d31b2b2c6 Fix StableHLO Compilation flag 2023-09-05 21:32:33 +05:30
Gaurav Shukla
d2f64eefa3 [chatbot] Remove few outdated models from list (#1814) 2023-09-04 09:26:32 -07:00
Abhishek Varma
87ae14b6ff [SD] Add sdpfa decomposition + update IREE flag
-- This commit adds Scaled Dot Product Flash Attention's decomposition
   in shark_importer.
-- It also updates `iree-flow-enable-data-tiling` to `iree-opt-data-tiling`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-04 18:03:53 +05:30
Phaneesh Barwaria
1ccafa1fc1 fix llama2-70b rewrite tensor dim 2023-09-01 17:27:06 +05:30
jinchen62
4c3d8a0a7f Enable downloading vmfb/mlir for webui (#1807) 2023-08-31 11:05:47 -07:00
jinchen62
3601dc7c3b Fix llama2 13b combined ir (#1803) 2023-08-28 11:34:44 -07:00
Daniel Garvey
671881cf87 Llama2 70b (#1783)
* llama2 70b IR gen

* fix IR sec llama2 + debug

* llama270b

---------

Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
2023-08-25 23:04:28 -07:00
Gaurav Shukla
4e9be6be59 [chatbot] Add debug as class attribute (#1799)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-25 21:46:29 -07:00
Ean Garvey
9c8cbaf498 Add support for ROCM (Windows) in Studio + compile utils (#1770)
* WIP: MSVC ROCM support for SHARK Studio

* Make get_iree_rocm_args platform-agnostic.

* Update stable_args.py

* Update rocm arg handling in SD utils

* Guard quantization imports.

Co-authored-by: jam https://github.com/jammm
2023-08-25 20:56:05 -07:00
Ean Garvey
9e348a114e Revert changes process_skipfiles.py (#1798)
Keeps a small typo fix but reverts the rest of changes to this file from 450c231171
2023-08-25 15:31:49 -07:00
jinchen62
51f90a4d56 Update conversion passes for brevitas quant op (#1795) 2023-08-25 17:28:07 -05:00
Abhishek Varma
310d5d0a49 Fix llama2 13b crashing + add spec file for CLI execution of Llama (#1797)
* [Llama2] Add a fix for Llama2 13B downloading/crashing

-- This commit fixes downloading/crashing of llama2 13B on wrong
   .mlir file.
-- Also adds support for downloading vmfb from shark_tank in CLI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [llama2] Add a spec file to run Llama/Vicuna CLI exe

-- This commit adds a spec file to run Llama/Vicuna CLI exe.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-08-25 09:36:09 -05:00
Ean Garvey
9697981004 Pipe through a debug option to iree compile utils. (#1796)
* Update compile_utils.py

* Pipe through a flag to toggle debug options in compile utils.

* Update SharkLLMBase.py
2023-08-25 07:11:11 -07:00
Ean Garvey
450c231171 Add tokenizers to requirements.txt (#1790)
* Add tokenizers to requirements and pin version

* Update process_skipfiles.py
2023-08-24 19:44:04 -05:00
Ean Garvey
07f6f4a2f7 Add a short README for the OPT examples and small tweaks. (#1793)
* Small changes to OPT example.

* Update opt README.

* Add a few modes to batch script.

* Update README.md
2023-08-24 17:26:11 -07:00
jinchen62
610813c72f Add iree flag to strip assertions (#1791) 2023-08-24 10:51:19 -07:00
Ean Garvey
8e3860c9e6 Remove flags that are default in upstream IREE (#1785)
* Remove index bits flags now set by default

* Update shark_studio_imports.py
2023-08-24 11:57:54 -05:00
xzuyn
e37d6720eb Add Hires Fix (#1787)
* improper test hiresfix

* add sliders & use `clear_cache`

* add resample choices & fix step adjustment

* add step adjustment to img2img

* add resample options to img2img

* simplify hiresfix
- import `img2img_inf` from `img2img_ui.py` instead of just copying it into `txt2img_ui.py`

* set `hri` to None after using

* add more resample types, and don't show output until hiresfix is done

* cleaner implementation

* ran black

* ran black again with jupyter dependencies
2023-08-24 09:01:41 -07:00
Vivek Khandelwal
16160d9a7d Fix combine mlir script 2023-08-24 19:10:49 +05:30
Sungsoon Cho
79075a1a07 Opt perf (#1786)
* Define command line args, model-name, max-seq-len, platform, etc.

* Add usage example.

* Add opt_perf_comparision_batch.py.

* Use shlex instead.
2023-08-24 08:33:12 -05:00
Abhishek Varma
db990826d3 Add Llama2 13B int4 fp16 support (#1784)
Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-08-23 10:00:32 -07:00
gpetters94
7ee3e4ba5d Add stencil_unet_512 support (#1778)
This should fix any remaining issues with stencils and long prompts.
2023-08-22 12:23:46 -04:00
Vivek Khandelwal
05889a8fe1 Add LLaMa2-int4-fp16 support (#1782) 2023-08-22 07:45:50 -07:00
jinchen62
b87efe7686 Fix venv setup for brevitas (#1779) 2023-08-21 11:58:51 -07:00
gpetters94
82b462de3a Fix stencils for long prompts (#1777) 2023-08-19 00:26:51 -07:00
Daniel Garvey
d8f0f7bade replace public with private (#1776)
unload footguns
2023-08-18 14:22:46 -07:00
gpetters94
79bd0b84a1 Fix an issue with diffusers>0.19.3 (#1775) 2023-08-18 14:06:06 -04:00
jinchen62
8738571d1e Adapt the change of brevitas custom op name (#1772) 2023-08-17 14:24:43 -07:00
Gaurav Shukla
a4c354ce54 [version] Pin diffusers==0.19.3
Once the latest works with LORA train, unpin it.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
cc53efa89f [cli] Fix chatbot cli
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
9ae8bc921e [chatbot] Fix chatbot cli and webview warning
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
32eb78f0f9 [chatbot] Fix switching parameters in chatbot
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 19:14:17 +05:30
Ean Garvey
cb509343d9 Fix pytest benchmarks and shark_tank generation. (#1632)
- fix setup_venv.sh for benchmarks/imports etc.
- fix torch benchmarks in SharkBenchmarkRunner
- generate SD artifacts using build_tools/stable_diffusion_testing.py and --import_mlir
- decouple SD gen from tank/generate_sharktank for now
2023-08-16 17:48:47 -05:00
powderluv
6da391c9b1 update signtool to use /fd certHash 2023-08-15 15:11:40 -07:00
Ean Garvey
9dee7ae652 fix tkinter window (#1766) 2023-08-15 13:23:09 -07:00
Ean Garvey
343dfd901c Update SHARK-Runtime links to SRT (#1765)
* Update nightly.yml

* Update setup_venv.ps1

* Update CMakeLists.txt

* Update shark_iree_profiling.md

* Update setup_venv.sh

* Update README.md

* Update .gitmodules

* Update CMakeLists.txt

* Update README.md

* fix signtool flags

* Update nightly.yml

* Update benchmark_utils.py

* uncomment tkinter launch
2023-08-15 12:40:44 -07:00
Ean Garvey
57260b9c37 (Studio) Add hf-hub to pyinstaller metadata (#1761) 2023-08-14 23:01:50 -05:00
Ean Garvey
18e7d2d061 Enable vae tunings for rdna3. (#1764) 2023-08-14 21:00:14 -07:00
Stanley Winata
51a1009796 Add Forward method to SHARKRunner and fix examples. (#1756) 2023-08-14 19:20:37 -07:00
Daniel Garvey
045c3c3852 enable iree-opt-const-expr-hoisting in vicuna (#1742)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-08-14 18:43:42 -07:00
Ean Garvey
0139dd58d9 Specify max allocation size in IREE compile args. (#1760) 2023-08-14 15:43:09 -05:00
Ean Garvey
c96571855a prevents recompiles for cuda benchmarks + update benchmark_module path (#1759)
* xfail resnet50_fp16

* Fix cuda benchmarks and prevent recompilation.
2023-08-14 15:30:32 -05:00
PhaneeshB
4f61d69d86 add support passing iree flags for LLMs 2023-08-15 00:22:56 +05:30
Phaneesh Barwaria
531d447768 set default allocator for metal device creation (#1755) 2023-08-14 06:17:52 -07:00
Vivek Khandelwal
16f46f8de9 Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
c4723f469f Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
d804f45a61 Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
d22177f936 Update requirements.txt 2023-08-14 14:32:19 +05:30
George Petterson
75e68f02f4 Remove CUDNN 2023-08-14 14:32:19 +05:30
Gaurav Shukla
4dc9c59611 [chatbot] Add tokens generated per second (#1753) 2023-08-13 11:25:41 -07:00
Gaurav Shukla
18801dcabc [chat] Update chatbot ui
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-13 18:39:22 +05:30
Gaurav Shukla
3c577f7168 [vicuna] fix shard config generator script (#1747)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-10 11:26:03 -07:00
Stefan Kapusniak
f5e4fa6ffe UI/Web - Revert tab order (#1724)
* Revert ui tab order

* Reverts the tab order, so that SD, LLM, and Experimental are grouped
together again as far as is possible.
* Labelled "Generate Sharding Config" as experimental as pressing the
'Get Model Config' errors for me.

* Fix formatting in index.py
2023-08-10 11:25:36 -07:00
powderluv
48de445325 Enable caching and disable vma (#1746)
* Enable caching allocator by default

Going to toggle VMA off too and this is required for performance.  Will have to monitor in the wild reports.

* Disable VMA

Disable VMA
2023-08-10 10:49:44 -07:00
Gaurav Shukla
8e90f1b81a [vicuna] add default config in case of sharded vicuna
Signed-Off-by: Gaurav Shukla<gaurav@nod-labs.com>
2023-08-10 21:28:08 +05:30
Vivek Khandelwal
e8c1203be2 Fix vicuna script (#1745) 2023-08-10 06:11:14 -07:00
Vivek Khandelwal
e4d7abb519 Final patch for fixing Langchain token streaming issue (#1744) 2023-08-09 10:09:41 -07:00
powderluv
96185c9dc1 pin safetensors to 0.3.1 (#1740) 2023-08-08 19:24:44 -07:00
powderluv
bc22a81925 re-enable constant folding (#1739)
Tested and works well. (modulo unrelated driver issue)
2023-08-08 17:17:38 -07:00
Eliasj42
5203679f1f Bandaid fix 2 (#1728)
* download all mlirs

* fixed install method

* download all mlirs (#1727)

Co-authored-by: Elias Joseph <elias@nod-labs.com>

* added taggs

* fix name check for file existence

* Remove SD from all_models.csv (#1706)

Removes SD from pytests as it has its own test suite.

* gpt_langchain.py fixes for pydantic (#1722)

* removed dead code

---------

Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: Stefan Kapusniak <121311569+one-lithe-rune@users.noreply.github.com>
2023-08-08 12:14:57 -05:00
Vivek Khandelwal
bf073f8f37 [Langchain] Expand pipelines to fix token streaming issue 2023-08-08 10:27:23 +05:30
Stella Laurenzo
cec6eda6b4 Optimize device enumeration overhead and log details on long operations. (#1734)
* Optimize device enumeration overhead and log details on long operations.

* Various fixes to add `@functools.cache` to what should be one time, expensive, device enumeration and setup activities. Cuts several seconds off of initialization on my machine.
* Add detailed tracing to actual invocations if they exceed a certain timeout or have an exception.
* Add detailed tracing to loading status.
* By default detail logging is only printed if an operation takes an excessive amount of time. All logging/timing can be printed by setting the variable `$env:SHARK_DETAIL_TRACE = "1"`

* Remove cache from unhashable functions
2023-08-07 17:20:53 -07:00
Stella Laurenzo
9e37e03741 Clearly differentiate phases of loading modules to better understand if things are taking a long time. (#1733) 2023-08-07 14:03:12 -07:00
Stefan Kapusniak
9b8c4401b5 gpt_langchain.py fixes for pydantic (#1722) 2023-08-07 00:55:38 -07:00
Ean Garvey
a9f95a218b Remove SD from all_models.csv (#1706)
Removes SD from pytests as it has its own test suite.
2023-08-05 15:55:52 -05:00
PhaneeshB
872bd72d0b fix name check for file existence 2023-08-05 21:33:53 +05:30
Eliasj42
fd1c4db5d0 download all mlirs (#1727)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-08-04 18:22:06 -05:00
Daniel Garvey
759664bb48 add py files to pyinstaller for shark (#1723) 2023-08-04 14:10:43 -07:00
Daniel Garvey
14fd0cdd87 add missing subprocess import (#1721) 2023-08-04 15:15:22 -05:00
Daniel Garvey
a57eccc997 fix lint (#1720) 2023-08-04 14:54:33 -05:00
Daniel Garvey
a686d7d89f temporarily disable langchain stuff in webui (#1719)
its breaking the exe
2023-08-04 12:48:06 -07:00
Eliasj42
ed484b8253 added functionality for int8 vicuna and 4 shards (#1712)
combined vicuna_4_shards.py and vicuna.py to reduce code duplication

Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-08-04 14:05:05 -05:00
gpetters94
7fe57ebaaf Add vector database and add support on the web UI (#1699) 2023-08-04 13:47:19 -04:00
Nithin Meganathan
c287fd2be8 Add GPU ID's in model_confg.json by default for manual annotation (#1718) 2023-08-04 12:46:27 -05:00
Gaurav Shukla
51ec1a1360 [vicuna] Integrate sharded vicuna in web (#1717)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-04 11:46:53 -05:00
Gaurav Shukla
bd30044c0b [Shard] Add sharding generation in shark studio
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-04 21:51:14 +05:30
Ean Garvey
c9de2729b2 Add flag for toggling constant folding. (#1714) 2023-08-04 04:55:52 -07:00
Vivek Khandelwal
a5b13fcc2f [Langchain] Patch for fixing streaming of tokens (#1709) 2023-08-03 10:06:49 -07:00
Stefan Kapusniak
6bb329c4af Unsharded Vicuna: Fix Memory Error compiling mlir for lmsys/vicuna-7b-v1.3 fp16 with 64 GiB (#1702) 2023-08-01 06:07:56 -07:00
100 changed files with 7328 additions and 4545 deletions

View File

@@ -51,11 +51,11 @@ jobs:
run: |
./setup_venv.ps1
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets
@@ -104,7 +104,7 @@ jobs:
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
@@ -144,7 +144,7 @@ jobs:
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models

6
.gitignore vendored
View File

@@ -193,3 +193,9 @@ stencil_annotator/
# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData
# Embeded browser cache and other
apps/stable_diffusion/web/EBWebView/
# Llama2 tokenizer configs
llama2_tokenizer_configs/

2
.gitmodules vendored
View File

@@ -1,4 +1,4 @@
[submodule "inference/thirdparty/shark-runtime"]
path = inference/thirdparty/shark-runtime
url =https://github.com/nod-ai/SHARK-Runtime.git
url =https://github.com/nod-ai/SRT.git
branch = shark-06032022

View File

@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
@@ -170,7 +170,7 @@ python -m pip install --upgrade pip
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
```shell
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```
### Run shark tank model tests.

View File

@@ -1,4 +1,3 @@
"""Load question answering chains."""
from __future__ import annotations
from typing import (
Any,
@@ -11,23 +10,34 @@ from typing import (
Union,
Protocol,
)
import inspect
import json
import warnings
from pathlib import Path
import yaml
from abc import ABC, abstractmethod
import langchain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.question_answering import stuff_prompt
from langchain.prompts.base import BasePromptTemplate
from langchain.docstore.document import Document
from abc import ABC, abstractmethod
from langchain.chains.base import Chain
from langchain.callbacks.manager import (
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult, PromptValue
from pydantic import Extra, Field, root_validator
from pydantic import Extra, Field, root_validator, validator
def _get_verbosity() -> bool:
return langchain.verbose
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
@@ -48,6 +58,413 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
return prompt.format(**document_info)
class Chain(Serializable, ABC):
"""Base interface that all chains should implement."""
memory: Optional[BaseMemory] = None
callbacks: Callbacks = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(
default=None, exclude=True
)
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
tags: Optional[List[str]] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def _chain_type(self) -> str:
raise NotImplementedError("Saving not supported for this chain type.")
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Input keys this chain expects."""
@property
@abstractmethod
def output_keys(self) -> List[str]:
"""Output keys this chain expects."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present."""
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
missing_keys = set(self.output_keys).difference(outputs)
if missing_keys:
raise ValueError(f"Missing some output keys: {missing_keys}")
@abstractmethod
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
input_docs = inputs["input_documents"]
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
)
if "is_first" in inputs.keys() and not inputs["is_first"]:
run_manager_ = run_manager
input_list = [inputs]
stop = None
prompts = []
for inputs in input_list:
selected_inputs = {
k: inputs[k] for k in self.prompt.input_variables
}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager_:
run_manager_.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
prompt_strings = [p.to_string() for p in prompts]
prompts = prompt_strings
callbacks = run_manager_.get_child() if run_manager_ else None
tags = None
"""Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will
# not make sense.
if not isinstance(prompts, list):
raise ValueError(
"Argument 'prompts' is expected to be of type List[str], received"
f" argument of type {type(prompts)}."
)
params = self.llm.dict()
params["stop"] = stop
options = {"stop": stop}
disregard_cache = self.llm.cache is not None and not self.llm.cache
callback_manager = CallbackManager.configure(
callbacks,
self.llm.callbacks,
self.llm.verbose,
tags,
self.llm.tags,
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.llm.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager_ = callback_manager.on_llm_start(
dumpd(self),
prompts,
invocation_params=params,
options=options,
)
generations = []
for prompt in prompts:
inputs_ = prompt
num_workers = None
batch_size = None
if num_workers is None:
if self.llm.pipeline._num_workers is None:
num_workers = 0
else:
num_workers = self.llm.pipeline._num_workers
if batch_size is None:
if self.llm.pipeline._batch_size is None:
batch_size = 1
else:
batch_size = self.llm.pipeline._batch_size
preprocess_params = {}
generate_kwargs = {}
preprocess_params.update(generate_kwargs)
forward_params = generate_kwargs
postprocess_params = {}
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
preprocess_params = {
**self.llm.pipeline._preprocess_params,
**preprocess_params,
}
forward_params = {
**self.llm.pipeline._forward_params,
**forward_params,
}
postprocess_params = {
**self.llm.pipeline._postprocess_params,
**postprocess_params,
}
self.llm.pipeline.call_count += 1
if (
self.llm.pipeline.call_count > 10
and self.llm.pipeline.framework == "pt"
and self.llm.pipeline.device.type == "cuda"
):
warnings.warn(
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
" dataset",
UserWarning,
)
model_inputs = self.llm.pipeline.preprocess(
inputs_, **preprocess_params
)
model_outputs = self.llm.pipeline.forward(
model_inputs, **forward_params
)
model_outputs["process"] = False
return model_outputs
output = LLMResult(generations=generations)
run_manager_.on_llm_end(output)
if run_manager_:
output.run = RunInfo(run_id=run_manager_.run_id)
response = output
outputs = [
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
][0]
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
else:
_run_manager = (
run_manager or CallbackManagerForChainRun.get_noop_manager()
)
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {
k: v for k, v in inputs.items() if k != self.input_key
}
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in other_keys.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
inputs["is_first"] = False
inputs["input_documents"] = input_docs
# Call predict on the LLM.
output = self.llm_chain(inputs, callbacks=_run_manager.get_child())
if "process" in output.keys() and not output["process"]:
return output
output = output[self.llm_chain.output_key]
extra_return_dict = {}
extra_return_dict[self.output_key] = output
outputs = extra_return_dict
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(
self, inputs: Union[Dict[str, Any], Any]
) -> Dict[str, str]:
"""Validate and prep inputs."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(
self.memory.memory_variables
)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]
def run(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
if args and not kwargs:
if len(args) != 1:
raise ValueError(
"`run` supports only one positional argument."
)
return self(args[0], callbacks=callbacks, tags=tags)[
self.output_keys[0]
]
if kwargs and not args:
return self(kwargs, callbacks=callbacks, tags=tags)[
self.output_keys[0]
]
if not kwargs and not args:
raise ValueError(
"`run` supported with either positional arguments or keyword arguments,"
" but none were provided."
)
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of chain."""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict()
_dict["_type"] = self._chain_type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the chain.
Args:
file_path: Path to file to save the chain to.
Example:
.. code-block:: python
chain.save(file_path="path/chain.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
chain_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(chain_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(chain_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents."""
@@ -79,12 +496,6 @@ class BaseCombineDocumentsChain(Chain, ABC):
"""
return None
@abstractmethod
def combine_docs(
self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]:
"""Combine documents into a single string."""
def _call(
self,
inputs: Dict[str, List[Document]],
@@ -96,13 +507,49 @@ class BaseCombineDocumentsChain(Chain, ABC):
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
output, extra_return_dict = self.combine_docs(
docs, callbacks=_run_manager.get_child(), **other_keys
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in other_keys.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
# Call predict on the LLM.
output, extra_return_dict = (
self.llm_chain(inputs, callbacks=_run_manager.get_child())[
self.llm_chain.output_key
],
{},
)
extra_return_dict[self.output_key] = output
return extra_return_dict
from pydantic import BaseModel
class Generation(Serializable):
"""Output of a single generation."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw generation info response from the provider"""
"""May include things like reason for finishing (e.g. in OpenAI)"""
# TODO: add log probs
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
class LLMChain(Chain):
"""Chain to run queries against LLMs.
@@ -153,21 +600,13 @@ class LLMChain(Chain):
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
prompts, stop = self.prep_prompts([inputs], run_manager=run_manager)
response = self.llm.generate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
)
return self.create_outputs(response)[0]
def prep_prompts(
self,
@@ -223,23 +662,6 @@ class LLMChain(Chain):
for generation in response.generations
]
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return self(kwargs, callbacks=callbacks)[self.output_key]
def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, Any]]:
@@ -350,14 +772,6 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
@property
def _chain_type(self) -> str:
return "stuff_documents_chain"

View File

@@ -1129,7 +1129,7 @@ class Langchain:
max_time=max_time,
num_return_sequences=num_return_sequences,
)
for r in run_qa_db(
out = run_qa_db(
query=instruction,
iinput=iinput,
context=context,
@@ -1170,689 +1170,8 @@ class Langchain:
auto_reduce_chunks=auto_reduce_chunks,
max_chunks=max_chunks,
device=self.device,
):
(
outr,
extra,
) = r # doesn't accumulate, new answer every yield, so only save that full answer
yield dict(response=outr, sources=extra)
if save_dir:
extra_dict = gen_hyper_langchain.copy()
extra_dict.update(
prompt_type=prompt_type,
inference_server=inference_server,
langchain_mode=langchain_mode,
langchain_action=langchain_action,
document_choice=document_choice,
num_prompt_tokens=num_prompt_tokens,
instruction=instruction,
iinput=iinput,
context=context,
)
save_generate_output(
prompt=prompt,
output=outr,
base_model=base_model,
save_dir=save_dir,
where_from="run_qa_db",
extra_dict=extra_dict,
)
if verbose:
print(
"Post-Generate Langchain: %s decoded_output: %s"
% (str(datetime.now()), len(outr) if outr else -1),
flush=True,
)
if outr or base_model in non_hf_types:
# if got no response (e.g. not showing sources and got no sources,
# so nothing to give to LLM), then slip through and ask LLM
# Or if llama/gptj, then just return since they had no response and can't go down below code path
# clear before return, since .then() never done if from API
clear_torch_cache()
return
if inference_server.startswith(
"openai"
) or inference_server.startswith("http"):
if inference_server.startswith("openai"):
import openai
where_from = "openai_client"
openai.api_key = os.getenv("OPENAI_API_KEY")
stop_sequences = list(
set(prompter.terminate_response + [prompter.PreResponse])
)
stop_sequences = [x for x in stop_sequences if x]
# OpenAI will complain if ask for too many new tokens, takes it as min in some sense, wrongly so.
max_new_tokens_openai = min(
max_new_tokens, model_max_length - num_prompt_tokens
)
gen_server_kwargs = dict(
temperature=temperature if do_sample else 0,
max_tokens=max_new_tokens_openai,
top_p=top_p if do_sample else 1,
frequency_penalty=0,
n=num_return_sequences,
presence_penalty=1.07
- repetition_penalty
+ 0.6, # so good default
)
if inference_server == "openai":
response = openai.Completion.create(
model=base_model,
prompt=prompt,
**gen_server_kwargs,
stop=stop_sequences,
stream=stream_output,
)
if not stream_output:
text = response["choices"][0]["text"]
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
collected_events = []
text = ""
for event in response:
collected_events.append(
event
) # save the event response
event_text = event["choices"][0][
"text"
] # extract the text
text += event_text # append the text
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
elif inference_server == "openai_chat":
response = openai.ChatCompletion.create(
model=base_model,
messages=[
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": prompt,
},
],
stream=stream_output,
**gen_server_kwargs,
)
if not stream_output:
text = response["choices"][0]["message"]["content"]
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
text = ""
for chunk in response:
delta = chunk["choices"][0]["delta"]
if "content" in delta:
text += delta["content"]
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
raise RuntimeError(
"No such OpenAI mode: %s" % inference_server
)
elif inference_server.startswith("http"):
inference_server, headers = get_hf_server(inference_server)
from gradio_utils.grclient import GradioClient
from text_generation import Client as HFClient
if isinstance(model, GradioClient):
gr_client = model
hf_client = None
elif isinstance(model, HFClient):
gr_client = None
hf_client = model
else:
(
inference_server,
gr_client,
hf_client,
) = self.get_client_from_inference_server(
inference_server, base_model=base_model
)
# quick sanity check to avoid long timeouts, just see if can reach server
requests.get(
inference_server,
timeout=int(os.getenv("REQUEST_TIMEOUT_FAST", "10")),
)
if gr_client is not None:
# Note: h2oGPT gradio server could handle input token size issues for prompt,
# but best to handle here so send less data to server
chat_client = False
where_from = "gr_client"
client_langchain_mode = "Disabled"
client_langchain_action = LangChainAction.QUERY.value
gen_server_kwargs = dict(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
early_stopping=early_stopping,
max_time=max_time,
repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
chat=chat_client,
)
# account for gradio into gradio that handles prompting, avoid duplicating prompter prompt injection
if prompt_type in [
None,
"",
PromptType.plain.name,
PromptType.plain.value,
str(PromptType.plain.value),
]:
# if our prompt is plain, assume either correct or gradio server knows different prompt type,
# so pass empty prompt_Type
gr_prompt_type = ""
gr_prompt_dict = ""
gr_prompt = prompt # already prepared prompt
gr_context = ""
gr_iinput = ""
else:
# if already have prompt_type that is not plain, None, or '', then already applied some prompting
# But assume server can handle prompting, and need to avoid double-up.
# Also assume server can do better job of using stopping.py to stop early, so avoid local prompting, let server handle
# So avoid "prompt" and let gradio server reconstruct from prompt_type we passed
# Note it's ok that prompter.get_response() has prompt+text, prompt=prompt passed,
# because just means extra processing and removal of prompt, but that has no human-bot prompting doesn't matter
# since those won't appear
gr_context = context
gr_prompt = instruction
gr_iinput = iinput
gr_prompt_type = prompt_type
gr_prompt_dict = prompt_dict
client_kwargs = dict(
instruction=gr_prompt
if chat_client
else "", # only for chat=True
iinput=gr_iinput, # only for chat=True
context=gr_context,
# streaming output is supported, loops over and outputs each generation in streaming mode
# but leave stream_output=False for simple input/output mode
stream_output=stream_output,
**gen_server_kwargs,
prompt_type=gr_prompt_type,
prompt_dict=gr_prompt_dict,
instruction_nochat=gr_prompt
if not chat_client
else "",
iinput_nochat=gr_iinput, # only for chat=False
langchain_mode=client_langchain_mode,
langchain_action=client_langchain_action,
top_k_docs=top_k_docs,
chunk=chunk,
chunk_size=chunk_size,
document_choice=[DocumentChoices.All_Relevant.name],
)
api_name = "/submit_nochat_api" # NOTE: like submit_nochat but stable API for string dict passing
if not stream_output:
res = gr_client.predict(
str(dict(client_kwargs)), api_name=api_name
)
res_dict = ast.literal_eval(res)
text = res_dict["response"]
sources = res_dict["sources"]
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources=sources,
)
else:
job = gr_client.submit(
str(dict(client_kwargs)), api_name=api_name
)
text = ""
sources = ""
res_dict = dict(response=text, sources=sources)
while not job.done():
outputs_list = job.communicator.job.outputs
if outputs_list:
res = job.communicator.job.outputs[-1]
res_dict = ast.literal_eval(res)
text = res_dict["response"]
sources = res_dict["sources"]
if gr_prompt_type == "plain":
# then gradio server passes back full prompt + text
prompt_and_text = text
else:
prompt_and_text = prompt + text
yield dict(
response=prompter.get_response(
prompt_and_text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources=sources,
)
time.sleep(0.01)
# ensure get last output to avoid race
res_all = job.outputs()
if len(res_all) > 0:
res = res_all[-1]
res_dict = ast.literal_eval(res)
text = res_dict["response"]
sources = res_dict["sources"]
else:
# go with old text if last call didn't work
e = job.future._exception
if e is not None:
stre = str(e)
strex = "".join(
traceback.format_tb(e.__traceback__)
)
else:
stre = ""
strex = ""
print(
"Bad final response: %s %s %s %s %s: %s %s"
% (
base_model,
inference_server,
res_all,
prompt,
text,
stre,
strex,
),
flush=True,
)
if gr_prompt_type == "plain":
# then gradio server passes back full prompt + text
prompt_and_text = text
else:
prompt_and_text = prompt + text
yield dict(
response=prompter.get_response(
prompt_and_text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources=sources,
)
elif hf_client:
# HF inference server needs control over input tokens
where_from = "hf_client"
# prompt must include all human-bot like tokens, already added by prompt
# https://github.com/huggingface/text-generation-inference/tree/main/clients/python#types
stop_sequences = list(
set(
prompter.terminate_response
+ [prompter.PreResponse]
)
)
stop_sequences = [x for x in stop_sequences if x]
gen_server_kwargs = dict(
do_sample=do_sample,
max_new_tokens=max_new_tokens,
# best_of=None,
repetition_penalty=repetition_penalty,
return_full_text=True,
seed=SEED,
stop_sequences=stop_sequences,
temperature=temperature,
top_k=top_k,
top_p=top_p,
# truncate=False, # behaves oddly
# typical_p=top_p,
# watermark=False,
# decoder_input_details=False,
)
# work-around for timeout at constructor time, will be issue if multi-threading,
# so just do something reasonable or max_time if larger
# lower bound because client is re-used if multi-threading
hf_client.timeout = max(300, max_time)
if not stream_output:
text = hf_client.generate(
prompt, **gen_server_kwargs
).generated_text
yield dict(
response=prompter.get_response(
text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
text = ""
for response in hf_client.generate_stream(
prompt, **gen_server_kwargs
):
if not response.token.special:
# stop_sequences
text_chunk = response.token.text
text += text_chunk
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
raise RuntimeError(
"Failed to get client: %s" % inference_server
)
else:
raise RuntimeError(
"No such inference_server %s" % inference_server
)
if save_dir and text:
# save prompt + new text
extra_dict = gen_server_kwargs.copy()
extra_dict.update(
dict(
inference_server=inference_server,
num_prompt_tokens=num_prompt_tokens,
)
)
save_generate_output(
prompt=prompt,
output=text,
base_model=base_model,
save_dir=save_dir,
where_from=where_from,
extra_dict=extra_dict,
)
return
else:
assert not inference_server, (
"inferene_server=%s not supported" % inference_server
)
if isinstance(tokenizer, str):
# pipeline
if tokenizer == "summarization":
key = "summary_text"
else:
raise RuntimeError("No such task type %s" % tokenizer)
# NOTE: uses max_length only
yield dict(
response=model(prompt, max_length=max_new_tokens)[0][key],
sources="",
)
if "mbart-" in base_model.lower():
assert src_lang is not None
tokenizer.src_lang = self.languages_covered()[src_lang]
stopping_criteria = get_stopping(
prompt_type,
prompt_dict,
tokenizer,
self.device,
model_max_length=tokenizer.model_max_length,
)
print(prompt)
# exit(0)
inputs = tokenizer(prompt, return_tensors="pt")
if debug and len(inputs["input_ids"]) > 0:
print("input_ids length", len(inputs["input_ids"][0]), flush=True)
input_ids = inputs["input_ids"].to(self.device)
# CRITICAL LIMIT else will fail
max_max_tokens = tokenizer.model_max_length
max_input_tokens = max_max_tokens - min_new_tokens
# NOTE: Don't limit up front due to max_new_tokens, let go up to max or reach max_max_tokens in stopping.py
input_ids = input_ids[:, -max_input_tokens:]
# required for falcon if multiple threads or asyncio accesses to model during generation
if use_cache is None:
use_cache = False if "falcon" in base_model else True
gen_config_kwargs = dict(
temperature=float(temperature),
top_p=float(top_p),
top_k=top_k,
num_beams=num_beams,
do_sample=do_sample,
repetition_penalty=float(repetition_penalty),
num_return_sequences=num_return_sequences,
renormalize_logits=True,
remove_invalid_values=True,
use_cache=use_cache,
)
token_ids = [
"eos_token_id",
"pad_token_id",
"bos_token_id",
"cls_token_id",
"sep_token_id",
]
for token_id in token_ids:
if (
hasattr(tokenizer, token_id)
and getattr(tokenizer, token_id) is not None
):
gen_config_kwargs.update(
{token_id: getattr(tokenizer, token_id)}
)
generation_config = GenerationConfig(**gen_config_kwargs)
gen_kwargs = dict(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens, # prompt + new
min_new_tokens=min_new_tokens, # prompt + new
early_stopping=early_stopping, # False, True, "never"
max_time=max_time,
stopping_criteria=stopping_criteria,
)
if "gpt2" in base_model.lower():
gen_kwargs.update(
dict(
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
)
elif "mbart-" in base_model.lower():
assert tgt_lang is not None
tgt_lang = self.languages_covered()[tgt_lang]
gen_kwargs.update(
dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])
)
else:
token_ids = ["eos_token_id", "bos_token_id", "pad_token_id"]
for token_id in token_ids:
if (
hasattr(tokenizer, token_id)
and getattr(tokenizer, token_id) is not None
):
gen_kwargs.update({token_id: getattr(tokenizer, token_id)})
decoder_kwargs = dict(
skip_special_tokens=True, clean_up_tokenization_spaces=True
)
decoder = functools.partial(tokenizer.decode, **decoder_kwargs)
decoder_raw_kwargs = dict(
skip_special_tokens=False, clean_up_tokenization_spaces=True
)
decoder_raw = functools.partial(tokenizer.decode, **decoder_raw_kwargs)
with torch.no_grad():
have_lora_weights = lora_weights not in [no_lora_str, "", None]
context_class_cast = (
NullContext
if self.device == "cpu"
or have_lora_weights
or self.device == "mps"
else torch.autocast
)
with context_class_cast(self.device):
# protection for gradio not keeping track of closed users,
# else hit bitsandbytes lack of thread safety:
# https://github.com/h2oai/h2ogpt/issues/104
# but only makes sense if concurrency_count == 1
context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
if verbose:
print("Pre-Generate: %s" % str(datetime.now()), flush=True)
decoded_output = None
with context_class("generate.lock"):
if verbose:
print("Generate: %s" % str(datetime.now()), flush=True)
# decoded tokenized prompt can deviate from prompt due to special characters
inputs_decoded = decoder(input_ids[0])
inputs_decoded_raw = decoder_raw(input_ids[0])
if inputs_decoded == prompt:
# normal
pass
elif inputs_decoded.lstrip() == prompt.lstrip():
# sometimes extra space in front, make prompt same for prompt removal
prompt = inputs_decoded
elif inputs_decoded_raw == prompt:
# some models specify special tokens that are part of normal prompt, so can't skip them
inputs_decoded = prompt = inputs_decoded_raw
decoder = decoder_raw
decoder_kwargs = decoder_raw_kwargs
elif inputs_decoded_raw.replace("<unk> ", "").replace(
"<unk>", ""
).replace("\n", " ").replace(" ", "") == prompt.replace(
"\n", " "
).replace(
" ", ""
):
inputs_decoded = prompt = inputs_decoded_raw
decoder = decoder_raw
decoder_kwargs = decoder_raw_kwargs
else:
if verbose:
print(
"WARNING: Special characters in prompt",
flush=True,
)
if stream_output:
skip_prompt = False
streamer = H2OTextIteratorStreamer(
tokenizer,
skip_prompt=skip_prompt,
block=False,
**decoder_kwargs,
)
gen_kwargs.update(dict(streamer=streamer))
target = wrapped_partial(
self.generate_with_exceptions,
model.generate,
prompt=prompt,
inputs_decoded=inputs_decoded,
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
**gen_kwargs,
)
bucket = queue.Queue()
thread = EThread(
target=target, streamer=streamer, bucket=bucket
)
thread.start()
outputs = ""
try:
for new_text in streamer:
if bucket.qsize() > 0 or thread.exc:
thread.join()
outputs += new_text
yield dict(
response=prompter.get_response(
outputs,
prompt=inputs_decoded,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
except BaseException:
# if any exception, raise that exception if was from thread, first
if thread.exc:
raise thread.exc
raise
finally:
# clear before return, since .then() never done if from API
clear_torch_cache()
# in case no exception and didn't join with thread yet, then join
if not thread.exc:
thread.join()
# in case raise StopIteration or broke queue loop in streamer, but still have exception
if thread.exc:
raise thread.exc
decoded_output = outputs
else:
try:
outputs = model.generate(**gen_kwargs)
finally:
clear_torch_cache() # has to be here for API submit_nochat_api since.then() not called
outputs = [decoder(s) for s in outputs.sequences]
yield dict(
response=prompter.get_response(
outputs,
prompt=inputs_decoded,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
if outputs and len(outputs) >= 1:
decoded_output = prompt + outputs[0]
if save_dir and decoded_output:
extra_dict = gen_config_kwargs.copy()
extra_dict.update(
dict(num_prompt_tokens=num_prompt_tokens)
)
save_generate_output(
prompt=prompt,
output=decoded_output,
base_model=base_model,
save_dir=save_dir,
where_from="evaluate_%s" % str(stream_output),
extra_dict=gen_config_kwargs,
)
if verbose:
print(
"Post-Generate: %s decoded_output: %s"
% (
str(datetime.now()),
len(decoded_output) if decoded_output else -1,
),
flush=True,
)
return outputs[0]
return out
inputs_list_names = list(inspect.signature(evaluate).parameters)
global inputs_kwargs_list

View File

@@ -436,7 +436,7 @@ class GradioInference(LLM):
chat_client: bool = False
return_full_text: bool = True
stream: bool = False
stream_output: bool = Field(False, alias="stream")
sanitize_bot_response: bool = False
prompter: Any = None
@@ -481,7 +481,7 @@ class GradioInference(LLM):
# so server should get prompt_type or '', not plain
# This is good, so gradio server can also handle stopping.py conditions
# this is different than TGI server that uses prompter to inject prompt_type prompting
stream_output = self.stream
stream_output = self.stream_output
gr_client = self.client
client_langchain_mode = "Disabled"
client_langchain_action = LangChainAction.QUERY.value
@@ -596,7 +596,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
inference_server_url: str = ""
timeout: int = 300
headers: dict = None
stream: bool = False
stream_output: bool = Field(False, alias="stream")
sanitize_bot_response: bool = False
prompter: Any = None
tokenizer: Any = None
@@ -663,7 +663,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
# lower bound because client is re-used if multi-threading
self.client.timeout = max(300, self.timeout)
if not self.stream:
if not self.stream_output:
res = self.client.generate(
prompt,
**gen_server_kwargs,
@@ -852,7 +852,7 @@ def get_llm(
top_p=top_p,
# typical_p=top_p,
callbacks=callbacks if stream_output else None,
stream=stream_output,
stream_output=stream_output,
prompter=prompter,
tokenizer=tokenizer,
client=hf_client,
@@ -2510,8 +2510,7 @@ def _run_qa_db(
formatted_doc_chunks = "\n\n".join(
[get_url(x) + "\n\n" + x.page_content for x in docs]
)
yield formatted_doc_chunks, ""
return
return formatted_doc_chunks, ""
if not docs and langchain_action in [
LangChainAction.SUMMARIZE_MAP.value,
LangChainAction.SUMMARIZE_ALL.value,
@@ -2523,8 +2522,7 @@ def _run_qa_db(
else "No documents to summarize."
)
extra = ""
yield ret, extra
return
return ret, extra
if not docs and langchain_mode not in [
LangChainMode.DISABLED.value,
LangChainMode.CHAT_LLM.value,
@@ -2536,8 +2534,7 @@ def _run_qa_db(
else "No documents to query."
)
extra = ""
yield ret, extra
return
return ret, extra
if chain is None and model_name not in non_hf_types:
# here if no docs at all and not HF type
@@ -2557,22 +2554,7 @@ def _run_qa_db(
)
with context_class_cast(args.device):
answer = chain()
if not use_context:
ret = answer["output_text"]
extra = ""
yield ret, extra
elif answer is not None:
ret, extra = get_sources_answer(
query,
answer,
scores,
show_rank,
answer_with_sources,
verbose=verbose,
)
yield ret, extra
return
return answer
def get_similarity_chain(

View File

@@ -3,13 +3,11 @@ from apps.stable_diffusion.src.utils.utils import _compile_module
from io import BytesIO
import torch_mlir
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
from stopping import get_stopping
from prompter import Prompter, PromptType
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
@@ -22,7 +20,7 @@ import gc
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from apps.stable_diffusion.src import args
# Brevitas
@@ -31,14 +29,8 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
def brevitasmatmul_rhs_group_quant〡shape(
lhs: List[int],
rhs: List[int],
rhs_scale: List[int],
rhs_zero_point: List[int],
rhs_bit_width: int,
rhs_group_size: int,
) -> List[int]:
# fmt: off
def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
@@ -47,30 +39,21 @@ def brevitasmatmul_rhs_group_quant〡shape(
raise ValueError("Input shapes not supported.")
def brevitasmatmul_rhs_group_quant〡dtype(
lhs_rank_dtype: Tuple[int, int],
rhs_rank_dtype: Tuple[int, int],
rhs_scale_rank_dtype: Tuple[int, int],
rhs_zero_point_rank_dtype: Tuple[int, int],
rhs_bit_width: int,
rhs_group_size: int,
) -> int:
def quantmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def brevitasmatmul_rhs_group_quant〡has_value_semantics(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
) -> None:
def quantmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics,
]
quantmatmul_rhs_group_quant〡shape,
quantmatmul_rhs_group_quant〡dtype,
quantmatmul_rhs_group_quant〡has_value_semantics]
# fmt: on
global_device = "cuda"
global_precision = "fp16"
@@ -246,7 +229,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
ts_graph,
[*h2ogptCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
@@ -254,7 +237,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
@@ -273,6 +256,11 @@ class H2OGPTSHARKModel(torch.nn.Module):
bytecode = bytecode_stream.getvalue()
del module
bytecode = save_mlir(
bytecode,
model_name=f"h2ogpt_{precision}",
frontend="torch",
)
return bytecode
def forward(self, input_ids, attention_mask):
@@ -285,7 +273,215 @@ class H2OGPTSHARKModel(torch.nn.Module):
return result
h2ogpt_model = H2OGPTSHARKModel()
def decode_tokens(tokenizer, res_tokens):
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = tokenizer.decode(res_tokens, skip_special_tokens=True)
return res_str
def generate_token(h2ogpt_shark_model, model, tokenizer, **generate_kwargs):
del generate_kwargs["max_time"]
generate_kwargs["input_ids"] = generate_kwargs["input_ids"].to(
device=tensor_device
)
generate_kwargs["attention_mask"] = generate_kwargs["attention_mask"].to(
device=tensor_device
)
truncated_input_ids = []
stopping_criteria = generate_kwargs["stopping_criteria"]
generation_config_ = GenerationConfig.from_model_config(model.config)
generation_config = copy.deepcopy(generation_config_)
model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
model_kwargs,
) = model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
input_ids_seq_length = input_ids.shape[-1]
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
)
logits_warper = model._get_logits_warper(generation_config)
(
input_ids,
model_kwargs,
) = model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=model.config.is_encoder_decoder, # False
**model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = (
torch.tensor(eos_token_id).to(device=tensor_device)
if eos_token_id is not None
else None
)
pad_token_id = generation_config.pad_token_id
eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(
input_ids.shape[0],
dtype=torch.long,
device=input_ids.device,
)
timesRan = 0
import time
start = time.time()
print("\n")
res_tokens = []
while True:
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs
)
outputs = h2ogpt_shark_model.forward(
model_inputs["input_ids"], model_inputs["attention_mask"]
)
if args.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = next_token * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
input_ids = torch.cat([input_ids, next_token[:, None]], dim=-1)
model_kwargs["past_key_values"] = None
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
truncated_input_ids.append(input_ids[:, 0])
input_ids = input_ids[:, 1:]
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 1:]
new_word = tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
res_tokens.append(next_token)
if new_word == "<0x0A>":
print("\n", end="", flush=True)
else:
print(f"{new_word}", end=" ", flush=True)
part_str = decode_tokens(tokenizer, res_tokens)
yield part_str
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_token.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0 or stopping_criteria(
input_ids, scores
):
break
timesRan = timesRan + 1
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / timesRan
)
)
torch.cuda.empty_cache()
gc.collect()
res_str = decode_tokens(tokenizer, res_tokens)
yield res_str
def pad_or_truncate_inputs(
@@ -498,233 +694,6 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
)
return records
def generate_new_token(self):
model_inputs = self.model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = h2ogpt_model.forward(
model_inputs["input_ids"], model_inputs["attention_mask"]
)
if args.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.truncated_input_ids.append(self.input_ids[:, 0])
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
def generate_token(self, **generate_kwargs):
del generate_kwargs["max_time"]
self.truncated_input_ids = []
generation_config_ = GenerationConfig.from_model_config(
self.model.config
)
generation_config = copy.deepcopy(generation_config_)
self.model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
self.stopping_criteria = (
self.stopping_criteria
if self.stopping_criteria is not None
else StoppingCriteriaList()
)
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
self.model_kwargs,
) = self.model._prepare_model_inputs(
None, generation_config.bos_token_id, self.model_kwargs
)
batch_size = inputs_tensor.shape[0]
self.model_kwargs[
"output_attentions"
] = generation_config.output_attentions
self.model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
self.model_kwargs["use_cache"] = generation_config.use_cache
self.input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else self.model_kwargs.pop("input_ids")
)
input_ids_seq_length = self.input_ids.shape[-1]
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
self.logits_processor = self.model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=self.stopping_criteria,
)
self.logits_warper = self.model._get_logits_warper(generation_config)
(
self.input_ids,
self.model_kwargs,
) = self.model._expand_inputs_for_generation(
input_ids=self.input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.model.config.is_encoder_decoder, # False
**self.model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id).to(device=tensor_device)
if eos_token_id is not None
else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
output_attentions = generation_config.output_attentions # False
output_hidden_states = generation_config.output_hidden_states # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
self.input_ids.shape[0],
dtype=torch.long,
device=self.input_ids.device,
)
timesRan = 0
import time
start = time.time()
print("\n")
while True:
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
print(f"{new_word}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(self.input_ids, self.scores)
):
break
timesRan = timesRan + 1
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / timesRan
)
)
self.input_ids = torch.cat(
[
torch.tensor(self.truncated_input_ids)
.to(device=tensor_device)
.unsqueeze(dim=0),
self.input_ids,
],
dim=-1,
)
torch.cuda.empty_cache()
gc.collect()
return self.input_ids
def _forward(self, model_inputs, **generate_kwargs):
if self.can_stop:
stopping_criteria = get_stopping(
@@ -784,19 +753,13 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
input_ids, attention_mask = pad_or_truncate_inputs(
input_ids, attention_mask, max_padding_length=max_padding_length
)
self.stopping_criteria = generate_kwargs["stopping_criteria"]
generated_sequence = self.generate_token(
input_ids=input_ids,
attention_mask=attention_mask,
**generate_kwargs,
)
out_b = generated_sequence.shape[0]
generated_sequence = generated_sequence.reshape(
in_b, out_b // in_b, *generated_sequence.shape[1:]
)
return {
"generated_sequence": generated_sequence,
return_dict = {
"model": self.model,
"tokenizer": self.tokenizer,
"input_ids": input_ids,
"prompt_text": prompt_text,
"attention_mask": attention_mask,
"attention_mask": attention_mask,
}
return_dict = {**return_dict, **generate_kwargs}
return return_dict

View File

@@ -1,5 +1,4 @@
import os
import fire
from gpt_langchain import (
path_to_docs,
@@ -202,7 +201,3 @@ def make_db_main(
if verbose:
print("DONE", flush=True)
return db, collection_name
if __name__ == "__main__":
fire.Fire(make_db_main)

View File

@@ -0,0 +1,442 @@
from pathlib import Path
import argparse
from argparse import RawTextHelpFormatter
import re, gc
"""
This script can be used as a standalone utility to convert IRs to dynamic + combine them.
Following are the various ways this script can be used :-
a. To convert a single Linalg IR to dynamic IR:
--dynamic --first_ir_path=<PATH TO FIRST IR>
b. To convert two Linalg IRs to dynamic IR:
--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>
c. To combine two Linalg IRs into one:
--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
d. To convert both IRs into dynamic as well as combine the IRs:
--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
NOTE: For dynamic you'll also need to provide the following set of flags:-
i. For First Llama : --dynamic_input_size (DEFAULT: 19)
ii. For Second Llama: --model_name (DEFAULT: llama2_7b)
--precision (DEFAULT: 'int4')
You may use --save_dynamic to also save the dynamic IR in option d above.
Else for option a. and b. the dynamic IR(s) will get saved by default.
"""
def combine_mlir_scripts(
first_vicuna_mlir,
second_vicuna_mlir,
output_name,
return_ir=True,
):
print(f"[DEBUG] combining first and second mlir")
print(f"[DEBUG] output_name = {output_name}")
maps1 = []
maps2 = []
constants = set()
f1 = []
f2 = []
print(f"[DEBUG] processing first vicuna mlir")
first_vicuna_mlir = first_vicuna_mlir.splitlines()
while first_vicuna_mlir:
line = first_vicuna_mlir.pop(0)
if re.search("#map\d*\s*=", line):
maps1.append(line)
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "first_vicuna_forward", line)
f1.append(line)
f1 = f1[:-1]
del first_vicuna_mlir
gc.collect()
for i, map_line in enumerate(maps1):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
maps1[i] = map_line
f1 = [
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
for func_line in f1
]
print(f"[DEBUG] processing second vicuna mlir")
second_vicuna_mlir = second_vicuna_mlir.splitlines()
while second_vicuna_mlir:
line = second_vicuna_mlir.pop(0)
if re.search("#map\d*\s*=", line):
maps2.append(line)
elif "global_seed" in line:
continue
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
f2 = f2[:-1]
del second_vicuna_mlir
gc.collect()
for i, map_line in enumerate(maps2):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
maps2[i] = map_line
f2 = [
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
for func_line in f2
]
module_start = 'module attributes {torch.debug_module_name = "_lambda"} {'
module_end = "}"
global_vars = []
vnames = []
global_var_loading1 = []
global_var_loading2 = []
print(f"[DEBUG] processing constants")
counter = 0
constants = list(constants)
while constants:
constant = constants.pop(0)
vname, vbody = constant.split("=")
vname = re.sub("%", "", vname)
vname = vname.strip()
vbody = re.sub("arith.constant", "", vbody)
vbody = vbody.strip()
if len(vbody.split(":")) < 2:
print(constant)
vdtype = vbody.split(":")[-1].strip()
fixed_vdtype = vdtype
if "c1_i64" in vname:
print(constant)
counter += 1
if counter == 2:
counter = 0
print("detected duplicate")
continue
vnames.append(vname)
if "true" not in vname:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
else:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : i1"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
new_f1, new_f2 = [], []
print(f"[DEBUG] processing f1")
for line in f1:
if "func.func" in line:
new_f1.append(line)
for global_var in global_var_loading1:
new_f1.append(global_var)
else:
new_f1.append(line)
print(f"[DEBUG] processing f2")
for line in f2:
if "func.func" in line:
new_f2.append(line)
for global_var in global_var_loading2:
if (
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
in global_var
):
print(global_var)
new_f2.append(global_var)
else:
new_f2.append(line)
f1 = new_f1
f2 = new_f2
del new_f1
del new_f2
gc.collect()
print(
[
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
for x in [maps1, maps2, global_vars, f1, f2]
]
)
# doing it this way rather than assembling the whole string
# to prevent OOM with 64GiB RAM when encoding the file.
print(f"[DEBUG] Saving mlir to {output_name}")
with open(output_name, "w+") as f_:
f_.writelines(line + "\n" for line in maps1)
f_.writelines(line + "\n" for line in maps2)
f_.writelines(line + "\n" for line in [module_start])
f_.writelines(line + "\n" for line in global_vars)
f_.writelines(line + "\n" for line in f1)
f_.writelines(line + "\n" for line in f2)
f_.writelines(line + "\n" for line in [module_end])
del maps1
del maps2
del module_start
del global_vars
del f1
del f2
del module_end
gc.collect()
if return_ir:
print(f"[DEBUG] Reading combined mlir back in")
with open(output_name, "rb") as f:
return f.read()
def write_in_dynamic_inputs0(module, dynamic_input_size):
print("[DEBUG] writing dynamic inputs to first vicuna")
# Current solution for ensuring mlir files support dynamic inputs
# TODO: find a more elegant way to implement this
new_lines = []
module = module.splitlines()
while module:
line = module.pop(0)
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
continue
new_lines.append(line)
return "\n".join(new_lines)
def write_in_dynamic_inputs1(module, model_name, precision):
print("[DEBUG] writing dynamic inputs to second vicuna")
def remove_constant_dim(line):
if "c19_i64" in line:
line = re.sub("c19_i64", "dim_i64", line)
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
if "x20x" in line or "<20x" in line:
line = re.sub("20x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
if " 20," in line:
line = re.sub(" 20,", " %dimp1,", line)
return line
module = module.splitlines()
new_lines = []
# Using a while loop and the pop method to avoid creating a copy of module
if "llama2_13b" in model_name:
pkv_tensor_shape = "tensor<1x40x?x128x"
elif "llama2_70b" in model_name:
pkv_tensor_shape = "tensor<1x8x?x128x"
else:
pkv_tensor_shape = "tensor<1x32x?x128x"
if precision in ["fp16", "int4", "int8"]:
pkv_tensor_shape += "f16>"
else:
pkv_tensor_shape += "f32>"
while module:
line = module.pop(0)
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
)
continue
if "%c2 = arith.constant 2 : index" in line:
continue
if "%c20_i64 = arith.constant 20 : i64" in line:
new_lines.append("%c1_i64 = arith.constant 1 : i64")
new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
new_lines.append(
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
)
continue
line = remove_constant_dim(line)
new_lines.append(line)
return "\n".join(new_lines)
def save_dynamic_ir(ir_to_save, output_file):
if not ir_to_save:
return
# We only get string output from the dynamic conversion utility.
from contextlib import redirect_stdout
with open(output_file, "w") as f:
with redirect_stdout(f):
print(ir_to_save)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="llama ir utility",
description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n"
+ "\tFollowing are the various ways this script can be used :-\n"
+ "\t\ta. To convert a single Linalg IR to dynamic IR:\n"
+ "\t\t\t--dynamic --first_ir_path=<PATH TO FIRST IR>\n"
+ "\t\tb. To convert two Linalg IRs to dynamic IR:\n"
+ "\t\t\t--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>\n"
+ "\t\tc. To combine two Linalg IRs into one:\n"
+ "\t\t\t--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n"
+ "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n"
+ "\t\t\t--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n\n"
+ "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n"
+ "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n"
+ "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n"
+ "\t\t\t--precision (DEFAULT: 'int4')\n"
+ "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n"
+ "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--precision",
"-p",
default="int4",
choices=["fp32", "fp16", "int8", "int4"],
help="Precision of the concerned IR",
)
parser.add_argument(
"--model_name",
type=str,
default="llama2_7b",
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
help="Specify which model to run.",
)
parser.add_argument(
"--first_ir_path",
default=None,
help="path to first llama mlir file",
)
parser.add_argument(
"--second_ir_path",
default=None,
help="path to second llama mlir file",
)
parser.add_argument(
"--dynamic_input_size",
type=int,
default=19,
help="Specify the static input size to replace with dynamic dim.",
)
parser.add_argument(
"--dynamic",
default=False,
action=argparse.BooleanOptionalAction,
help="Converts the IR(s) to dynamic",
)
parser.add_argument(
"--save_dynamic",
default=False,
action=argparse.BooleanOptionalAction,
help="Save the individual IR(s) after converting to dynamic",
)
parser.add_argument(
"--combine",
default=False,
action=argparse.BooleanOptionalAction,
help="Converts the IR(s) to dynamic",
)
args, unknown = parser.parse_known_args()
dynamic = args.dynamic
combine = args.combine
assert (
dynamic or combine
), "neither `dynamic` nor `combine` flag is turned on"
first_ir_path = args.first_ir_path
second_ir_path = args.second_ir_path
assert first_ir_path or second_ir_path, "no input ir has been provided"
if combine:
assert (
first_ir_path and second_ir_path
), "you will need to provide both IRs to combine"
precision = args.precision
model_name = args.model_name
dynamic_input_size = args.dynamic_input_size
save_dynamic = args.save_dynamic
print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}")
print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}")
if dynamic and not combine:
save_dynamic = True
first_ir = None
first_dynamic_ir_name = None
second_ir = None
second_dynamic_ir_name = None
if first_ir_path:
first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic"
with open(first_ir_path, "r") as f:
first_ir = f.read()
if second_ir_path:
second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic"
with open(second_ir_path, "r") as f:
second_ir = f.read()
if dynamic:
first_ir = (
write_in_dynamic_inputs0(first_ir, dynamic_input_size)
if first_ir
else None
)
second_ir = (
write_in_dynamic_inputs1(second_ir, model_name, precision)
if second_ir
else None
)
if save_dynamic:
save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir")
save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir")
if combine:
combine_mlir_scripts(
first_ir,
second_ir,
f"{model_name}_{precision}.mlir",
return_ir=False,
)

View File

@@ -46,6 +46,7 @@ def compile_stableLM(
model_vmfb_name,
device="cuda",
precision="fp32",
debug=False,
):
from shark.shark_inference import SharkInference
@@ -92,7 +93,7 @@ def compile_stableLM(
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
)
print("Saved vmfb at ", str(path))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,94 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_submodules
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += copy_metadata('huggingface-hub')
datas += copy_metadata('sentencepiece')
datas += copy_metadata("pyyaml")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('py-cpuinfo')
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['scripts/vicuna.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_llama_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,876 @@
import argparse
import json
import re
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from typing import List, Optional, Tuple, Union
import numpy as np
import iree.runtime
import itertools
import subprocess
import torch
import torch_mlir
from torch_mlir import TensorPlaceholder
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LlamaPreTrainedModel,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
FirstVicunaLayer,
SecondVicunaLayer,
CompiledVicunaLayer,
ShardedVicunaModel,
LMHead,
LMHeadCompiled,
VicunaEmbedding,
VicunaEmbeddingCompiled,
VicunaNorm,
VicunaNormCompiled,
)
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna7B,
)
from apps.language_models.utils import (
get_vmfb_from_path,
)
from shark.shark_downloader import download_public_file
from shark.shark_importer import get_f16_inputs
from shark.shark_inference import SharkInference
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaRMSNorm,
_make_causal_mask,
_expand_mask,
)
from torch import nn
from time import time
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(config)
for _ in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
t1 = time()
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = (
use_cache if use_cache is not None else self.config.use_cache
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = (
seq_length_with_past + past_key_values_length
)
if position_ids is None:
device = (
input_ids.device
if input_ids is not None
else inputs_embeds.device
)
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.compressedlayers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[8 * idx : 8 * (idx + 1)]
if past_key_values is not None
else None
)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer.forward(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[1:],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
try:
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
except:
_ = 10
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
next_cache = tuple(itertools.chain.from_iterable(next_cache))
print(f"Token generated in {time() - t1} seconds")
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class EightLayerLayerSV(torch.nn.Module):
def __init__(self, layers):
super().__init__()
assert len(layers) == 8
self.layers = layers
def forward(
self,
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
):
pkvs = [
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
]
new_pkvs = []
for layer, pkv in zip(self.layers, pkvs):
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
pkv[0],
pkv[1],
),
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
return (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
class EightLayerLayerFV(torch.nn.Module):
def __init__(self, layers):
super().__init__()
assert len(layers) == 8
self.layers = layers
def forward(self, hidden_states, attention_mask, position_ids):
new_pkvs = []
for layer in self.layers:
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=None,
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
return (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
class CompiledEightLayerLayerSV(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
(
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
) = past_key_value
pkv00 = pkv00.detatch()
pkv01 = pkv01.detatch()
pkv10 = pkv10.detatch()
pkv11 = pkv11.detatch()
pkv20 = pkv20.detatch()
pkv21 = pkv21.detatch()
pkv30 = pkv30.detatch()
pkv31 = pkv31.detatch()
pkv40 = pkv40.detatch()
pkv41 = pkv41.detatch()
pkv50 = pkv50.detatch()
pkv51 = pkv51.detatch()
pkv60 = pkv60.detatch()
pkv61 = pkv61.detatch()
pkv70 = pkv70.detatch()
pkv71 = pkv71.detatch()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
),
send_to_host=False,
)
return (
output[0],
(output[1][0], output[1][1]),
(output[2][0], output[2][1]),
(output[3][0], output[3][1]),
(output[4][0], output[4][1]),
(output[5][0], output[5][1]),
(output[6][0], output[6][1]),
(output[7][0], output[7][1]),
(output[8][0], output[8][1]),
)
def forward_compressed(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = (
input_ids.device if input_ids is not None else inputs_embeds.device
)
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.compressedlayers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[8 * idx : 8 * (idx + 1)]
if past_key_values is not None
else None
)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1],
)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class CompiledEightLayerLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
t2 = time()
if past_key_value is None:
try:
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
except:
pass
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
t1 = time()
output = self.model(
"first_vicuna_forward",
(hidden_states, attention_mask, position_ids),
send_to_host=False,
)
output2 = (
output[0],
(
output[1],
output[2],
),
(
output[3],
output[4],
),
(
output[5],
output[6],
),
(
output[7],
output[8],
),
(
output[9],
output[10],
),
(
output[11],
output[12],
),
(
output[13],
output[14],
),
(
output[15],
output[16],
),
)
return output2
else:
(
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
) = past_key_value
try:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv00 = pkv00.detach()
pkv01 = pkv01.detach()
pkv10 = pkv10.detach()
pkv11 = pkv11.detach()
pkv20 = pkv20.detach()
pkv21 = pkv21.detach()
pkv30 = pkv30.detach()
pkv31 = pkv31.detach()
pkv40 = pkv40.detach()
pkv41 = pkv41.detach()
pkv50 = pkv50.detach()
pkv51 = pkv51.detach()
pkv60 = pkv60.detach()
pkv61 = pkv61.detach()
pkv70 = pkv70.detach()
pkv71 = pkv71.detach()
except:
x = 10
t1 = time()
if type(hidden_states) == iree.runtime.array_interop.DeviceArray:
hidden_states = np.array(hidden_states, hidden_states.dtype)
hidden_states = torch.tensor(hidden_states)
hidden_states = hidden_states.detach()
output = self.model(
"second_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
),
send_to_host=False,
)
print(f"{time() - t1}")
del pkv00
del pkv01
del pkv10
del pkv11
del pkv20
del pkv21
del pkv30
del pkv31
del pkv40
del pkv41
del pkv50
del pkv51
del pkv60
del pkv61
del pkv70
del pkv71
output2 = (
output[0],
(
output[1],
output[2],
),
(
output[3],
output[4],
),
(
output[5],
output[6],
),
(
output[7],
output[8],
),
(
output[9],
output[10],
),
(
output[11],
output[12],
),
(
output[13],
output[14],
),
(
output[15],
output[16],
),
)
return output2

View File

@@ -1,15 +1,13 @@
import torch
from transformers import AutoModelForCausalLM
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
class FirstVicuna(torch.nn.Module):
def __init__(
self,
model_path,
precision="fp32",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -18,15 +16,24 @@ class FirstVicuna(torch.nn.Module):
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.accumulates = (
torch.float32 if accumulates == "fp32" else torch.float16
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("First Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
@@ -40,19 +47,22 @@ class FirstVicuna(torch.nn.Module):
def forward(self, input_ids):
op = self.model(input_ids=input_ids, use_cache=True)
return_vals = []
return_vals.append(op.logits)
token = torch.argmax(op.logits[:, -1, :], dim=1)
return_vals.append(token)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return_vals.append(item[0].transpose(1,2))
return_vals.append(item[1].transpose(1,2))
return tuple(return_vals)
class SecondVicuna(torch.nn.Module):
class SecondVicuna7B(torch.nn.Module):
def __init__(
self,
model_path,
precision="fp32",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -64,12 +74,21 @@ class SecondVicuna(torch.nn.Module):
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
self.accumulates = (
torch.float32 if accumulates == "fp32" else torch.float16
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
@@ -148,8 +167,6 @@ class SecondVicuna(torch.nn.Module):
i63,
i64,
):
# input_ids = input_tuple[0]
# input_tuple = torch.unbind(pkv, dim=0)
token = i0
past_key_values = (
(i1, i2),
@@ -278,6 +295,845 @@ class SecondVicuna(torch.nn.Module):
i64,
),
)
past_key_values = [(x[0].transpose(1,2), x[0].transpose(1,2)) for x in past_key_values]
past_key_values = tuple(past_key_values)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
token = torch.argmax(op.logits[:, -1, :], dim=1)
return_vals.append(token)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0].transpose(1,2))
return_vals.append(item[1].transpose(1,2))
return tuple(return_vals)
class SecondVicuna13B(torch.nn.Module):
def __init__(
self,
model_path,
precision="int8",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
self.accumulates = (
torch.float32 if accumulates == "fp32" else torch.float16
)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
i65,
i66,
i67,
i68,
i69,
i70,
i71,
i72,
i73,
i74,
i75,
i76,
i77,
i78,
i79,
i80,
):
token = i0
past_key_values = (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
(
i65,
i66,
),
(
i67,
i68,
),
(
i69,
i70,
),
(
i71,
i72,
),
(
i73,
i74,
),
(
i75,
i76,
),
(
i77,
i78,
),
(
i79,
i80,
),
)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class SecondVicuna70B(torch.nn.Module):
def __init__(
self,
model_path,
precision="fp32",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
self.accumulates = (
torch.float32 if accumulates == "fp32" else torch.float16
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
i65,
i66,
i67,
i68,
i69,
i70,
i71,
i72,
i73,
i74,
i75,
i76,
i77,
i78,
i79,
i80,
i81,
i82,
i83,
i84,
i85,
i86,
i87,
i88,
i89,
i90,
i91,
i92,
i93,
i94,
i95,
i96,
i97,
i98,
i99,
i100,
i101,
i102,
i103,
i104,
i105,
i106,
i107,
i108,
i109,
i110,
i111,
i112,
i113,
i114,
i115,
i116,
i117,
i118,
i119,
i120,
i121,
i122,
i123,
i124,
i125,
i126,
i127,
i128,
i129,
i130,
i131,
i132,
i133,
i134,
i135,
i136,
i137,
i138,
i139,
i140,
i141,
i142,
i143,
i144,
i145,
i146,
i147,
i148,
i149,
i150,
i151,
i152,
i153,
i154,
i155,
i156,
i157,
i158,
i159,
i160,
):
token = i0
past_key_values = (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
(
i65,
i66,
),
(
i67,
i68,
),
(
i69,
i70,
),
(
i71,
i72,
),
(
i73,
i74,
),
(
i75,
i76,
),
(
i77,
i78,
),
(
i79,
i80,
),
(
i81,
i82,
),
(
i83,
i84,
),
(
i85,
i86,
),
(
i87,
i88,
),
(
i89,
i90,
),
(
i91,
i92,
),
(
i93,
i94,
),
(
i95,
i96,
),
(
i97,
i98,
),
(
i99,
i100,
),
(
i101,
i102,
),
(
i103,
i104,
),
(
i105,
i106,
),
(
i107,
i108,
),
(
i109,
i110,
),
(
i111,
i112,
),
(
i113,
i114,
),
(
i115,
i116,
),
(
i117,
i118,
),
(
i119,
i120,
),
(
i121,
i122,
),
(
i123,
i124,
),
(
i125,
i126,
),
(
i127,
i128,
),
(
i129,
i130,
),
(
i131,
i132,
),
(
i133,
i134,
),
(
i135,
i136,
),
(
i137,
i138,
),
(
i139,
i140,
),
(
i141,
i142,
),
(
i143,
i144,
),
(
i145,
i146,
),
(
i147,
i148,
),
(
i149,
i150,
),
(
i151,
i152,
),
(
i153,
i154,
),
(
i155,
i156,
),
(
i157,
i158,
),
(
i159,
i160,
),
)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
@@ -298,15 +1154,17 @@ class CombinedModel(torch.nn.Module):
):
super().__init__()
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
# NOT using this path for 13B currently, hence using `SecondVicuna7B`.
self.second_vicuna = SecondVicuna7B(second_vicuna_model_path)
def forward(self, input_ids):
first_output = self.first_vicuna(input_ids=input_ids, use_cache=True)
logits = first_output[0]
pkv = first_output[1:]
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
second_output = self.second_vicuna(secondVicunaInput)
first_output = self.first_vicuna(input_ids=input_ids)
# generate second vicuna
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
second_output = self.second_vicuna(*secondVicunaCompileInput)
return second_output

File diff suppressed because it is too large Load Diff

View File

@@ -66,7 +66,7 @@ class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers) == len(model.model.layers)
# assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers = layers
@@ -132,7 +132,10 @@ class VicunaNormCompiled(torch.nn.Module):
self.model = shark_module
def forward(self, hidden_states):
hidden_states.detach()
try:
hidden_states.detach()
except:
pass
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output

View File

@@ -3,7 +3,10 @@ from abc import ABC, abstractmethod
class SharkLLMBase(ABC):
def __init__(
self, model_name, hf_model_path=None, max_num_tokens=512
self,
model_name,
hf_model_path=None,
max_num_tokens=512,
) -> None:
self.model_name = model_name
self.hf_model_path = hf_model_path

View File

@@ -7,9 +7,9 @@ from io import BytesIO
from pathlib import Path
from contextlib import redirect_stdout
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
@@ -28,9 +28,11 @@ parser = argparse.ArgumentParser(
description="runs a falcon model",
)
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
parser.add_argument(
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
)
parser.add_argument(
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
@@ -49,7 +51,7 @@ parser.add_argument(
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
@@ -59,32 +61,52 @@ parser.add_argument(
action=argparse.BooleanOptionalAction,
help="Run model in cli mode",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication token for falcon-180B model.",
)
class Falcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_auth_token: str = None,
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=None,
falcon_vmfb_path=None,
debug=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
print("hf_model_path: ", self.hf_model_path)
if "180b" in self.model_name and hf_auth_token == None:
raise ValueError(
""" HF auth token required for falcon-180b. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.src_model = self.get_src_model()
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, trust_remote_code=True
self.hf_model_path,
trust_remote_code=True,
token=self.hf_auth_token,
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 11
@@ -92,13 +114,24 @@ class Falcon(SharkLLMBase):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
kwargs = {
"torch_dtype": torch.float,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile_falcon(self):
def compile(self):
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
@@ -120,37 +153,37 @@ class Falcon(SharkLLMBase):
if vmfb is not None:
return vmfb
print(
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
)
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
if self.falcon_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
print(
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
if args.load_mlir_from_shark_tank:
# Downloading MLIR from shark_tank
print(f"[DEBUG] Trying to download mlir from shark_tank")
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
print(
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
)
mlir_generated = True
if not mlir_generated:
print(f"[DEBUG] generating MLIR locally")
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 100)
)
@@ -167,9 +200,10 @@ class Falcon(SharkLLMBase):
ts_graph = import_with_fx(
model,
falconCompileInput,
is_f16=self.precision == "fp16",
is_f16=self.precision in ["fp16", "int4"],
f16_input_mask=[False, False],
mlir_type="torchscript",
is_gptq=self.precision == "int4",
)
del model
print(f"[DEBUG] generating torch mlir")
@@ -189,35 +223,37 @@ class Falcon(SharkLLMBase):
bytecode = bytecode_stream.getvalue()
del module
print(f"[DEBUG] writing mlir to file")
with open(f"{self.model_name}.mlir", "wb") as f_:
with redirect_stdout(f_):
print(module.operation.get_asm())
f_ = open(self.falcon_mlir_path, "wb")
f_.write(bytecode)
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
f_.close()
del bytecode
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
mlir_module=self.falcon_mlir_path,
device=self.device,
mlir_dialect="linalg",
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-spirv-index-bits=64",
],
]
+ [
"--iree-llvmcpu-use-fast-min-max-ops",
]
if self.precision == "int4"
else [],
debug=self.debug,
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile(self):
falcon_shark_model = self.compile_falcon()
return falcon_shark_model
def generate(self, prompt):
model_inputs = self.tokenizer(
prompt,
@@ -387,7 +423,7 @@ class Falcon(SharkLLMBase):
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
if self.precision == "fp16":
if self.precision in ["fp16", "int4"]:
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
@@ -466,11 +502,26 @@ if __name__ == "__main__":
else Path(args.falcon_vmfb_path)
)
if args.precision == "int4":
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
else:
hf_model_path_value = (
"TheBloke/falcon-"
+ args.falcon_variant_to_use
+ "-instruct-GPTQ"
)
else:
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "tiiuae/falcon-180B-chat"
else:
hf_model_path_value = (
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
)
falcon = Falcon(
"falcon_" + args.falcon_variant_to_use,
hf_model_path="tiiuae/falcon-"
+ args.falcon_variant_to_use
+ "-instruct",
model_name="falcon_" + args.falcon_variant_to_use,
hf_model_path=hf_model_path_value,
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
@@ -497,7 +548,11 @@ if __name__ == "__main__":
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)
res_str = falcon.generate(prompt)
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
User: {prompt}
Assistant:"""
res_str = falcon.generate(prompt_template)
torch.cuda.empty_cache()
gc.collect()
print(

View File

@@ -126,7 +126,7 @@ def is_url(input_url):
import os
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
import torch
import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
@@ -136,7 +136,8 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
def brevitasmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
# fmt: off
def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
@@ -145,20 +146,21 @@ def brevitasmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh
raise ValueError("Input shapes not supported.")
def brevitasmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
def quantmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def brevitasmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
def quantmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics]
quantmatmul_rhs_group_quant〡shape,
quantmatmul_rhs_group_quant〡dtype,
quantmatmul_rhs_group_quant〡has_value_semantics]
# fmt: on
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
@@ -176,7 +178,7 @@ def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
def compile_module(
shark_module, extended_model_name, generate_vmfb, extra_args=[]
shark_module, extended_model_name, generate_vmfb, extra_args=[], debug=False,
):
if generate_vmfb:
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
@@ -188,7 +190,7 @@ def compile_module(
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
)
path = shark_module.save_module(
os.getcwd(), extended_model_name, extra_args
os.getcwd(), extended_model_name, extra_args, debug=debug
)
shark_module.load_module(path, extra_args=extra_args)
else:
@@ -197,7 +199,7 @@ def compile_module(
def compile_int_precision(
model, inputs, precision, device, generate_vmfb, extended_model_name
model, inputs, precision, device, generate_vmfb, extended_model_name, debug=False
):
torchscript_module = import_with_fx(
model,
@@ -209,7 +211,7 @@ def compile_int_precision(
torchscript_module,
inputs,
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
@@ -217,7 +219,7 @@ def compile_int_precision(
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
mlir_module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
from contextlib import redirect_stdout
@@ -233,6 +235,12 @@ def compile_int_precision(
mlir_module = BytesIO(mlir_module)
bytecode = mlir_module.read()
print(f"Elided IR written for {extended_model_name}")
bytecode = save_mlir(
bytecode,
model_name=extended_model_name,
frontend="torch",
dir=os.getcwd(),
)
return bytecode
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
@@ -249,6 +257,7 @@ def compile_int_precision(
extended_model_name=extended_model_name,
generate_vmfb=generate_vmfb,
extra_args=extra_args,
debug=debug,
),
bytecode,
)
@@ -292,6 +301,7 @@ def shark_compile_through_fx_int(
device,
generate_or_load_vmfb,
extended_model_name,
debug,
)
extra_args = [
"--iree-hal-dump-executable-sources-to=ies",

View File

@@ -32,11 +32,13 @@ class SharkStableLM(SharkLLMBase):
max_num_tokens=512,
device="cuda",
precision="fp32",
debug="False",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
@@ -111,7 +113,7 @@ class SharkStableLM(SharkLLMBase):
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
)
print("Saved vmfb at ", str(path))

View File

@@ -8,7 +8,7 @@ from shark.shark_downloader import download_public_file
# expects a Path / str as arg
# returns None if path not found or SharkInference module
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
if not isinstance(vmfb_path, Path):
vmfb_path = Path(vmfb_path)
@@ -20,7 +20,7 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
print("Loading vmfb from: ", vmfb_path)
print("Device from get_vmfb_from_path - ", device)
shark_module = SharkInference(
None, device=device, mlir_dialect=mlir_dialect
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
@@ -28,7 +28,13 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
def get_vmfb_from_config(
shark_container, model, precision, device, vmfb_path, padding=None
shark_container,
model,
precision,
device,
vmfb_path,
padding=None,
device_id=None,
):
vmfb_url = (
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
@@ -37,4 +43,6 @@ def get_vmfb_from_config(
vmfb_url = vmfb_url + f"_{padding}"
vmfb_url = vmfb_url + ".vmfb"
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
return get_vmfb_from_path(vmfb_path, device, "tm_tensor")
return get_vmfb_from_path(
vmfb_path, device, "tm_tensor", device_id=device_id
)

View File

@@ -7,16 +7,16 @@ Compile Commands FP32/FP16:
```shell
Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda /path/to/input/mlir -o /path/to/output/vmfb
CPU:
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu /path/to/input/mlir -o /path/to/output/vmfb
```

View File

@@ -34,7 +34,7 @@ from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
@@ -287,7 +287,7 @@ def lora_train(
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)

View File

@@ -15,8 +15,8 @@ pathex = [
# datafiles for pyinstaller
datas = []
datas += collect_data_files("torch")
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
@@ -30,26 +30,29 @@ datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("pytorch_lightning")
datas += collect_data_files("opencv_python")
datas += collect_data_files("skimage")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree")
datas += collect_data_files("google_cloud_storage")
datas += collect_data_files("shark")
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
datas += collect_data_files("cv2")
datas += [
("src/utils/resources/prompts.json", "resources"),
("src/utils/resources/model_db.json", "resources"),
@@ -72,6 +75,13 @@ datas += [
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("transformers") if "tests" not in x
x for x in collect_submodules("diffusers") if "tests" not in x
]
blacklist = ["tests", "convert"]
hiddenimports += [
x
for x in collect_submodules("transformers")
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]

View File

@@ -177,9 +177,11 @@ class SharkifyStableDiffusionModel:
"unet",
"unet512",
"stencil_unet",
"stencil_unet_512",
"vae",
"vae_encode",
"stencil_adaptor",
"stencil_adaptor_512",
]
index = 0
for model in sub_model_list:
@@ -339,7 +341,7 @@ class SharkifyStableDiffusionModel:
)
return shark_vae, vae_mlir
def get_controlled_unet(self):
def get_controlled_unet(self, use_large=False):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self,
@@ -415,6 +417,16 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
model_name = "stencil_unet"
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[:2]
+ (torch.nn.functional.pad(inputs[2], pad),)
+ inputs[3:]
)
model_name = "stencil_unet_512"
input_mask = [
True,
True,
@@ -437,19 +449,19 @@ class SharkifyStableDiffusionModel:
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["stencil_unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self):
def get_control_net(self, use_large=False):
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
@@ -497,17 +509,34 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor"]
)
input_mask = [True, True, True, True]
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
extended_model_name=self.model_name["stencil_adaptor"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_adaptor",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
@@ -681,8 +710,11 @@ class SharkifyStableDiffusionModel:
return self.text_encoder(input)[0]
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
save_dir = ""
if self.debug:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["clip"]
)
os.makedirs(
save_dir,
exist_ok=True,
@@ -748,7 +780,7 @@ class SharkifyStableDiffusionModel:
else:
return self.get_unet(use_large=use_large)
else:
return self.get_controlled_unet()
return self.get_controlled_unet(use_large=use_large)
def vae_encode(self):
try:
@@ -847,12 +879,14 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def controlnet(self):
def controlnet(self, use_large=False):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
use_large=use_large
)
check_compilation(compiled_stencil_adaptor, "Stencil")
if self.return_mlir:

View File

@@ -84,13 +84,35 @@ class Image2ImagePipeline(StableDiffusionPipeline):
num_inference_steps,
strength,
dtype,
resample_type,
):
# Pre process image -> get image encoded -> process latents
# TODO: process with variable HxW combos
# Pre process image
image = image.resize((width, height))
# Pre-process image
if resample_type == "Lanczos":
resample_type = Image.LANCZOS
elif resample_type == "Nearest Neighbor":
resample_type = Image.NEAREST
elif resample_type == "Bilinear":
resample_type = Image.BILINEAR
elif resample_type == "Bicubic":
resample_type = Image.BICUBIC
elif resample_type == "Adaptive":
resample_type = Image.ADAPTIVE
elif resample_type == "Antialias":
resample_type = Image.ANTIALIAS
elif resample_type == "Box":
resample_type = Image.BOX
elif resample_type == "Affine":
resample_type = Image.AFFINE
elif resample_type == "Cubic":
resample_type = Image.CUBIC
else: # Fallback to Lanczos
resample_type = Image.LANCZOS
image = image.resize((width, height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
@@ -147,6 +169,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -186,6 +209,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
# Get Image latents

View File

@@ -58,6 +58,7 @@ class StencilPipeline(StableDiffusionPipeline):
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
self.controlnet_512 = None
def load_controlnet(self):
if self.controlnet is not None:
@@ -68,6 +69,15 @@ class StencilPipeline(StableDiffusionPipeline):
del self.controlnet
self.controlnet = None
def load_controlnet_512(self):
if self.controlnet_512 is not None:
return
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
def unload_controlnet_512(self):
del self.controlnet_512
self.controlnet_512 = None
def prepare_latents(
self,
batch_size,
@@ -111,8 +121,12 @@ class StencilPipeline(StableDiffusionPipeline):
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
self.load_controlnet()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
self.load_controlnet()
else:
self.load_unet_512()
self.load_controlnet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
@@ -135,43 +149,82 @@ class StencilPipeline(StableDiffusionPipeline):
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
else:
control = self.controlnet_512(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
else:
print(self.unet_512)
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
@@ -191,7 +244,9 @@ class StencilPipeline(StableDiffusionPipeline):
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
self.unload_controlnet()
self.unload_controlnet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -218,6 +273,7 @@ class StencilPipeline(StableDiffusionPipeline):
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.

View File

@@ -84,9 +84,6 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
def _import(self):
scaling_model = ScalingModel()

View File

@@ -109,7 +109,7 @@ def load_lower_configs(base_model_id=None):
spec = spec.split("-")[0]
if args.annotation_model == "vae":
if not spec or spec in ["rdna3", "sm_80"]:
if not spec or spec in ["sm_80"]:
config_name = (
f"{args.annotation_model}_{args.precision}_{device}.json"
)
@@ -158,9 +158,9 @@ def load_lower_configs(base_model_id=None):
f"{spec}.json"
)
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading lowering config file from ", lowering_config_dir)
full_gs_url = config_bucket + config_name
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir

View File

@@ -132,6 +132,57 @@ p.add_argument(
"img2img.",
)
p.add_argument(
"--use_hiresfix",
type=bool,
default=False,
help="Use Hires Fix to do higher resolution images, while trying to "
"avoid the issues that come with it. This is accomplished by first "
"generating an image using txt2img, then running it through img2img.",
)
p.add_argument(
"--hiresfix_height",
type=int,
default=768,
choices=range(128, 769, 8),
help="The height of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_width",
type=int,
default=768,
choices=range(128, 769, 8),
help="The width of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_strength",
type=float,
default=0.6,
help="The denoising strength to apply for the Hires Fix.",
)
p.add_argument(
"--resample_type",
type=str,
default="Nearest Neighbor",
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
help="The resample type to use when resizing an image before being run "
"through stable diffusion.",
)
##############################################################################
# Stable Diffusion Training Params
##############################################################################
@@ -407,6 +458,14 @@ p.add_argument(
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
p.add_argument(
"--device_allocator_heap_key",
type=str,
default="",
help="Specify heap key for device caching allocator."
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################
@@ -519,6 +578,20 @@ p.add_argument(
"in shark importer. Does nothing if import_mlir is false (the default).",
)
p.add_argument(
"--compile_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to toggle debug assert/verify flags for imported IR in the"
"iree-compiler. Default to false.",
)
p.add_argument(
"--iree_constant_folding",
default=True,
action=argparse.BooleanOptionalAction,
help="Controls constant folding in iree-compile for all SD models.",
)
##############################################################################
# Web UI flags
@@ -568,6 +641,13 @@ p.add_argument(
help="Flag for enabling rest API.",
)
p.add_argument(
"--debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling debugging log in WebUI.",
)
p.add_argument(
"--output_gallery",
default=True,

View File

@@ -18,14 +18,14 @@ import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)
from shark.iree_utils.metal_utils import get_metal_target_triple
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from shark.iree_utils.gpu_utils import get_cuda_sm_cc, get_iree_rocm_args
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
@@ -78,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
)
)
path = shark_module.save_module(
os.getcwd(), model_name, extra_args
os.getcwd(), model_name, extra_args, debug=args.compile_debug
)
shark_module.load_module(path, extra_args=extra_args)
else:
@@ -154,8 +154,8 @@ def compile_through_fx(
f16_input_mask=f16_input_mask,
debug=debug,
model_name=extended_model_name,
save_dir=save_dir,
)
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
@@ -168,6 +168,14 @@ def compile_through_fx(
mlir_module, extended_model_name, base_model_id
)
if not os.path.isdir(save_dir):
save_dir = ""
mlir_module = save_mlir(
mlir_module,
model_name=extended_model_name,
dir=save_dir,
)
shark_module = SharkInference(
mlir_module,
device=args.device if device is None else device,
@@ -179,17 +187,22 @@ def compile_through_fx(
mlir_module,
)
del mlir_module
gc.collect()
def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if args.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={args.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
@@ -470,12 +483,25 @@ def get_available_devices():
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
@@ -499,10 +525,15 @@ def get_opt_flags(model, precision="fp16"):
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "rocm" in args.device:
rocm_args = get_iree_rocm_args()
iree_flags.extend(rocm_args)
print(iree_flags)
if args.iree_constant_folding == False:
iree_flags.append("--iree-opt-const-expr-hoisting=False")
iree_flags.append(
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
@@ -566,7 +597,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
)
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
checkpoint_path_or_dict=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
@@ -816,6 +847,8 @@ def clear_all():
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
if args.local_tank_cache != "":
shutil.rmtree(args.local_tank_cache)
def get_generated_imgs_path() -> Path:

View File

@@ -1,6 +1,7 @@
from multiprocessing import Process, freeze_support
import os
import sys
import logging
if sys.platform == "darwin":
# import before IREE to avoid torch-MLIR library issues
@@ -37,10 +38,12 @@ def launch_app(address):
height=height,
text_select=True,
)
webview.start(private_mode=False)
webview.start(private_mode=False, storage_path=os.getcwd())
if __name__ == "__main__":
if args.debug:
logging.basicConfig(level=logging.DEBUG)
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
if args.api or "api" in args.ui.split(","):
@@ -115,7 +118,8 @@ if __name__ == "__main__":
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
h2ogpt_web,
# h2ogpt_upload,
# h2ogpt_web,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
@@ -152,8 +156,9 @@ if __name__ == "__main__":
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
lora_train_web,
model_web,
# lora_train_web,
# model_web,
# model_config_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
@@ -211,6 +216,15 @@ if __name__ == "__main__":
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
# and that the order in the code here is the order they should
# appear in the ui, as the id value doesn't determine the order.
# Where possible, avoid changing the id of any tab that is the
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
@@ -236,16 +250,22 @@ if __name__ == "__main__":
upscaler_status,
]
)
with gr.TabItem(label="Model Manager", id=6):
model_web.render()
with gr.TabItem(label="LoRA Training (Experimental)", id=8):
lora_train_web.render()
with gr.TabItem(label="Chat Bot (Experimental)", id=7):
# with gr.TabItem(label="Model Manager", id=6):
# model_web.render()
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
# lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=8):
stablelm_chat.render()
with gr.TabItem(label="MultiModal (Experimental)", id=9):
# with gr.TabItem(
# label="Generate Sharding Config (Experimental)", id=9
# ):
# model_config_web.render()
with gr.TabItem(label="MultiModal (Experimental)", id=10):
minigpt4_web.render()
with gr.TabItem(label="DocuChat(Experimental)", id=10):
h2ogpt_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
# send to buttons
register_button_click(

View File

@@ -78,7 +78,7 @@ from apps.stable_diffusion.web.ui.stablelm_ui import (
stablelm_chat,
llm_chat_api,
)
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_web
from apps.stable_diffusion.web.ui.generate_config import model_config_web
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_web,

View File

@@ -0,0 +1,41 @@
import gradio as gr
import torch
from transformers import AutoTokenizer
from apps.language_models.src.model_wrappers.vicuna_model import CombinedModel
from shark.shark_generate_model_config import GenerateConfigFile
def get_model_config():
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
return c.split_into_layers()
with gr.Blocks() as model_config_web:
with gr.Row():
hf_models = gr.Dropdown(
label="Model List",
choices=["Vicuna"],
value="Vicuna",
visible=True,
)
get_model_config_btn = gr.Button(value="Get Model Config")
json_view = gr.JSON()
get_model_config_btn.click(
fn=get_model_config,
inputs=[],
outputs=[json_view],
)

View File

@@ -12,6 +12,10 @@ from apps.language_models.langchain.enums import (
LangChainAction,
)
import apps.language_models.langchain.gen as gen
from gpt_langchain import (
path_to_docs,
create_or_update_db,
)
from apps.stable_diffusion.src import args
@@ -33,8 +37,15 @@ start_message = """
def create_prompt(history):
system_message = start_message
for item in history:
print("His item: ", item)
conversation = "".join(["".join([item[0], item[1]]) for item in history])
conversation = "<|endoftext|>".join(
[
"<|endoftext|><|answer|>".join([item[0], item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
@@ -44,10 +55,12 @@ def create_prompt(history):
def chat(curr_system_message, history, device, precision):
args.run_docuchat_web = True
global h2ogpt_model
global sharkModel
global h2ogpt_tokenizer
global model_state
global langchain
global userpath_selector
from apps.language_models.langchain.h2oai_pipeline import generate_token
if h2ogpt_model == 0:
if "cuda" in device:
@@ -102,9 +115,14 @@ def chat(curr_system_message, history, device, precision):
prompt_type=None,
prompt_dict=None,
)
from apps.language_models.langchain.h2oai_pipeline import (
H2OGPTSHARKModel,
)
sharkModel = H2OGPTSHARKModel()
prompt = create_prompt(history)
output = langchain.evaluate(
output_dict = langchain.evaluate(
model_state=model_state,
my_db_state=None,
instruction=prompt,
@@ -164,14 +182,22 @@ def chat(curr_system_message, history, device, precision):
model_lock=True,
user_path=userpath_selector.value,
)
for partial_text in output:
history[-1][1] = partial_text["response"]
yield history
output = generate_token(sharkModel, **output_dict)
for partial_text in output:
history[-1][1] = partial_text
yield history
return history
with gr.Blocks(title="H2OGPT") as h2ogpt_web:
userpath_selector = gr.Textbox(
label="Document Directory",
value=str(os.path.abspath("apps/language_models/langchain/user_path/")),
interactive=True,
container=True,
)
with gr.Blocks(title="DocuChat") as h2ogpt_web:
with gr.Row():
supported_devices = available_devices
enabled = len(supported_devices) > 0
@@ -186,6 +212,7 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
)
precision = gr.Radio(
label="Precision",
@@ -198,14 +225,6 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
],
visible=True,
)
userpath_selector = gr.Textbox(
label="Document Directory",
value=str(
os.path.abspath("apps/language_models/langchain/user_path/")
),
interactive=True,
container=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
@@ -249,3 +268,100 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)
with gr.Blocks(title="DocuChat Upload") as h2ogpt_upload:
import pathlib
upload_path = None
database = None
database_directory = os.path.abspath(
"apps/language_models/langchain/db_path/"
)
def read_path():
global upload_path
filenames = [
[f]
for f in os.listdir(upload_path)
if os.path.isfile(os.path.join(upload_path, f))
]
filenames.sort()
return filenames
def upload_file(f):
names = []
for tmpfile in f:
name = tmpfile.name.split("/")[-1]
basename = os.path.join(upload_path, name)
with open(basename, "wb") as w:
with open(tmpfile.name, "rb") as r:
w.write(r.read())
update_or_create_db()
return read_path()
def update_userpath(newpath):
global upload_path
upload_path = newpath
pathlib.Path(upload_path).mkdir(parents=True, exist_ok=True)
return read_path()
def update_or_create_db():
global database
global upload_path
sources = path_to_docs(
upload_path,
verbose=True,
fail_any_exception=False,
n_jobs=-1,
chunk=True,
chunk_size=512,
url=None,
enable_captions=False,
captions_model=None,
caption_loader=None,
enable_ocr=False,
)
pathlib.Path(database_directory).mkdir(parents=True, exist_ok=True)
database = create_or_update_db(
"chroma",
database_directory,
"UserData",
sources,
False,
True,
True,
"sentence-transformers/all-MiniLM-L6-v2",
)
def first_run():
global database
if database is None:
update_or_create_db()
update_userpath(
os.path.abspath("apps/language_models/langchain/user_path/")
)
h2ogpt_upload.load(fn=first_run)
h2ogpt_web.load(fn=first_run)
with gr.Column():
text = gr.DataFrame(
col_count=(1, "fixed"),
type="array",
label="Documents",
value=read_path(),
)
with gr.Row():
upload = gr.UploadButton(
label="Upload documents",
file_count="multiple",
)
upload.upload(fn=upload_file, inputs=upload, outputs=text)
userpath_selector.render()
userpath_selector.input(
fn=update_userpath, inputs=userpath_selector, outputs=text
).then(fn=update_or_create_db)

View File

@@ -3,6 +3,7 @@ import torch
import time
import gradio as gr
import PIL
from math import ceil
from PIL import Image
import base64
from io import BytesIO
@@ -67,6 +68,7 @@ def img2img_inf(
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -245,7 +247,7 @@ def img2img_inf(
batch_size,
height,
width,
steps,
ceil(steps / strength),
strength,
guidance_scale,
seeds[current_batch],
@@ -255,6 +257,7 @@ def img2img_inf(
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
resample_type=resample_type,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
@@ -348,6 +351,7 @@ def img2img_api(
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
resample_type="Lanczos",
)
# Converts generator type to subscriptable
@@ -392,6 +396,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
allow_custom_value=True,
)
img2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -417,6 +422,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -432,7 +438,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
lines=2,
elem_id="negative_prompt_box",
)
# TODO: make this import image prompt info if it exists
img2img_init_image = gr.Image(
label="Input Image",
source="upload",
@@ -448,6 +454,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Stencil model",
value="None",
choices=["None", "canny", "openpose", "scribble"],
allow_custom_value=True,
)
def show_canvas(choice):
@@ -508,6 +515,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
).replace("\\", "\n\\")
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
lora_weights = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=i2i_lora_info,
elem_id="lora_weights",
@@ -531,6 +539,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -550,15 +559,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
@@ -581,11 +581,36 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
step=0.01,
label="Denoising Strength",
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
label="Resample Type",
allow_custom_value=True,
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
@@ -629,6 +654,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
@@ -695,6 +721,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
lora_hf_id,
ondemand,
repeatable_seeds,
resample_type,
],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -344,6 +344,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
allow_custom_value=True,
)
inpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -369,6 +370,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -406,6 +408,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -424,6 +427,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -527,6 +531,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")

View File

@@ -50,6 +50,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
allow_custom_value=True,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -73,6 +74,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -105,6 +107,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
label="Scheduler",
value=args.scheduler,
choices=scheduler_list,
allow_custom_value=True,
)
with gr.Row():
height = gr.Slider(
@@ -177,6 +180,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
with gr.Column(scale=2):

View File

@@ -109,7 +109,7 @@ with gr.Blocks() as minigpt4_web:
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=0.5):
with gr.Column():
image = gr.Image(type="pil")
upload_button = gr.Button(
value="Upload & Start Chat",
@@ -143,6 +143,7 @@ with gr.Blocks() as minigpt4_web:
# else "Only CUDA Supported for now",
choices=["cuda"],
interactive=False,
allow_custom_value=True,
)
with gr.Column():

View File

@@ -98,6 +98,7 @@ with gr.Blocks() as model_web:
choices=None,
value=None,
visible=False,
allow_custom_value=True,
)
# TODO: select and SendTo
civit_models = gr.Gallery(

View File

@@ -351,6 +351,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
allow_custom_value=True,
)
outpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -376,6 +377,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -411,6 +413,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -429,6 +432,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -555,6 +559,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")

View File

@@ -109,6 +109,7 @@ with gr.Blocks() as outputgallery_web:
value="",
interactive=True,
elem_classes="dropdown_no_container",
allow_custom_value=True,
)
with gr.Column(
scale=1,

View File

@@ -7,6 +7,8 @@ from transformers import (
)
from apps.stable_diffusion.web.ui.utils import available_devices
from datetime import datetime as dt
import json
import sys
def user(message, history):
@@ -22,11 +24,9 @@ past_key_values = None
model_map = {
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
"llama2_13b": "meta-llama/Llama-2-13b-chat-hf",
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
}
# NOTE: Each `model_name` should have its own start message
@@ -40,6 +40,15 @@ start_message = {
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_13b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_70b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
@@ -49,54 +58,41 @@ start_message = {
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"StableLM": (
"<|SYSTEM|># StableLM Tuned (Alpha version)"
"\n- StableLM is a helpful and harmless open-source AI language model "
"developed by StabilityAI."
"\n- StableLM is excited to be able to help the user, but will refuse "
"to do anything that could be considered harmful to the user."
"\n- StableLM is more than just an information source, StableLM is also "
"able to write poetry, short stories, and make jokes."
"\n- StableLM will refuse to participate in anything that "
"could harm a human."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"vicuna1p3": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"codegen": "",
}
def create_prompt(model_name, history):
system_message = start_message[model_name]
def create_prompt(model_name, history, prompt_prefix):
system_message = ""
if prompt_prefix:
system_message = start_message[model_name]
if model_name in [
"StableLM",
"vicuna",
"vicuna1p3",
"llama2_7b",
"llama2_70b",
]:
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
msg = system_message + conversation
msg = msg.strip()
msg = system_message + conversation
msg = msg.strip()
return msg
@@ -105,84 +101,178 @@ def set_vicuna_model(model):
vicuna_model = model
def get_default_config():
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
CombinedModel,
)
from shark.shark_generate_model_config import GenerateConfigFile
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()
model_vmfb_key = ""
# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision, cli=True):
def chat(
prompt_prefix,
history,
model,
device,
precision,
download_vmfb,
config_file,
cli=False,
progress=gr.Progress(),
):
global past_key_values
global model_vmfb_key
global vicuna_model
device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")
if model_name in [
"vicuna",
"vicuna1p3",
"codegen",
"llama2_7b",
"llama2_70b",
]:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
# get iree flags that need to be overridden, from commandline args
_extra_args = []
# vulkan target triple
vulkan_target_triple = args.iree_vulkan_target_triple
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
get_vulkan_target_triple,
)
from apps.stable_diffusion.src import args
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
# We already have the device_id extracted via WebUI, so we directly use
# that to find the target triple.
vulkan_target_triple = get_vulkan_target_triple(
vulkaninfo_list[device_id]
)
_extra_args.append(
f"-iree-vulkan-target-triple={vulkan_target_triple}"
)
if "rdna" in vulkan_target_triple:
flags_to_add = [
"--iree-spirv-index-bits=64",
]
_extra_args = _extra_args + flags_to_add
max_toks = 128 if model_name == "codegen" else 512
if device_id is None:
id = 0
for device in vulkaninfo_list:
target_triple = get_vulkan_target_triple(
vulkaninfo_list[id]
)
if target_triple == vulkan_target_triple:
device_id = id
break
id += 1
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use target triple : {vulkan_target_triple}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
)
else:
# if config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
vulkan_target_triple=vulkan_target_triple,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
device_id=device_id,
)
prompt = create_prompt(model_name, history)
for partial_text in vicuna_model.generate(prompt, cli=cli):
history[-1][1] = partial_text
yield history
if vicuna_model is None:
sys.exit("Unable to instantiate the model object, exiting.")
return history
# else Model is StableLM
global sharkModel
from apps.language_models.src.pipelines.stablelm_pipeline import (
SharkStableLM,
)
if sharkModel == 0:
# max_new_tokens=512
shark_slm = SharkStableLM(
model_name
) # pass elements from UI as required
# Construct the input message string for the model by concatenating the
# current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
prompt = create_prompt(model_name, history)
generate_kwargs = dict(prompt=prompt)
words_list = shark_slm.generate(**generate_kwargs)
prompt = create_prompt(model_name, history, prompt_prefix)
partial_text = ""
for new_text in words_list:
print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to clean up the message textbox and the updated
# conversation history
yield history
return words_list
token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
if msg is None:
if is_first:
prefill_time = exec_time
is_first = False
else:
total_time_ms += exec_time
token_count += 1
partial_text += text + " "
history[-1][1] = partial_text
yield history, f"Prefill: {prefill_time:.2f}"
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
)
return history, ""
def llm_chat_api(InputData: dict):
@@ -218,6 +308,7 @@ def llm_chat_api(InputData: dict):
UnshardedVicuna,
)
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
@@ -226,6 +317,7 @@ def llm_chat_api(InputData: dict):
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
else:
print("unrecognized device")
@@ -236,6 +328,9 @@ def llm_chat_api(InputData: dict):
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
device_id=device_id,
)
# TODO: add role dict for different models
@@ -300,13 +395,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
label="Select Model",
value=model_choices[0],
choices=model_choices,
allow_custom_value=True,
)
supported_devices = available_devices
enabled = len(supported_devices) > 0
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
@@ -314,23 +409,39 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
# multiselect=True,
)
precision = gr.Radio(
label="Precision",
value="fp16",
value="int4",
choices=[
"int4",
"int8",
"fp16",
"fp32",
],
visible=True,
visible=False,
)
with gr.Row():
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
interactive=True,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=False,
interactive=True,
)
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(label="Upload sharding configuration")
json_view_button = gr.Button("View as JSON")
json_view = gr.JSON()
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
@@ -349,24 +460,47 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
system_msg = gr.Textbox(
start_message, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
stop.click(

View File

@@ -4,6 +4,7 @@ import time
import sys
import gradio as gr
from PIL import Image
from math import ceil
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
@@ -26,6 +27,7 @@ from apps.stable_diffusion.src import (
utils,
save_output_img,
prompt_examples,
Image2ImagePipeline,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
@@ -62,6 +64,11 @@ def txt2img_inf(
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
use_hiresfix: bool,
hiresfix_height: int,
hiresfix_width: int,
hiresfix_strength: float,
resample_type: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -200,6 +207,81 @@ def txt2img_inf(
cpu_scheduling,
args.max_embeddings_multiples,
)
# TODO: allow user to save original image
# TODO: add option to let user keep both pipelines loaded, and unload
# either at will
# TODO: add custom step value slider
# TODO: add option to use secondary model for the img2img pass
if use_hiresfix is True:
new_config_obj = Config(
"img2img",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
1,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil="None",
ondemand=ondemand,
)
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
global_obj.set_sd_obj(
Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
1,
hiresfix_height,
hiresfix_width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(args.scheduler)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
out_imgs[0],
batch_size,
hiresfix_height,
hiresfix_width,
ceil(steps / hiresfix_strength),
hiresfix_strength,
guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil="None",
resample_type=resample_type,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
@@ -271,6 +353,11 @@ def txt2img_api(
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
use_hiresfix=False,
hiresfix_height=512,
hiresfix_width=512,
hiresfix_strength=0.6,
resample_type="Nearest Neighbor",
)
# Convert Generator to Subscriptable
@@ -319,6 +406,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
allow_custom_value=True,
)
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -343,6 +431,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
else "None",
choices=["None"]
+ get_custom_model_files("vae"),
allow_custom_value=True,
)
with gr.Column(scale=1, min_width=170):
txt2img_png_info_img = gr.Image(
@@ -379,6 +468,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -397,6 +487,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Scheduler",
value=args.scheduler,
choices=scheduler_list,
allow_custom_value=True,
)
with gr.Column():
save_metadata_to_png = gr.Checkbox(
@@ -460,6 +551,50 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Low VRAM",
interactive=True,
)
with gr.Group():
with gr.Row():
use_hiresfix = gr.Checkbox(
value=args.use_hiresfix,
label="Use Hires Fix",
interactive=True,
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
label="Resample Type",
allow_custom_value=True,
)
hiresfix_height = gr.Slider(
384,
768,
value=args.hiresfix_height,
step=8,
label="Hires Fix Height",
)
hiresfix_width = gr.Slider(
384,
768,
value=args.hiresfix_width,
step=8,
label="Hires Fix Width",
)
hiresfix_strength = gr.Slider(
0,
1,
value=args.hiresfix_strength,
step=0.01,
label="Hires Fix Denoising Strength",
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
@@ -494,17 +629,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
@@ -530,6 +656,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
show_label=False,
)
txt2img_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -565,6 +703,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
lora_hf_id,
ondemand,
repeatable_seeds,
use_hiresfix,
hiresfix_height,
hiresfix_width,
hiresfix_strength,
resample_type,
],
outputs=[txt2img_gallery, std_output, txt2img_status],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -365,6 +365,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
custom_checkpoint_type="upscaler"
)
+ predefined_upscaler_models,
allow_custom_value=True,
)
upscaler_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -390,6 +391,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -425,6 +427,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -443,6 +446,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Scheduler",
value="DDIM",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -547,6 +551,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")

View File

@@ -25,7 +25,7 @@ class Config:
device: str
use_lora: str
use_stencil: str
ondemand: str
ondemand: str # should this be expecting a bool instead?
custom_model_filetypes = (

View File

@@ -24,13 +24,13 @@ def get_image(url, local_filename):
shutil.copyfileobj(res.raw, f)
def compare_images(new_filename, golden_filename):
def compare_images(new_filename, golden_filename, upload=False):
new = np.array(Image.open(new_filename)) / 255.0
golden = np.array(Image.open(golden_filename)) / 255.0
diff = np.abs(new - golden)
mean = np.mean(diff)
if mean > 0.1:
if os.name != "nt":
if os.name != "nt" and upload == True:
subprocess.run(
[
"gsutil",
@@ -39,7 +39,7 @@ def compare_images(new_filename, golden_filename):
"gs://shark_tank/testdata/builder/",
]
)
raise SystemExit("new and golden not close")
raise AssertionError("new and golden not close")
else:
print("SUCCESS")

View File

@@ -1,5 +1,6 @@
#!/bin/bash
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
IMPORTER=1 BENCHMARK=1 NO_BREVITAS=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python build_tools/stable_diffusion_testing.py --gen
python tank/generate_sharktank.py

View File

@@ -63,7 +63,14 @@ def get_inpaint_inputs():
open("./test_images/inputs/mask.png", "wb").write(mask.content)
def test_loop(device="vulkan", beta=False, extra_flags=[]):
def test_loop(
device="vulkan",
beta=False,
extra_flags=[],
upload_bool=True,
exit_on_fail=True,
do_gen=False,
):
# Get golden values from tank
shutil.rmtree("./test_images", ignore_errors=True)
model_metrics = []
@@ -81,6 +88,8 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
if beta:
extra_flags.append("--beta_models=True")
extra_flags.append("--no-progress_bar")
if do_gen:
extra_flags.append("--import_debug")
to_skip = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
@@ -181,7 +190,14 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
"./test_images/golden/" + model_name + "/*.png"
)
golden_file = glob(golden_path)[0]
compare_images(test_file, golden_file)
try:
compare_images(
test_file, golden_file, upload=upload_bool
)
except AssertionError as e:
print(e)
if exit_on_fail == True:
raise
else:
print(command)
print("failed to generate image for this configuration")
@@ -200,6 +216,9 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
if do_gen:
prepare_artifacts()
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)
@@ -218,15 +237,49 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
f.write(";".join(output) + "\n")
def prepare_artifacts():
gen_path = os.path.join(os.getcwd(), "gen_shark_tank")
if not os.path.isdir(gen_path):
os.mkdir(gen_path)
for dirname in os.listdir(os.getcwd()):
for modelname in ["clip", "unet", "vae"]:
if modelname in dirname and "vmfb" not in dirname:
if not os.path.isdir(os.path.join(gen_path, dirname)):
shutil.move(os.path.join(os.getcwd(), dirname), gen_path)
print(f"Moved dir: {dirname} to {gen_path}.")
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--device", default="vulkan")
parser.add_argument(
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
)
parser.add_argument("-e", "--extra_args", type=str, default=None)
parser.add_argument(
"-u", "--upload", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-g", "--gen", action=argparse.BooleanOptionalAction, default=False
)
if __name__ == "__main__":
args = parser.parse_args()
print(args)
test_loop(args.device, args.beta, [])
extra_args = []
if args.extra_args:
for arg in args.extra_args.split(","):
extra_args.append(arg)
test_loop(
args.device,
args.beta,
extra_args,
args.upload,
args.exit_on_fail,
args.gen,
)
if args.gen:
prepare_artifacts()

View File

@@ -27,7 +27,7 @@ include(FetchContent)
FetchContent_Declare(
iree
GIT_REPOSITORY https://github.com/nod-ai/shark-runtime.git
GIT_REPOSITORY https://github.com/nod-ai/srt.git
GIT_TAG shark
GIT_SUBMODULES_RECURSE OFF
GIT_SHALLOW OFF

View File

@@ -40,7 +40,7 @@ cmake --build build/
*Prepare the model*
```bash
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux resnet50_tf.mlir -o resnet50_tf.vmfb
```
*Prepare the input*
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
```bash
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux stable_diff_tf.mlir -o stable_diff_tf.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
```
VAE and Autoencoder are also available
```bash
# VAE
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux vae.mlir -o vae.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
# CLIP Autoencoder
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux clip_autoencoder.mlir -o clip_autoencoder.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
```

View File

@@ -55,7 +55,7 @@ The command line for compilation will start something like this, where the `-` n
The `-o output_filename.vmfb` flag can be used to specify the location to save the compiled vmfb. Note that a dump of the
dispatches that can be compiled + run in isolation can be generated by adding `--iree-hal-dump-executable-benchmarks-to=/some/directory`. Say, if they are in the `benchmarks` directory, the following compile/run commands would work for Vulkan on RDNA3.
```
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
iree-benchmark-module --module=benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb --function=forward --device=vulkan
```
@@ -63,8 +63,8 @@ Where `${NUM}` is the dispatch number that you want to benchmark/profile in isol
### Enabling Tracy for Vulkan profiling
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime builds an
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SHARK-Runtime/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime (SRT) builds an
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SRT/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
```
$env:IREE_ENABLE_RUNTIME_TRACING="ON"
```

View File

@@ -1,192 +0,0 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.17)
project(sharkbackend LANGUAGES C CXX)
#
# Options
#
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()
#
# Dependencies
#
# FetchContent requires us to include the transitive closure of all
# repos that we depend on so that we can override the tags.
#
include(FetchContent)
FetchContent_Declare(
repo-common
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
GIT_TAG ${TRITON_COMMON_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-core
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
GIT_TAG ${TRITON_CORE_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-backend
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
#
# The backend must be built into a shared library. Use an ldscript to
# hide all symbols except for the TRITONBACKEND API.
#
configure_file(src/libtriton_dshark.ldscript libtriton_dshark.ldscript COPYONLY)
add_library(
triton-dshark-backend SHARED
src/dshark.cc
#src/dshark_driver_module.c
)
add_library(
SharkBackend::triton-dshark-backend ALIAS triton-dshark-backend
)
target_include_directories(
triton-dshark-backend
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src
)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
add_subdirectory(thirdparty/shark-runtime EXCLUDE_FROM_ALL)
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
iree_hal_hal
iree_hal_cuda_cuda
iree_hal_cuda_registration_registration
iree_hal_vmvx_registration_registration
iree_hal_dylib_registration_registration
iree_modules_hal_hal
iree_vm_vm
iree_vm_bytecode_module
iree_hal_local_loaders_system_library_loader
iree_hal_local_loaders_vmvx_module_loader
)
target_compile_features(triton-dshark-backend PRIVATE cxx_std_11)
target_link_libraries(
triton-dshark-backend
PRIVATE
triton-core-serverapi # from repo-core
triton-core-backendapi # from repo-core
triton-core-serverstub # from repo-core
triton-backend-utils # from repo-backend
)
if(WIN32)
set_target_properties(
triton-dshark-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_dshark
)
else()
set_target_properties(
triton-dshark-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_dshark
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_dshark.ldscript
LINK_FLAGS "-Wl,--version-script libtriton_dshark.ldscript"
)
endif()
#
# Install
#
include(GNUInstallDirs)
set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/SharkBackend)
install(
TARGETS
triton-dshark-backend
EXPORT
triton-dshark-backend-targets
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
)
install(
EXPORT
triton-dshark-backend-targets
FILE
SharkBackendTargets.cmake
NAMESPACE
SharkBackend::
DESTINATION
${INSTALL_CONFIGDIR}
)
include(CMakePackageConfigHelpers)
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/cmake/SharkBackendConfig.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
)
install(
FILES
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
DESTINATION ${INSTALL_CONFIGDIR}
)
#
# Export from build tree
#
export(
EXPORT triton-dshark-backend-targets
FILE ${CMAKE_CURRENT_BINARY_DIR}/SharkBackendTargets.cmake
NAMESPACE SharkBackend::
)
export(PACKAGE SharkBackend)

View File

@@ -1,100 +0,0 @@
# SHARK Triton Backend
The triton backend for shark.
# Build
Install SHARK
```
git clone https://github.com/nod-ai/SHARK.git
# skip above step if dshark is already installed
cd SHARK/inference
```
install dependancies
```
apt-get install patchelf rapidjson-dev python3-dev
git submodule update --init
```
update the submodules of iree
```
cd thirdparty/shark-runtime
git submodule update --init
```
Next, make the backend and install it
```
cd ../..
mkdir build && cd build
cmake -DTRITON_ENABLE_GPU=ON \
-DIREE_HAL_DRIVER_CUDA=ON \
-DIREE_TARGET_BACKEND_CUDA=ON \
-DMLIR_ENABLE_CUDA_RUNNER=ON \
-DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \
-DTRITON_BACKEND_REPO_TAG=r22.02 \
-DTRITON_CORE_REPO_TAG=r22.02 \
-DTRITON_COMMON_REPO_TAG=r22.02 ..
make install
```
# Incorporating into Triton
There are much more in depth explenations for the following steps in triton's documentation:
https://github.com/triton-inference-server/server/blob/main/docs/compose.md#triton-with-unsupported-and-custom-backends
There should be a file at /build/install/backends/dshark/libtriton_dshark.so. You will need to copy it into your triton server image.
More documentation is in the link above, but to create the docker image, you need to run the compose.py command in the triton-backend server repo
To first build your image, clone the tritonserver repo.
```
git clone https://github.com/triton-inference-server/server.git
```
then run `compose.py` to build a docker compose file
```
cd server
python3 compose.py --repoagent checksum --dry-run
```
Because dshark is a third party backend, you will need to manually modify the `Dockerfile.compose` to include the dshark backend. To do this, in the Dockerfile.compose file produced, copy this line.
the dshark backend will be located in the build folder from earlier under `/build/install/backends`
```
COPY /path/to/build/install/backends/dshark /opt/tritonserver/backends/dshark
```
Next run
```
docker build -t tritonserver_custom -f Dockerfile.compose .
docker run -it --gpus=1 --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
```
where `path/to/model_repos` is where you are storing the models you want to run
if your not using gpus, omit `--gpus=1`
```
docker run -it --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
```
# Setting up a model
to include a model in your backend, add a directory with your model name to your model repository directory. examples of models can be seen here: https://github.com/triton-inference-server/backend/tree/main/examples/model_repos/minimal_models
make sure to adjust the input correctly in the config.pbtxt file, and save a vmfb file under 1/model.vmfb
# CUDA
if you're having issues with cuda, make sure your correct drivers are installed, and that `nvidia-smi` works, and also make sure that the nvcc compiler is on the path.

View File

@@ -1,39 +0,0 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include(CMakeFindDependencyMacro)
get_filename_component(
SHARKBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
)
list(APPEND CMAKE_MODULE_PATH ${SHARKBACKEND_CMAKE_DIR})
if(NOT TARGET SharkBackend::triton-dshark-backend)
include("${SHARKBACKEND_CMAKE_DIR}/SharkBackendTargets.cmake")
endif()
set(SHARKBACKEND_LIBRARIES SharkBackend::triton-dshark-backend)

File diff suppressed because it is too large Load Diff

View File

@@ -1,30 +0,0 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
{
global:
TRITONBACKEND_*;
local: *;
};

View File

@@ -6,15 +6,15 @@ from distutils.sysconfig import get_python_lib
import fileinput
from pathlib import Path
# Temorary workaround for transformers/__init__.py.
path_to_tranformers_hook = Path(
# Temporary workaround for transformers/__init__.py.
path_to_transformers_hook = Path(
get_python_lib()
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
)
if path_to_tranformers_hook.is_file():
if path_to_transformers_hook.is_file():
pass
else:
with open(path_to_tranformers_hook, "w") as f:
with open(path_to_transformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")

View File

@@ -5,7 +5,7 @@ requires = [
"packaging",
"numpy>=1.22.4",
"torch-mlir>=20221021.633",
"torch-mlir>=20230620.875",
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
]

View File

@@ -8,19 +8,8 @@ torchvision
tqdm
#iree-compiler | iree-runtime should already be installed
#these dont work ok osx
#iree-tools-tflite
#iree-tools-xla
#iree-tools-tf
# TensorFlow and JAX.
gin-config
tensorflow-macos
tensorflow-metal
#tf-models-nightly
#tensorflow-text-nightly
transformers
tensorflow-probability
#jax[cpu]
# tflitehub dependencies.

View File

@@ -3,29 +3,19 @@
numpy>1.22.4
pytorch-triton
torchvision==0.16.0.dev20230322
torchvision
tabulate
tqdm
#iree-compiler | iree-runtime should already be installed
iree-tools-tflite
iree-tools-xla
iree-tools-tf
# TensorFlow and JAX.
# Modelling and JAX.
gin-config
tensorflow>2.11
keras
#tf-models-nightly
#tensorflow-text-nightly
transformers
diffusers
#tensorflow-probability
#jax[cpu]
# tflitehub dependencies.
Pillow
# Testing and support.

View File

@@ -1,3 +1,6 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
setuptools
wheel
@@ -15,16 +18,18 @@ Pillow
parameterized
# Add transformers, diffusers and scipy since it most commonly used
tokenizers==0.13.3
transformers
diffusers
#accelerate is now required for diffusers import from ckpt.
accelerate
scipy
ftfy
gradio
gradio==3.44.3
altair
omegaconf
safetensors
# 0.3.2 doesn't have binaries for arm64
safetensors==0.3.1
opencv-python
scikit-image
pytorch_lightning # for runwayml models
@@ -35,10 +40,11 @@ py-cpuinfo
tiktoken # for codegen
joblib # for langchain
timm # for MiniGPT4
langchain
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@dev
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea

View File

@@ -90,8 +90,8 @@ python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
Write-Host "Build and installation completed successfully"
Write-Host "Source your venv with ./shark.venv/Scripts/activate"

View File

@@ -86,6 +86,7 @@ $PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
if [ "$torch_mlir_bin" = true ]; then
if [[ $(uname -s) = 'Darwin' ]]; then
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
$PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
else
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
@@ -103,7 +104,7 @@ else
fi
if [[ -z "${USE_IREE}" ]]; then
rm .use-iree
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
RUNTIME="https://nod-ai.github.io/SRT/pip-release-links.html"
else
touch ./.use-iree
RUNTIME="https://openxla.github.io/iree/pip-release-links.html"
@@ -128,16 +129,21 @@ if [[ ! -z "${IMPORTER}" ]]; then
fi
fi
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
if [[ $(uname -s) = 'Darwin' ]]; then
PYTORCH_URL=https://download.pytorch.org/whl/nightly/torch/
else
PYTORCH_URL=https://download.pytorch.org/whl/nightly/cpu/
fi
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL}
if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
T_VER=$($PYTHON -m pip show torch | grep Version)
TORCH_VERSION=${T_VER:9:17}
T_VER_MIN=${T_VER:14:12}
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
TV_VERSION=${TV_VER:9:18}
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
TV_VER_MAJ=${TV_VER:9:6}
$PYTHON -m pip uninstall -y torchvision
$PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu118."
else
@@ -145,14 +151,8 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
fi
fi
if [[ ! -z "${ONNX}" ]]; then
echo "${Yellow}Installing ONNX and onnxruntime for benchmarks..."
$PYTHON -m pip install onnx onnxruntime psutil
if [ $? -eq 0 ];then
echo "Successfully installed ONNX and ONNX runtime."
else
echo "Could not install ONNX." >&2
fi
if [[ -z "${NO_BREVITAS}" ]]; then
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev
fi
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then

View File

@@ -43,9 +43,7 @@ if __name__ == "__main__":
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=False, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, mlir_dialect="linalg"
)
shark_module = SharkInference(minilm_mlir)
shark_module.compile()
token_logits = torch.tensor(shark_module.forward(inputs))
mask_id = torch.where(

View File

@@ -1,325 +0,0 @@
import torch
from torch.nn.utils import stateless
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark.shark_trainer import SharkTrainer
import argparse
import sys
import numpy as np
import torch.nn as nn
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
import torch_mlir
import extend_distributed as ext_dist
### define dlrm in PyTorch ###
class DLRM_Net(nn.Module):
def create_mlp(self, ln, sigmoid_layer):
# build MLP layer by layer
layers = nn.ModuleList()
for i in range(0, ln.size - 1):
n = ln[i]
m = ln[i + 1]
# construct fully connected operator
LL = nn.Linear(int(n), int(m), bias=True)
# initialize the weights
# with torch.no_grad():
# custom Xavier input, output or two-sided fill
mean = 0.0 # std_dev = np.sqrt(variance)
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
LL.weight.data = torch.tensor(W, requires_grad=True)
LL.bias.data = torch.tensor(bt, requires_grad=True)
# approach 2
# LL.weight.data.copy_(torch.tensor(W))
# LL.bias.data.copy_(torch.tensor(bt))
# approach 3
# LL.weight = Parameter(torch.tensor(W),requires_grad=True)
# LL.bias = Parameter(torch.tensor(bt),requires_grad=True)
layers.append(LL)
# construct sigmoid or relu operator
if i == sigmoid_layer:
layers.append(nn.Sigmoid())
else:
layers.append(nn.ReLU())
# approach 1: use ModuleList
# return layers
# approach 2: use Sequential container to wrap all layers
return torch.nn.Sequential(*layers)
def create_emb(self, m, ln, weighted_pooling=None):
emb_l = nn.ModuleList()
v_W_l = []
for i in range(0, ln.size):
n = ln[i]
# construct embedding operator
EE = nn.EmbeddingBag(n, m, mode="sum")
# initialize embeddings
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n))
W = np.random.uniform(
low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m)
).astype(np.float32)
# approach 1
print(W)
EE.weight.data = torch.tensor(W, requires_grad=True)
# approach 2
# EE.weight.data.copy_(torch.tensor(W))
# approach 3
# EE.weight = Parameter(torch.tensor(W),requires_grad=True)
if weighted_pooling is None:
v_W_l.append(None)
else:
v_W_l.append(torch.ones(n, dtype=torch.float32))
emb_l.append(EE)
return emb_l, v_W_l
def __init__(
self,
m_spa=None,
ln_emb=None,
ln_bot=None,
ln_top=None,
arch_interaction_op=None,
arch_interaction_itself=False,
sigmoid_bot=-1,
sigmoid_top=-1,
weighted_pooling=None,
):
super(DLRM_Net, self).__init__()
if (
(m_spa is not None)
and (ln_emb is not None)
and (ln_bot is not None)
and (ln_top is not None)
and (arch_interaction_op is not None)
):
# save arguments
self.output_d = 0
self.arch_interaction_op = arch_interaction_op
self.arch_interaction_itself = arch_interaction_itself
if weighted_pooling is not None and weighted_pooling != "fixed":
self.weighted_pooling = "learned"
else:
self.weighted_pooling = weighted_pooling
# create operators
self.emb_l, w_list = self.create_emb(
m_spa, ln_emb, weighted_pooling
)
if self.weighted_pooling == "learned":
self.v_W_l = nn.ParameterList()
for w in w_list:
self.v_W_l.append(nn.Parameter(w))
else:
self.v_W_l = w_list
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot)
self.top_l = self.create_mlp(ln_top, sigmoid_top)
def apply_mlp(self, x, layers):
return layers(x)
def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
# WARNING: notice that we are processing the batch at once. We implicitly
# assume that the data is laid out such that:
# 1. each embedding is indexed with a group of sparse indices,
# corresponding to a single lookup
# 2. for each embedding the lookups are further organized into a batch
# 3. for a list of embedding tables there is a list of batched lookups
# TORCH-MLIR
# We are passing all the embeddings as arguments for easy parsing.
ly = []
for k, sparse_index_group_batch in enumerate(lS_i):
sparse_offset_group_batch = lS_o[k]
# embedding lookup
# We are using EmbeddingBag, which implicitly uses sum operator.
# The embeddings are represented as tall matrices, with sum
# happening vertically across 0 axis, resulting in a row vector
# E = emb_l[k]
if v_W_l[k] is not None:
per_sample_weights = v_W_l[k].gather(
0, sparse_index_group_batch
)
else:
per_sample_weights = None
E = emb_l[k]
V = E(
sparse_index_group_batch,
sparse_offset_group_batch,
per_sample_weights=per_sample_weights,
)
ly.append(V)
return ly
def interact_features(self, x, ly):
if self.arch_interaction_op == "dot":
# concatenate dense and sparse features
(batch_size, d) = x.shape
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
# perform a dot product
Z = torch.bmm(T, torch.transpose(T, 1, 2))
# append dense feature with the interactions (into a row vector)
# approach 1: all
# Zflat = Z.view((batch_size, -1))
# approach 2: unique
_, ni, nj = Z.shape
# approach 1: tril_indices
# offset = 0 if self.arch_interaction_itself else -1
# li, lj = torch.tril_indices(ni, nj, offset=offset)
# approach 2: custom
offset = 1 if self.arch_interaction_itself else 0
li = torch.tensor(
[i for i in range(ni) for j in range(i + offset)]
)
lj = torch.tensor(
[j for i in range(nj) for j in range(i + offset)]
)
Zflat = Z[:, li, lj]
# concatenate dense features and interactions
R = torch.cat([x] + [Zflat], dim=1)
elif self.arch_interaction_op == "cat":
# concatenation features (into a row vector)
R = torch.cat([x] + ly, dim=1)
else:
sys.exit(
"ERROR: --arch-interaction-op="
+ self.arch_interaction_op
+ " is not supported"
)
return R
def forward(self, dense_x, lS_o, *lS_i):
return self.sequential_forward(dense_x, lS_o, lS_i)
def sequential_forward(self, dense_x, lS_o, lS_i):
# process dense features (using bottom mlp), resulting in a row vector
x = self.apply_mlp(dense_x, self.bot_l)
# debug prints
# print("intermediate")
# print(x.detach().cpu().numpy())
# process sparse features(using embeddings), resulting in a list of row vectors
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
# for y in ly:
# print(y.detach().cpu().numpy())
# interact features (dense and sparse)
z = self.interact_features(x, ly)
# print(z.detach().cpu().numpy())
# obtain probability of a click (using top mlp)
p = self.apply_mlp(z, self.top_l)
# # clamp output if needed
# if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
# z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
# else:
# z = p
return p
def dash_separated_ints(value):
vals = value.split("-")
for val in vals:
try:
int(val)
except ValueError:
raise argparse.ArgumentTypeError(
"%s is not a valid dash separated list of ints" % value
)
return value
# model related parameters
parser = argparse.ArgumentParser(
description="Train Deep Learning Recommendation Model (DLRM)"
)
parser.add_argument("--arch-sparse-feature-size", type=int, default=2)
parser.add_argument(
"--arch-embedding-size", type=dash_separated_ints, default="4-3-2"
)
# j will be replaced with the table number
parser.add_argument(
"--arch-mlp-bot", type=dash_separated_ints, default="4-3-2"
)
parser.add_argument(
"--arch-mlp-top", type=dash_separated_ints, default="8-2-1"
)
parser.add_argument(
"--arch-interaction-op", type=str, choices=["dot", "cat"], default="dot"
)
parser.add_argument(
"--arch-interaction-itself", action="store_true", default=False
)
parser.add_argument("--weighted-pooling", type=str, default=None)
args = parser.parse_args()
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
ln_top = np.fromstring(args.arch_mlp_top, dtype=int, sep="-")
m_den = ln_bot[0]
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-")
m_spa = args.arch_sparse_feature_size
ln_emb = np.asarray(ln_emb)
num_fea = ln_emb.size + 1 # num sparse + num dense features
# Initialize the model.
dlrm_model = DLRM_Net(
m_spa=m_spa,
ln_emb=ln_emb,
ln_bot=ln_bot,
ln_top=ln_top,
arch_interaction_op=args.arch_interaction_op,
)
def get_sorted_params(named_params):
return [i[1] for i in sorted(named_params.items())]
dense_inp = torch.tensor([[0.6965, 0.2861, 0.2269, 0.5513]])
vs0 = torch.tensor([[0], [0], [0]], dtype=torch.int64)
vsi = torch.tensor([1, 2, 3]), torch.tensor([1]), torch.tensor([1])
input_dlrm = (dense_inp, vs0, *vsi)
mlir_importer = SharkImporter(
dlrm_model,
input_dlrm,
frontend="torch",
)
dlrm_mlir = torch_mlir.compile(dlrm_model, input_dlrm, torch_mlir.OutputType.LINALG_ON_TENSORS, use_tracing=True)
print(dlrm_mlir)
def forward(params, buffers, args):
params_and_buffers = {**params, **buffers}
stateless.functional_call(
dlrm_model, params_and_buffers, args, {}
).sum().backward()
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
# optim.load_state_dict(optim_state)
optim.step()
return params, buffers
shark_module = SharkTrainer(dlrm_model, input_dlrm)
print("________________________________________________________________________")
shark_module.compile(forward)
print("________________________________________________________________________")
shark_module.train(num_iters=2)
print("training done")

View File

@@ -13,7 +13,7 @@
# limitations under the License.
## Common utilities to be shared by iree utilities.
import functools
import os
import sys
import subprocess
@@ -52,6 +52,8 @@ def iree_device_map(device):
)
if len(uri_parts) == 1:
return iree_driver
elif "rocm" in uri_parts:
return "rocm"
else:
return f"{iree_driver}://{uri_parts[1]}"
@@ -63,7 +65,6 @@ def get_supported_device_list():
_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"AMD-AIE": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
@@ -82,7 +83,6 @@ def iree_target_map(device):
_IREE_TARGET_MAP = {
"cpu": "llvm-cpu",
"cpu-task": "llvm-cpu",
"AMD-AIE": "llvm-cpu",
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan",
@@ -93,6 +93,7 @@ _IREE_TARGET_MAP = {
# Finds whether the required drivers are installed for the given device.
@functools.cache
def check_device_drivers(device):
"""Checks necessary drivers present for gpu and vulkan devices"""
if "://" in device:
@@ -120,7 +121,10 @@ def check_device_drivers(device):
return False
elif device == "rocm":
try:
subprocess.check_output("rocminfo")
if sys.platform == "win32":
subprocess.check_output("hipinfo")
else:
subprocess.check_output("rocminfo")
except Exception:
return True

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
from shark.iree_utils._common import run_cmd, iree_device_map
from shark.iree_utils.cpu_utils import get_cpu_count
import numpy as np
@@ -62,16 +61,12 @@ def build_benchmark_args(
and whether it is training or not.
Outputs: string that execute benchmark-module on target model.
"""
path = benchmark_module.__path__[0]
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
if platform.system() == "Windows":
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module.exe"
)
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
time_extractor = None
else:
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module"
)
benchmarker_path = os.path.join(path, "iree-benchmark-module")
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
# TODO: The function named can be passed as one of the args.
@@ -106,15 +101,13 @@ def build_benchmark_args_non_tensor_input(
and whether it is training or not.
Outputs: string that execute benchmark-module on target model.
"""
path = benchmark_module.__path__[0]
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
if platform.system() == "Windows":
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module.exe"
)
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
time_extractor = None
else:
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module"
)
benchmarker_path = os.path.join(path, "iree-benchmark-module")
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
# TODO: The function named can be passed as one of the args.
if function_name:
@@ -139,7 +132,7 @@ def run_benchmark_module(benchmark_cl):
benchmark_path = benchmark_cl[0]
assert os.path.exists(
benchmark_path
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
), "Cannot find iree_benchmark_module, Please contact SHARK maintainer on discord."
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
try:
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")

View File

@@ -11,18 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import iree_device_map, iree_target_map
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
from shark.iree_utils.benchmark_utils import *
from shark.parser import shark_args
import functools
import numpy as np
import os
import re
import tempfile
import time
from pathlib import Path
import iree.runtime as ireert
import iree.compiler as ireec
from shark.parser import shark_args
from .trace import DetailLogger
from ._common import iree_device_map, iree_target_map
from .cpu_utils import get_iree_cpu_rt_args
from .benchmark_utils import *
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
@@ -41,7 +46,7 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
data_tiling_flag = ["--iree-opt-data-tiling"]
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
@@ -79,7 +84,7 @@ def get_iree_frontend_args(frontend):
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
return [
"--iree-llvmcpu-target-cpu-features=host",
"--iree-flow-demote-i64-to-i32",
"--iree-input-demote-i64-to-i32",
]
else:
# Frontend not found.
@@ -87,13 +92,27 @@ def get_iree_frontend_args(frontend):
# Common args to be used given any frontend or device.
def get_iree_common_args():
return [
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
def get_iree_common_args(debug=False):
common_args = [
"--iree-stream-resource-max-allocation-size=4294967295",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
]
if debug == True:
common_args.extend(
[
"--iree-opt-strip-assertions=false",
"--verify=true",
]
)
else:
common_args.extend(
[
"--iree-opt-strip-assertions=true",
"--verify=false",
]
)
return common_args
# Args that are suitable only for certain models or groups of models.
@@ -272,14 +291,17 @@ def compile_module_to_flatbuffer(
model_config_path,
extra_args,
model_name="None",
debug=False,
compile_str=False,
):
# Setup Compile arguments wrt to frontends.
input_type = ""
input_type = "auto"
args = get_iree_frontend_args(frontend)
args += get_iree_device_args(device, extra_args)
args += get_iree_common_args()
args += get_iree_common_args(debug=debug)
args += get_model_specific_args()
args += extra_args
args += shark_args.additional_compile_args
if frontend in ["tensorflow", "tf"]:
input_type = "auto"
@@ -290,10 +312,7 @@ def compile_module_to_flatbuffer(
elif frontend in ["tm_tensor"]:
input_type = ireec.InputType.TM_TENSOR
# TODO: make it simpler.
# Compile according to the input type, else just try compiling.
if input_type != "":
# Currently for MHLO/TOSA.
if compile_str:
flatbuffer_blob = ireec.compile_str(
module,
target_backends=[iree_target_map(device)],
@@ -301,9 +320,10 @@ def compile_module_to_flatbuffer(
input_type=input_type,
)
else:
# Currently for Torch.
flatbuffer_blob = ireec.compile_str(
assert os.path.isfile(module)
flatbuffer_blob = ireec.compile_file(
module,
input_type=input_type,
target_backends=[iree_target_map(device)],
extra_args=args,
)
@@ -317,7 +337,6 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
@@ -337,58 +356,70 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
):
instance = ireert.VmInstance()
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device_by_uri(
device,
allocators=[],
)
# First get configs.
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
print(f"Loading module {flatbuffer_blob_or_path}...")
if "rocm" in device:
device = "rocm"
with DetailLogger(timeout=2.5) as dl:
# First get configs.
if device_idx is not None:
dl.log(f"Mapping device id: {device_idx}")
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
dl.log(f"ireert.get_driver()")
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(mmaped_vmfb)
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
return mmaped_vmfb, config, temp_file_to_unlink
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
dl.log(f"ireert.create_device()")
config = ireert.Config(device=haldevice)
dl.log(f"ireert.Config()")
else:
config = get_iree_runtime_config(device)
dl.log("get_iree_runtime_config")
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(
config.vm_instance, flatbuffer_blob_or_path
)
dl.log(f"mmap {flatbuffer_blob_or_path}")
ctx = ireert.SystemContext(config=config)
dl.log(f"ireert.SystemContext created")
if "vulkan" in device:
# Vulkan pipeline creation consumes significant amount of time.
print(
"\tCompiling Vulkan shaders. This may take a few minutes."
)
ctx.add_vm_module(mmaped_vmfb)
dl.log(f"module initialized")
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
dl.log(f"mmap temp {vmfb_file_path}")
return mmaped_vmfb, config, temp_file_to_unlink
def get_iree_compiled_module(
@@ -399,10 +430,18 @@ def get_iree_compiled_module(
extra_args: list = [],
device_idx: int = None,
mmap: bool = False,
debug: bool = False,
compile_str: bool = False,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args
module,
device,
frontend,
model_config_path,
extra_args,
debug,
compile_str,
)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
@@ -410,7 +449,6 @@ def get_iree_compiled_module(
# we're setting delete=False when creating NamedTemporaryFile. That's why
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
if mmap:
print(f"Will load the compiled module as a mmapped temporary file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_blob, device, device_idx
)
@@ -434,7 +472,6 @@ def load_flatbuffer(
):
temp_file_to_unlink = None
if mmap:
print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_path, device, device_idx
)
@@ -460,10 +497,18 @@ def export_iree_module_to_vmfb(
model_config_path: str = None,
module_name: str = None,
extra_args: list = [],
debug: bool = False,
compile_str: bool = False,
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, mlir_dialect, model_config_path, extra_args
module,
device,
mlir_dialect,
model_config_path,
extra_args,
debug,
compile_str,
)
if module_name is None:
device_name = (
@@ -471,9 +516,9 @@ def export_iree_module_to_vmfb(
)
module_name = f"{mlir_dialect}_{device_name}"
filename = os.path.join(directory, module_name + ".vmfb")
print(f"Saved vmfb in {filename}.")
with open(filename, "wb") as f:
f.write(flatbuffer_blob)
print(f"Saved vmfb in {filename}.")
return filename
@@ -498,37 +543,56 @@ def get_results(
config,
frontend="torch",
send_to_host=True,
debug_timeout: float = 5.0,
):
"""Runs a .vmfb file given inputs and config and returns output."""
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm[function_name](*device_inputs)
result_tensors = []
if isinstance(result, tuple):
if send_to_host:
for val in result:
result_tensors.append(np.asarray(val, val.dtype))
with DetailLogger(debug_timeout) as dl:
device_inputs = []
for input_array in input:
dl.log(f"Load to device: {input_array.shape}")
device_inputs.append(
ireert.asdevicearray(config.device, input_array)
)
dl.log(f"Invoke function: {function_name}")
result = compiled_vm[function_name](*device_inputs)
dl.log(f"Invoke complete")
result_tensors = []
if isinstance(result, tuple):
if send_to_host:
for val in result:
dl.log(f"Result to host: {val.shape}")
result_tensors.append(np.asarray(val, val.dtype))
else:
for val in result:
result_tensors.append(val)
return result_tensors
elif isinstance(result, dict):
data = list(result.items())
if send_to_host:
res = np.array(data, dtype=object)
return np.copy(res)
return data
else:
for val in result:
result_tensors.append(val)
return result_tensors
elif isinstance(result, dict):
data = list(result.items())
if send_to_host:
res = np.array(data, dtype=object)
return np.copy(res)
return data
else:
if send_to_host and result is not None:
return result.to_host()
return result
if send_to_host and result is not None:
dl.log("Result to host")
return result.to_host()
return result
dl.log("Execution complete")
@functools.cache
def get_iree_runtime_config(device):
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
if device == "metal" and shark_args.device_allocator == "caching":
print(
"[WARNING] metal devices can not have a `caching` allocator."
"\nUsing default allocator `None`"
)
haldevice = haldriver.create_device_by_uri(
device,
allocators=shark_args.device_allocator,
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
allocators=shark_args.device_allocator if device != "metal" else None,
)
config = ireert.Config(device=haldevice)
return config

View File

@@ -14,6 +14,7 @@
# All the iree_cpu related functionalities go here.
import functools
import subprocess
import platform
from shark.parser import shark_args
@@ -30,6 +31,7 @@ def get_cpu_count():
# Get the default cpu args.
@functools.cache
def get_iree_cpu_args():
uname = platform.uname()
os_name, proc_name = uname.system, uname.machine
@@ -51,6 +53,7 @@ def get_iree_cpu_args():
# Get iree runtime flags for cpu
@functools.cache
def get_iree_cpu_rt_args():
default = get_cpu_count()
default = default if default <= 8 else default - 2

View File

@@ -14,12 +14,15 @@
# All the iree_gpu related functionalities go here.
import functools
import iree.runtime as ireert
import ctypes
import sys
from shark.parser import shark_args
# Get the default gpu args given the architecture.
@functools.cache
def get_iree_gpu_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
ireert.flags.parse_flags("--cuda_allow_inline_execution")
@@ -37,23 +40,54 @@ def get_iree_gpu_args():
# Get the default gpu args given the architecture.
@functools.cache
def get_iree_rocm_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
# get arch from rocminfo.
# get arch from hipinfo.
import os
import re
import subprocess
rocm_arch = re.match(
r".*(gfx\w+)",
subprocess.check_output(
"rocminfo | grep -i 'gfx'", shell=True, text=True
),
).group(1)
print(f"Found rocm arch {rocm_arch}...")
if sys.platform == "win32":
if "HIP_PATH" in os.environ:
rocm_path = os.environ["HIP_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to C:\\AMD\\ROCM\\5.5")
rocm_path = "C:\\AMD\\ROCM\\5.5"
else:
if "ROCM_PATH" in os.environ:
rocm_path = os.environ["ROCM_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to /opt/rocm")
rocm_path = "/opt/rocm/"
try:
if sys.platform == "win32":
rocm_arch = re.search(
r"gfx\d{3,}",
subprocess.check_output("hipinfo", shell=True, text=True),
).group(0)
else:
rocm_arch = re.match(
r".*(gfx\w+)",
subprocess.check_output(
"rocminfo | grep -i 'gfx'", shell=True, text=True
),
).group(1)
print(f"Found rocm arch {rocm_arch}...")
except:
print(
"Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100."
)
rocm_arch = "gfx1100"
bc_path = os.path.join(rocm_path, "amdgcn", "bitcode")
return [
f"--iree-rocm-target-chip={rocm_arch}",
"--iree-rocm-link-bc=true",
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
f"--iree-rocm-bc-dir={bc_path}",
]
@@ -65,6 +99,7 @@ CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13
CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36
@functools.cache
def get_cuda_sm_cc():
libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll")
for libname in libnames:

View File

@@ -14,12 +14,15 @@
# All the iree_vulkan related functionalities go here.
import functools
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
@functools.cache
def get_metal_device_name(device_num=0):
iree_device_dump = run_cmd("iree-run-module --dump_devices")
iree_device_dump = iree_device_dump[0].split("\n\n")
@@ -86,24 +89,10 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
def get_iree_metal_args(device_num=0, extra_args=[]):
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]
# Add any metal spefic compilation flags here
res_metal_flag = []
metal_triple_flag = None
for arg in extra_args:
if "-iree-metal-target-platform=" in arg:
print(f"Using target triple {arg} from command line args")
metal_triple_flag = arg
break
if metal_triple_flag is None:
metal_triple_flag = get_metal_triple_flag(extra_args=extra_args)
if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(
"-iree-vulkan-target-triple=m1-moltenvk-macos"
)
res_metal_flag.append(vulkan_target_env)
if len(extra_args) > 0:
res_metal_flag.extend(extra_args)
return res_metal_flag

76
shark/iree_utils/trace.py Normal file
View File

@@ -0,0 +1,76 @@
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import os
import threading
import time
def _enable_detail_trace() -> bool:
return os.getenv("SHARK_DETAIL_TRACE", "0") == "1"
class DetailLogger:
"""Context manager which can accumulate detailed log messages.
Detailed log is only emitted if the operation takes a long time
or errors.
"""
def __init__(self, timeout: float):
self._timeout = timeout
self._messages: List[Tuple[float, str]] = []
self._start_time = time.time()
self._active = not _enable_detail_trace()
self._lock = threading.RLock()
self._cond = threading.Condition(self._lock)
self._thread = None
def __enter__(self):
self._thread = threading.Thread(target=self._run)
self._thread.start()
return self
def __exit__(self, type, value, traceback):
with self._lock:
self._active = False
self._cond.notify()
if traceback:
self.dump_on_error(f"exception")
def _run(self):
with self._lock:
timed_out = not self._cond.wait(self._timeout)
if timed_out:
self.dump_on_error(f"took longer than {self._timeout}s")
def log(self, msg):
with self._lock:
timestamp = time.time()
if self._active:
self._messages.append((timestamp, msg))
else:
print(f" +{(timestamp - self._start_time) * 1000}ms: {msg}")
def dump_on_error(self, summary: str):
with self._lock:
if self._active:
print(f"::: Detailed report ({summary}):")
for timestamp, msg in self._messages:
print(
f" +{(timestamp - self._start_time) * 1000}ms: {msg}"
)
self._active = False

View File

@@ -13,8 +13,10 @@
# limitations under the License.
from collections import OrderedDict
import functools
@functools.cache
def get_vulkan_target_env(vulkan_target_triple):
arch, product, os = vulkan_target_triple.split("=")[1].split("-")
triple = (arch, product, os)
@@ -52,13 +54,11 @@ def get_version(triple):
return "v1.3"
@functools.cache
def get_extensions(triple):
def make_ext_list(ext_list):
res = ""
for e in ext_list:
res += e + ", "
res = f"[{res[:-2]}]"
return res
res = ", ".join(ext_list)
return f"[{res}]"
arch, product, os = triple
if arch == "m1":
@@ -116,12 +116,13 @@ def get_extensions(triple):
]
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
ext.append("VK_NV_cooperative_matrix")
ext.append("VK_KHR_cooperative_matrix")
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
ext.append("VK_KHR_shader_integer_dot_product")
return make_ext_list(ext_list=ext)
@functools.cache
def get_vendor(triple):
arch, product, os = triple
if arch == "unknown":
@@ -146,6 +147,7 @@ def get_vendor(triple):
return "Unknown"
@functools.cache
def get_device_type(triple):
arch, product, _ = triple
if arch == "unknown":
@@ -166,6 +168,7 @@ def get_device_type(triple):
# get all the capabilities for the device
# TODO: make a dataclass for capabilites and init using vulkaninfo
@functools.cache
def get_vulkan_target_capabilities(triple):
def get_subgroup_val(l):
return int(sum([subgroup_feature[sgf] for sgf in l]))
@@ -241,7 +244,7 @@ def get_vulkan_target_capabilities(triple):
if arch == "rdna3":
# TODO: Get scope value
cap["coopmatCases"] = [
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>"
]
if product == "rx5700xt":
@@ -462,9 +465,9 @@ def get_vulkan_target_capabilities(triple):
cap["variablePointersStorageBuffer"] = True
cap["coopmatCases"] = [
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, accSat = false, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, accSat = false, scope = #vk.scope<Subgroup>",
]
elif arch == "adreno":
@@ -525,7 +528,7 @@ def get_vulkan_target_capabilities(triple):
cmc = ""
for case in v:
cmc += f"#vk.coop_matrix_props<{case}>, "
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
res += f"cooperativeMatrixPropertiesKHR = [{cmc[:-2]}], "
else:
res += f"{k} = {get_comma_sep_str(v)}, "
else:

View File

@@ -14,6 +14,7 @@
# All the iree_vulkan related functionalities go here.
import functools
from os import linesep
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
@@ -22,10 +23,19 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
from shark.parser import shark_args
@functools.cache
def get_all_vulkan_devices():
from iree.runtime import get_driver
driver = get_driver("vulkan")
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return [d["name"] for d in device_list_src]
@functools.cache
def get_vulkan_device_name(device_num=0):
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
vulkaninfo_list = get_all_vulkan_devices()
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
@@ -48,6 +58,7 @@ def get_os_name():
return "linux"
@functools.cache
def get_vulkan_target_triple(device_name):
"""This method provides a target triple str for specified vulkan device.
@@ -108,6 +119,8 @@ def get_vulkan_target_triple(device_name):
# Windows: AMD Radeon RX 7900 XTX
elif all(x in device_name for x in ("RX", "7900")):
triple = f"rdna3-7900-{system_os}"
elif all(x in device_name for x in ("Radeon", "780M")):
triple = f"rdna3-780m-{system_os}"
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
triple = f"rdna3-w7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
@@ -172,11 +185,10 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
return res_vulkan_flag
@functools.cache
def get_iree_vulkan_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={shark_args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
f"--vulkan_vma_allocator={'true' if shark_args.vulkan_vma_allocator else 'false'}",
]
return vulkan_runtime_flags

View File

@@ -14,8 +14,21 @@
import argparse
import os
import shlex
import subprocess
class SplitStrToListAction(argparse.Action):
def __init__(self, option_strings, dest, *args, **kwargs):
super(SplitStrToListAction, self).__init__(
option_strings=option_strings, dest=dest, *args, **kwargs
)
def __call__(self, parser, namespace, values, option_string=None):
del parser, option_string
setattr(namespace, self.dest, shlex.split(values[0]))
parser = argparse.ArgumentParser(description="SHARK runner.")
parser.add_argument(
@@ -24,6 +37,13 @@ parser.add_argument(
default="cpu",
help="Device on which shark_runner runs. options are cpu, cuda, and vulkan",
)
parser.add_argument(
"--additional_compile_args",
default=list(),
nargs=1,
action=SplitStrToListAction,
help="Additional arguments to pass to the compiler. These are appended as the last arguments.",
)
parser.add_argument(
"--enable_tf32",
type=bool,
@@ -114,7 +134,7 @@ parser.add_argument(
"--device_allocator",
type=str,
nargs="*",
default=[],
default=["caching"],
help="Specifies one or more HAL device allocator specs "
"to augment the base device allocator",
choices=["debug", "caching"],
@@ -133,13 +153,6 @@ parser.add_argument(
help="Profiles vulkan device and collects the .rdc info.",
)
parser.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
help="Flag for setting VMA preferredLargeHeapBlockSize for "
"vulkan device, default is 4G.",
)
parser.add_argument(
"--vulkan_validation_layers",
default=False,
@@ -147,11 +160,4 @@ parser.add_argument(
help="Flag for disabling vulkan validation layers when benchmarking.",
)
parser.add_argument(
"--vulkan_vma_allocator",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for enabling / disabling Vulkan VMA Allocator.",
)
shark_args, unknown = parser.parse_known_args()

View File

@@ -13,7 +13,11 @@
# limitations under the License.
from shark.shark_runner import SharkRunner
from shark.iree_utils.compile_utils import export_iree_module_to_vmfb
from shark.iree_utils.compile_utils import (
export_iree_module_to_vmfb,
load_flatbuffer,
get_iree_runtime_config,
)
from shark.iree_utils.benchmark_utils import (
build_benchmark_args,
run_benchmark_module,
@@ -79,22 +83,39 @@ class SharkBenchmarkRunner(SharkRunner):
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
self.import_args = {}
self.temp_file_to_unlink = None
if not os.path.isfile(mlir_module):
print(
"Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
)
self.compile_str = True
else:
self.compile_str = False
SharkRunner.__init__(
self,
mlir_module,
device,
self.mlir_dialect,
self.extra_args,
compile_vmfb=True,
compile_vmfb=False,
)
if self.vmfb_file == None:
self.vmfb_file = export_iree_module_to_vmfb(
mlir_module,
device,
".",
self.mlir_dialect,
extra_args=self.extra_args,
)
self.vmfb_file = export_iree_module_to_vmfb(
mlir_module,
device,
".",
self.mlir_dialect,
extra_args=self.extra_args,
compile_str=self.compile_str,
)
params = load_flatbuffer(
self.vmfb_file,
device,
mmap=True,
)
self.iree_compilation_module = params["vmfb"]
self.iree_config = params["config"]
self.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
def setup_cl(self, input_tensors):
self.benchmark_cl = build_benchmark_args(
@@ -111,42 +132,41 @@ class SharkBenchmarkRunner(SharkRunner):
elif self.mlir_dialect in ["mhlo", "tf"]:
return self.benchmark_tf(modelname)
def benchmark_torch(self, modelname):
def benchmark_torch(self, modelname, device="cpu"):
import torch
from tank.model_utils import get_torch_model
if self.device == "cuda":
torch.set_default_tensor_type(torch.cuda.FloatTensor)
if self.enable_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# TODO: Pass this as an arg. currently the best way is to setup with BENCHMARK=1 if we want to use torch+cuda, else use cpu.
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
torch.set_default_device("cuda:0")
# if self.enable_tf32:
# torch.backends.cuda.matmul.allow_tf32 = True
else:
torch.set_default_tensor_type(torch.FloatTensor)
torch_device = torch.device(
"cuda:0" if self.device == "cuda" else "cpu"
)
torch.set_default_dtype(torch.float32)
torch.set_default_device("cpu")
torch_device = torch.device("cuda:0" if device == "cuda" else "cpu")
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
frontend_model = HFmodel.model
frontend_model.to(torch_device)
input.to(torch_device)
# TODO: re-enable as soon as pytorch CUDA context issues are resolved
try:
frontend_model = torch.compile(
frontend_model, mode="max-autotune", backend="inductor"
)
except RuntimeError:
frontend_model = HFmodel.model
if device == "cuda":
frontend_model.cuda()
input.to(torch.device("cuda:0"))
print(input)
else:
frontend_model.cpu()
input.cpu()
for i in range(shark_args.num_warmup_iterations):
frontend_model.forward(input)
if self.device == "cuda":
if device == "cuda":
torch.cuda.reset_peak_memory_stats()
begin = time.time()
for i in range(shark_args.num_iterations):
out = frontend_model.forward(input)
end = time.time()
if self.device == "cuda":
if device == "cuda":
stats = torch.cuda.memory_stats()
device_peak_b = stats["allocated_bytes.all.peak"]
frontend_model.to(torch.device("cpu"))
@@ -158,7 +178,7 @@ class SharkBenchmarkRunner(SharkRunner):
print(
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
if self.device == "cuda":
if device == "cuda":
# Set device to CPU so we don't run into segfaults exiting pytest subprocesses.
torch_device = torch.device("cpu")
return [

View File

@@ -1,7 +1,7 @@
import os
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
import torch
import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
@@ -11,14 +11,8 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
def brevitasmatmul_rhs_group_quant〡shape(
lhs: List[int],
rhs: List[int],
rhs_scale: List[int],
rhs_zero_point: List[int],
rhs_bit_width: int,
rhs_group_size: int,
) -> List[int]:
# fmt: off
def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
@@ -27,30 +21,21 @@ def brevitasmatmul_rhs_group_quant〡shape(
raise ValueError("Input shapes not supported.")
def brevitasmatmul_rhs_group_quant〡dtype(
lhs_rank_dtype: Tuple[int, int],
rhs_rank_dtype: Tuple[int, int],
rhs_scale_rank_dtype: Tuple[int, int],
rhs_zero_point_rank_dtype: Tuple[int, int],
rhs_bit_width: int,
rhs_group_size: int,
) -> int:
def quantmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def brevitasmatmul_rhs_group_quant〡has_value_semantics(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
) -> None:
def quantmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics,
]
quantmatmul_rhs_group_quant〡shape,
quantmatmul_rhs_group_quant〡dtype,
quantmatmul_rhs_group_quant〡has_value_semantics]
# fmt: on
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
@@ -122,7 +107,7 @@ def compile_int_precision(
torchscript_module,
inputs,
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
@@ -130,7 +115,7 @@ def compile_int_precision(
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
mlir_module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
from contextlib import redirect_stdout
@@ -145,10 +130,17 @@ def compile_int_precision(
mlir_module = mlir_module.encode("UTF-8")
mlir_module = BytesIO(mlir_module)
bytecode = mlir_module.read()
bytecode_path = os.path.join(
os.getcwd(), f"{extended_model_name}_linalg.mlirbc"
)
with open(bytecode_path, "wb") as f:
f.write(bytecode)
del bytecode
del mlir_module
print(f"Elided IR written for {extended_model_name}")
return bytecode
return bytecode_path
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
mlir_module=bytecode_path, device=device, mlir_dialect="tm_tensor"
)
extra_args = [
"--iree-hal-dump-executable-sources-to=ies",
@@ -163,7 +155,7 @@ def compile_int_precision(
generate_vmfb=generate_vmfb,
extra_args=extra_args,
),
bytecode,
bytecode_path,
)
@@ -216,7 +208,7 @@ def shark_compile_through_fx(
]
else:
(
mlir_module,
bytecode,
_,
) = import_with_fx(
model=model,
@@ -227,6 +219,11 @@ def shark_compile_through_fx(
model_name=extended_model_name,
save_dir=save_dir,
)
mlir_module = save_mlir(
mlir_module=bytecode,
model_name=extended_model_name,
mlir_dialect=mlir_dialect,
)
shark_module = SharkInference(
mlir_module,

View File

@@ -111,22 +111,20 @@ os.makedirs(WORKDIR, exist_ok=True)
def check_dir_exists(model_name, frontend="torch", dynamic=""):
model_dir = os.path.join(WORKDIR, model_name)
# Remove the _tf keyword from end.
if frontend in ["tf", "tensorflow"]:
model_name = model_name[:-3]
elif frontend in ["tflite"]:
model_name = model_name[:-7]
elif frontend in ["torch", "pytorch"]:
model_name = model_name[:-6]
# Remove the _tf keyword from end only for non-SD models.
if not any(model in model_name for model in ["clip", "unet", "vae"]):
if frontend in ["tf", "tensorflow"]:
model_name = model_name[:-3]
elif frontend in ["tflite"]:
model_name = model_name[:-7]
elif frontend in ["torch", "pytorch"]:
model_name = model_name[:-6]
model_mlir_file_name = f"{model_name}{dynamic}_{frontend}.mlir"
if os.path.isdir(model_dir):
if (
os.path.isfile(
os.path.join(
model_dir,
model_name + dynamic + "_" + str(frontend) + ".mlir",
)
)
os.path.isfile(os.path.join(model_dir, model_mlir_file_name))
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
@@ -277,11 +275,11 @@ def download_model(
model_dir = os.path.join(WORKDIR, model_dir_name)
tuned_str = "" if tuned is None else "_" + tuned
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
filename = os.path.join(model_dir, model_name + suffix)
mlir_filename = os.path.join(model_dir, model_name + suffix)
print(
f"Verifying that model artifacts were downloaded successfully to {filename}..."
f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..."
)
if not os.path.exists(filename):
if not os.path.exists(mlir_filename):
from tank.generate_sharktank import gen_shark_files
print(
@@ -289,13 +287,11 @@ def download_model(
)
gen_shark_files(model_name, frontend, WORKDIR, import_args)
assert os.path.exists(filename), f"MLIR not found at {filename}"
with open(filename, mode="rb") as f:
mlir_file = f.read()
assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}"
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
inputs_tuple = tuple([inputs[key] for key in inputs])
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
return mlir_file, function_name, inputs_tuple, golden_out_tuple
return mlir_filename, function_name, inputs_tuple, golden_out_tuple

View File

@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Tuple
from collections import defaultdict
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
import torchvision.models as models
import copy
import io
@@ -20,10 +20,16 @@ def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
bytecode_path = save_mlir(
bytecode,
model_name="shark_eager_module",
frontend="torch",
mlir_dialect="tm_tensor",
)
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode,
mlir_module=bytecode_path,
device=device,
mlir_dialect="tm_tensor",
)

View File

@@ -1,8 +1,10 @@
import re
import json
import numpy as np
import torch_mlir
from iree.compiler import compile_str
from shark.shark_importer import import_with_fx, get_f16_inputs
from iree.compiler import compile_file
from shark.shark_importer import import_with_fx, get_f16_inputs, save_mlir
class GenerateConfigFile:
@@ -11,6 +13,7 @@ class GenerateConfigFile:
model,
num_sharding_stages: int,
sharding_stages_id: list[str],
units_in_each_stage: list[int],
model_input=None,
config_file_path="model_config.json",
):
@@ -22,13 +25,16 @@ class GenerateConfigFile:
), "Number of sharding stages should be equal to the list of their ID"
self.model_input = model_input
self.config_file_path = config_file_path
# (Nithin) this is a quick fix - revisit and rewrite
self.units_in_each_stage = np.array(units_in_each_stage)
self.track_loop = np.zeros(len(self.sharding_stages_id)).astype(int)
def split_into_dispatches(
self,
backend,
fx_tracing_required=True,
fx_tracing_required=False,
f16_model=False,
torch_mlir_tracing=False,
torch_mlir_tracing=True,
):
graph_for_compilation = self.model
if fx_tracing_required:
@@ -48,9 +54,15 @@ class GenerateConfigFile:
verbose=False,
)
module = module.operation.get_asm(large_elements_limit=4)
module_file = save_mlir(
module,
model_name="module_pre_split",
frontend="torch",
mlir_dialect="linalg",
)
compiled_module_str = str(
compile_str(
str(module),
compile_file(
module_file,
target_backends=[backend],
extra_args=[
"--compile-to=flow",
@@ -95,7 +107,17 @@ class GenerateConfigFile:
if substring_before_final_period in model_dictionary:
del model_dictionary[substring_before_final_period]
layer_dict = {n: "None" for n in self.sharding_stages_id}
# layer_dict = {n: "None" for n in self.sharding_stages_id}
# By default embed increasing device id's for each layer
increasing_wraparound_idx_list = (
self.track_loop % self.units_in_each_stage
)
layer_dict = {
n: int(increasing_wraparound_idx_list[idx][0][0])
for idx, n in enumerate(self.sharding_stages_id)
}
self.track_loop += 1
model_dictionary[name] = layer_dict
self.generate_json(model_dictionary)
@@ -103,3 +125,29 @@ class GenerateConfigFile:
def generate_json(self, artifacts):
with open(self.config_file_path, "w") as outfile:
json.dump(artifacts, outfile)
if __name__ == "__main__":
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna7B,
CombinedModel,
)
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()

View File

@@ -451,6 +451,108 @@ def transform_fx(fx_g, quantized=False):
fx_g.graph.lint()
def gptq_transforms(fx_g):
import torch
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.ones,
torch.ops.aten._to_copy,
]:
if node.kwargs.get("device") == torch.device(device="cuda:0"):
updated_kwargs = node.kwargs.copy()
updated_kwargs["device"] = torch.device(device="cpu")
node.kwargs = updated_kwargs
if node.target in [
torch.ops.aten._to_copy,
]:
if node.kwargs.get("dtype") == torch.bfloat16:
updated_kwargs = node.kwargs.copy()
updated_kwargs["dtype"] = torch.float16
node.kwargs = updated_kwargs
# Inputs of aten.native_layer_norm should be upcasted to fp32.
if node.target in [torch.ops.aten.native_layer_norm]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
node.args = (
new_node_arg0,
node.args[1],
node.args[2],
node.args[3],
node.args[4],
)
# Inputs of aten.mm should be upcasted to fp32.
if node.target in [torch.ops.aten.mm]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
new_node_arg1 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[1], torch.float32),
kwargs={},
)
node.args = (new_node_arg0, new_node_arg1)
# Outputs of aten.mm should be downcasted to fp16.
if type(node.args[0]) == torch.fx.node.Node and node.args[
0
].target in [torch.ops.aten.mm]:
with fx_g.graph.inserting_before(node):
tmp = node.args[0]
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node.args[0],),
kwargs={"dtype": torch.float16},
)
node.args[0].append(new_node)
node.args[0].replace_all_uses_with(new_node)
new_node.args = (tmp,)
new_node.kwargs = {"dtype": torch.float16}
# Inputs of aten._softmax should be upcasted to fp32.
if node.target in [torch.ops.aten._softmax]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
node.args = (new_node_arg0, node.args[1], node.args[2])
# Outputs of aten._softmax should be downcasted to fp16.
if (
type(node.args[0]) == torch.fx.node.Node
and node.args[0].target in [torch.ops.aten._softmax]
and node.target in [torch.ops.aten.expand]
):
with fx_g.graph.inserting_before(node):
tmp = node.args[0]
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node.args[0],),
kwargs={"dtype": torch.float16},
)
node.args[0].append(new_node)
node.args[0].replace_all_uses_with(new_node)
new_node.args = (tmp,)
new_node.kwargs = {"dtype": torch.float16}
fx_g.graph.lint()
# Doesn't replace the None type.
def change_fx_graph_return_to_tuple(fx_g):
for node in fx_g.graph.nodes:
@@ -504,27 +606,12 @@ def import_with_fx(
is_dynamic=False,
tracing_required=False,
precision="fp32",
is_gptq=False,
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
)
from brevitas_examples.llm.llm_quant.export import (
brevitas_layer_export_mode,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
LinearWeightBlockQuantHandlerFwd,
)
from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
matmul_rhs_group_quant_placeholder,
)
from brevitas.backport.fx.experimental.proxy_tensor import (
make_fx as brevitas_make_fx,
)
golden_values = None
if debug:
@@ -596,8 +683,30 @@ def import_with_fx(
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten.index_add,
torch.ops.aten.index_add_,
]
if precision in ["int4", "int8"]:
if precision in ["int4", "int8"] and not is_gptq:
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
)
from brevitas_examples.llm.llm_quant.export import (
brevitas_layer_export_mode,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
LinearWeightBlockQuantHandlerFwd,
)
from brevitas_examples.llm.llm_quant.export import (
replace_call_fn_target,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
matmul_rhs_group_quant_placeholder,
)
from brevitas.backport.fx.experimental.proxy_tensor import (
make_fx as brevitas_make_fx,
)
export_context_manager = brevitas_layer_export_mode
export_class = block_quant_layer_level_manager(
export_handlers=[LinearWeightBlockQuantHandlerFwd]
@@ -612,7 +721,7 @@ def import_with_fx(
replace_call_fn_target(
fx_g,
src=matmul_rhs_group_quant_placeholder,
target=torch.ops.brevitas.matmul_rhs_group_quant,
target=torch.ops.quant.matmul_rhs_group_quant,
)
fx_g.recompile()
@@ -647,6 +756,10 @@ def import_with_fx(
add_upcast(fx_g)
fx_g.recompile()
if is_gptq:
gptq_transforms(fx_g)
fx_g.recompile()
if mlir_type == "fx":
return fx_g
@@ -677,5 +790,27 @@ def import_with_fx(
)
return mlir_module, func_name
mlir_module, func_name = mlir_importer.import_mlir()
mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type)
return mlir_module, func_name
# Saves a .mlir module python object to the directory 'dir' with 'model_name' and returns a path to the saved file.
def save_mlir(
mlir_module,
model_name,
mlir_dialect="linalg",
frontend="torch",
dir=tempfile.gettempdir(),
):
model_name_mlir = (
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
)
if dir == "":
dir = tempfile.gettempdir()
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if frontend == "torch":
with open(mlir_path, "wb") as mlir_file:
mlir_file.write(mlir_module)
return mlir_path

View File

@@ -39,7 +39,7 @@ class SharkInference:
Attributes
----------
mlir_module : str
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
mlir_module or path represented in string; modules from torch-mlir are serialized in bytecode format.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
@@ -65,7 +65,7 @@ class SharkInference:
def __init__(
self,
mlir_module: bytes,
mlir_module,
device: str = "none",
mlir_dialect: str = "linalg",
is_benchmark: bool = False,
@@ -75,6 +75,14 @@ class SharkInference:
mmap: bool = True,
):
self.mlir_module = mlir_module
if mlir_module is not None:
if mlir_module and not os.path.isfile(mlir_module):
print(
"Warning: Initializing SharkInference with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
)
self.compile_str = True
else:
self.compile_str = False
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.is_benchmark = is_benchmark
@@ -141,6 +149,10 @@ class SharkInference:
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
return self.shark_runner.run(function_name, inputs, send_to_host)
# forward function.
def forward(self, inputs: tuple, send_to_host=True):
return self.shark_runner.run("forward", inputs, send_to_host)
# Get all function names defined within the compiled module.
def get_functions_in_module(self):
return self.shark_runner.get_functions_in_module()
@@ -188,7 +200,9 @@ class SharkInference:
# TODO: Instead of passing directory and having names decided by the module
# , user may want to save the module with manual names.
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
def save_module(
self, dir=os.getcwd(), module_name=None, extra_args=[], debug=False
):
return export_iree_module_to_vmfb(
self.mlir_module,
self.device,
@@ -196,6 +210,8 @@ class SharkInference:
self.mlir_dialect,
module_name=module_name,
extra_args=extra_args,
debug=debug,
compile_str=self.compile_str,
)
# load and return the module.

View File

@@ -45,7 +45,7 @@ class SharkRunner:
Attributes
----------
mlir_module : str
mlir_module represented in string.
mlir_module path, string, or bytecode.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
@@ -74,6 +74,14 @@ class SharkRunner:
device_idx: int = None,
):
self.mlir_module = mlir_module
if self.mlir_module is not None:
if not os.path.isfile(mlir_module):
print(
"Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
)
self.compile_str = True
else:
self.compile_str = False
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
@@ -91,6 +99,7 @@ class SharkRunner:
self.mlir_dialect,
extra_args=self.extra_args,
device_idx=self.device_idx,
compile_str=self.compile_str,
)
self.iree_compilation_module = params["vmfb"]
self.iree_config = params["config"]

View File

@@ -15,7 +15,7 @@
from shark.parser import shark_args
from shark.shark_runner import SharkRunner
from shark.backward_makefx import MakeFxModule
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
import numpy as np
from tqdm import tqdm
import sys
@@ -69,7 +69,7 @@ class SharkTrainer:
self.frontend = frontend
# Training function is needed in the case of torch_fn.
def compile(self, training_fn=None, extra_args=[]):
def compile(self, training_fn=None, mlir_type="linalg", extra_args=[]):
if self.frontend in ["torch", "pytorch"]:
packed_inputs = (
dict(self.model.named_parameters()),
@@ -77,7 +77,18 @@ class SharkTrainer:
tuple(self.input),
)
mlir_module, func_name = import_with_fx(
training_fn, packed_inputs, False, [], training=True
training_fn,
packed_inputs,
False,
[],
training=True,
mlir_type=mlir_type,
)
mlir_module = save_mlir(
mlir_module,
model_name="shark_model",
frontend="torch",
mlir_dialect=mlir_type,
)
self.shark_runner = SharkRunner(
mlir_module,

View File

@@ -1,25 +1,6 @@
resnet50,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
albert-base-v2,stablehlo,tf,1e-2,1e-2,default,None,False,False,False,"",""
roberta-base,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
bert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","enabled_windows"
camembert-base,stablehlo,tf,1e-2,1e-3,default,None,True,True,True,"",""
dbmdz/convbert-base-turkish-cased,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/iree-org/iree/issues/9971",""
distilbert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
facebook/convnext-tiny-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342","macos"
funnel-transformer/small,stablehlo,tf,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/201",""
google/electra-small-discriminator,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
google/mobilebert-uncased,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"Fails during iree-compile","macos"
google/vit-base-patch16-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,False,"",""
microsoft/MiniLM-L12-H384-uncased,stablehlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
microsoft/layoutlm-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
microsoft/mpnet-base,stablehlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with aten.tanh in torch-mlir",""
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-large-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
@@ -30,18 +11,11 @@ nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,False,True,"",""
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
efficientnet-v2-s,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b0,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
efficientnet_b7,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
gpt2,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
t5-base,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
t5-large,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
stabilityai/stable-diffusion-2-1-base,linalg,torch,1e-3,1e-3,default,None,True,False,False,"","macos"
1 resnet50 bert-base-uncased stablehlo linalg tf torch 1e-2 1e-3 default nhcw-nhwc None False True False False macos
resnet50 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False macos
albert-base-v2 stablehlo tf 1e-2 1e-2 default None False False False
roberta-base stablehlo tf 1e-02 1e-3 default nhcw-nhwc True True True macos
bert-base-uncased stablehlo tf 1e-2 1e-3 default None False False False enabled_windows
camembert-base stablehlo tf 1e-2 1e-3 default None True True True
dbmdz/convbert-base-turkish-cased stablehlo tf 1e-2 1e-3 default nhcw-nhwc True False True https://github.com/iree-org/iree/issues/9971
distilbert-base-uncased stablehlo tf 1e-2 1e-3 default None False False False
facebook/convnext-tiny-224 stablehlo tf 1e-2 1e-3 tf_vit nhcw-nhwc True False True https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342 macos
funnel-transformer/small stablehlo tf 1e-2 1e-3 default None True False True https://github.com/nod-ai/SHARK/issues/201
google/electra-small-discriminator stablehlo tf 1e-2 1e-3 default None False False False
google/mobilebert-uncased stablehlo tf 1e-2 1e-3 default None True False False Fails during iree-compile macos
google/vit-base-patch16-224 stablehlo tf 1e-2 1e-3 tf_vit nhcw-nhwc False False False
microsoft/MiniLM-L12-H384-uncased stablehlo tf 1e-2 1e-3 tf_hf None True False False Fails during iree-compile.
microsoft/layoutlm-base-uncased stablehlo tf 1e-2 1e-3 default None False False False
microsoft/mpnet-base stablehlo tf 1e-2 1e-2 default None True True True
albert-base-v2 linalg torch 1e-2 1e-3 default None True True True issue with aten.tanh in torch-mlir
alexnet linalg torch 1e-2 1e-3 default None True False True https://github.com/nod-ai/SHARK/issues/879
bert-base-cased linalg torch 1e-2 1e-3 default None False False True
1 bert-base-uncased bert-base-uncased linalg linalg torch torch 1e-2 1e-3 default None None False True False True False
2 bert-base-uncased_fp16 bert-base-uncased_fp16 linalg linalg torch torch 1e-1 1e-1 default None None True True True True
3 bert-large-uncased bert-large-uncased linalg linalg torch torch 1e-2 1e-3 default None None False True False True False
bert-large-uncased stablehlo tf 1e-2 1e-3 default None False False False
4 facebook/deit-small-distilled-patch16-224 facebook/deit-small-distilled-patch16-224 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False True False True False Fails during iree-compile.
5 google/vit-base-patch16-224 google/vit-base-patch16-224 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False True False True False https://github.com/nod-ai/SHARK/issues/311
6 microsoft/beit-base-patch16-224-pt22k-ft22k microsoft/beit-base-patch16-224-pt22k-ft22k linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False True False True False https://github.com/nod-ai/SHARK/issues/390 macos macos
11 resnet101 resnet101 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc/img2col nhcw-nhwc/img2col True False False False macos macos
12 resnet18 resnet18 linalg linalg torch torch 1e-2 1e-3 default None None True True False True False macos macos
13 resnet50 resnet50 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False False False False macos macos
14 resnet50_fp16 resnet50_fp16 linalg linalg torch torch 1e-2 1e-2 default nhcw-nhwc/img2col nhcw-nhwc/img2col True True True False True Numerics issues, awaiting cuda-independent fp16 integration
15 squeezenet1_0 squeezenet1_0 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False False False False macos macos
16 wide_resnet50_2 wide_resnet50_2 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc/img2col nhcw-nhwc/img2col True False False False macos macos
efficientnet-v2-s stablehlo tf 1e-02 1e-3 default nhcw-nhwc False False False macos
17 mnasnet1_0 mnasnet1_0 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc True True True True macos macos
18 efficientnet_b0 efficientnet_b0 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc True True True True https://github.com/nod-ai/SHARK/issues/1487 macos macos
19 efficientnet_b7 efficientnet_b7 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc True True True True https://github.com/nod-ai/SHARK/issues/1487 macos macos
efficientnet_b0 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False
efficientnet_b7 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False Fails on MacOS builder, VK device lost macos
gpt2 stablehlo tf 1e-2 1e-3 default None True False False macos
20 t5-base t5-base linalg linalg torch torch 1e-2 1e-3 default None None True True True True Inputs for seq2seq models in torch currently unsupported. macos macos
t5-base stablehlo tf 1e-2 1e-3 default None False False False macos
21 t5-large t5-large linalg linalg torch torch 1e-2 1e-3 default None None True True True True Inputs for seq2seq models in torch currently unsupported macos macos
t5-large stablehlo tf 1e-2 1e-3 default None False False False macos
stabilityai/stable-diffusion-2-1-base linalg torch 1e-3 1e-3 default None True False False macos

View File

@@ -85,8 +85,6 @@ if __name__ == "__main__":
args = [
"--iree-llvmcpu-target-cpu-features=host",
"--iree-mhlo-demote-i64-to-i32=false",
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
]
backend_config = "dylib"
# backend = "cuda"

View File

@@ -1,3 +1,26 @@
# Running Different OPT Variants
# Run OPT for sentence completion through SHARK
To run different sizes of OPT, change the string `OPT_MODEL` string in `opt_torch_test.py`. The default is 350m parameters. 66b cases also exist in the file, simply uncomment the test cases.
From base SHARK directory, follow instructions to set up a virtual environment with SHARK. (`./setup_venv.sh` or `./setup_venv.ps1`)
Then, you may run opt_causallm.py to get a very simple sentence completion application running through SHARK
```
python opt_causallm.py
```
# Run OPT performance comparison on SHARK vs. PyTorch
```
python opt_perf_comparison.py --max-seq-len=512 --model-name=facebook/opt-1.3b \
--platform=shark
```
Any OPT model from huggingface should work with this script, and you can choose between `--platform=shark` or `--platform=huggingface` to generate benchmarks of OPT inference on SHARK / PyTorch.
# Run a small suite of OPT models through the benchmark script
```
python opt_perf_comparison_batch.py
```
This script will run benchmarks from a suite of OPT configurations:
- Sequence Lengths: 32, 128, 256, 512
- Parameter Counts: 125m, 350m, 1.3b
note: Most of these scripts are written for use on CPU, as perf comparisons against pytorch can be problematic across platforms otherwise.

View File

@@ -36,9 +36,7 @@ def create_module(model_name, tokenizer, device):
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
if os.path.isfile(mlir_path):
with open(mlir_path, "r") as f:
model_mlir = f.read()
print(f"Loaded .mlir from {mlir_path}")
print(f"Found .mlir from {mlir_path}")
else:
(model_mlir, func_name) = import_with_fx(
model=opt_model,
@@ -50,16 +48,17 @@ def create_module(model_name, tokenizer, device):
with open(mlir_path, "w") as f:
f.write(model_mlir)
print(f"Saved mlir at {mlir_path}")
del model_mlir
shark_module = SharkInference(
model_mlir,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=False,
)
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
shark_module.save_module(module_name=vmfb_name)
shark_module.save_module(module_name=vmfb_name, debug=False)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path

View File

@@ -6,7 +6,7 @@ import numpy as np
from shark_opt_wrapper import OPTForCausalLMModel
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from transformers import AutoTokenizer, OPTForCausalLM
OPT_MODEL = "facebook/opt-1.3b"
@@ -57,9 +57,10 @@ class OPTModuleTester:
with open(mlir_path, "w") as f:
f.write(mlir_module)
print(f"Saved mlir at {mlir_path}")
del mlir_module
shark_module = SharkInference(
mlir_module,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=self.benchmark,

View File

@@ -1,18 +1,45 @@
"""
Script for comparing OPT model performance between SHARK and Huggingface
PyTorch.
Usage Example:
python opt_perf_comparison.py --max-seq-len=32 --model-name=facebook/opt-125m \
--platform=shark
python opt_perf_comparison.py --max-seq-len=512 --model-name=facebook/opt-1.3b \
--platform=shark
See parse_args() below for command line argument usage.
"""
import argparse
import collections
import json
import time
import os
import psutil
import time
from typing import Tuple
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
from shark_opt_wrapper import OPTForCausalLMModel
MODEL_NAME = "facebook/opt-1.3b"
OPT_MODELNAME = "opt-1.3b"
OPT_FS_NAME = "opt_1-3b"
MAX_SEQUENCE_LENGTH = 512
DEVICE = "cpu"
PLATFORM_SHARK = "shark"
PLATFORM_HUGGINGFACE = "huggingface"
# Dict keys for reports.
REPORT_PLATFORM = "platform"
REPORT_MODEL_NAME = "model"
REPORT_MAX_SEQ_LEN = "max_seq_len"
REPORT_LOAD_TIME = "load_time_sec"
REPORT_RUN_TIME = "run_time_sec"
REPORT_LOAD_PHYSICAL_MEMORY_MB = "load_physical_MB"
REPORT_LOAD_VIRTUAL_MEMORY_MB = "load_virtual_MB"
REPORT_RUN_PHYSICAL_MEMORY_MB = "run_physical_MB"
REPORT_RUN_VIRTUAL_MEMORY_MB = "run_virtual_MB"
PROMPTS = [
"What is the meaning of life?",
@@ -30,15 +57,27 @@ PROMPTS = [
ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])
def create_vmfb_module(model_name, tokenizer, device):
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
def get_memory_info():
pid = os.getpid()
process = psutil.Process(pid)
return process.memory_info()
def create_vmfb_module(
model_name: str,
tokenizer,
device: str,
max_seq_len: int,
recompile_shark: bool,
):
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
"What is the meaning of life?",
PROMPTS[0],
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
max_length=max_seq_len,
return_tensors="pt",
)
inputs = (
@@ -48,8 +87,16 @@ def create_vmfb_module(model_name, tokenizer, device):
# np.save("model_inputs_0.npy", inputs[0])
# np.save("model_inputs_1.npy", inputs[1])
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
if os.path.isfile(mlir_path):
opt_fs_name = get_opt_fs_name(model_name)
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
# If MLIR has already been loaded and recompilation is not requested, use
# the loaded MLIR file.
has_mlir = os.path.isfile(mlir_path)
# The purpose of recompile_shark is to measure compilation time; the
# compilation time can be correctly measured only when MLIR has already been
# loaded.
assert not recompile_shark or has_mlir
if has_mlir:
with open(mlir_path, "r") as f:
model_mlir = f.read()
print(f"Loaded .mlir from {mlir_path}")
@@ -58,7 +105,7 @@ def create_vmfb_module(model_name, tokenizer, device):
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=OPT_FS_NAME,
model_name=opt_fs_name,
return_str=True,
)
with open(mlir_path, "w") as f:
@@ -72,18 +119,25 @@ def create_vmfb_module(model_name, tokenizer, device):
is_benchmark=False,
)
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{DEVICE}_tiled_ukernels"
vmfb_name = (
f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels"
)
shark_module.save_module(module_name=vmfb_name)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path
def load_shark_model() -> ModelWrapper:
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{DEVICE}_tiled_ukernels.vmfb"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
if not os.path.isfile(vmfb_name):
def load_shark_model(
model_name: str, max_seq_len: int, recompile_shark: bool
) -> ModelWrapper:
opt_fs_name = get_opt_fs_name(model_name)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
if recompile_shark or not os.path.isfile(vmfb_name):
print(f"vmfb not found. compiling and saving to {vmfb_name}")
create_vmfb_module(OPT_MODELNAME, tokenizer, DEVICE)
create_vmfb_module(
model_name, tokenizer, DEVICE, max_seq_len, recompile_shark
)
shark_module = SharkInference(mlir_module=None, device="cpu-task")
shark_module.load_module(vmfb_name)
return ModelWrapper(model=shark_module, tokenizer=tokenizer)
@@ -94,20 +148,10 @@ def run_shark_model(model_wrapper: ModelWrapper, tokens):
return model_wrapper.model("forward", tokens)
def run_shark():
model_wrapper = load_shark_model()
prompt = "What is the meaning of life?"
logits = run_shark_model(model_wrapper, prompt)
# Print output logits to validate vs. pytorch + base transformers
print(logits[0])
def load_huggingface_model() -> ModelWrapper:
def load_huggingface_model(model_name: str) -> ModelWrapper:
return ModelWrapper(
model=OPTForCausalLM.from_pretrained(MODEL_NAME),
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME),
model=OPTForCausalLM.from_pretrained(model_name),
tokenizer=AutoTokenizer.from_pretrained(model_name),
)
@@ -117,47 +161,71 @@ def run_huggingface_model(model_wrapper: ModelWrapper, tokens):
)
def run_huggingface():
model_wrapper = load_huggingface_model()
prompt = "What is the meaning of life?"
logits = run_huggingface_model(model_wrapper, prompt)
print(logits[0])
def save_json(data, filename):
with open(filename, "w") as file:
json.dump(data, file)
def collect_huggingface_logits():
def collect_huggingface_logits(
model_name: str, max_seq_len: int, to_save_json: bool
) -> Tuple[float, float]:
# Load
t0 = time.time()
model_wrapper = load_huggingface_model()
print("--- Took {} seconds to load Huggingface.".format(time.time() - t0))
model_wrapper = load_huggingface_model(model_name)
load_time = time.time() - t0
print("--- Took {} seconds to load Huggingface.".format(load_time))
load_memory_info = get_memory_info()
results = []
tokenized_prompts = []
for prompt in PROMPTS:
tokens = model_wrapper.tokenizer(
prompt,
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
max_length=max_seq_len,
truncation=True,
return_tensors="pt",
)
tokenized_prompts.append(tokens)
# Run
t0 = time.time()
for idx, tokens in enumerate(tokenized_prompts):
print("prompt: {}".format(PROMPTS[idx]))
logits = run_huggingface_model(model_wrapper, tokens)
results.append([PROMPTS[idx], logits[0].tolist()])
print("--- Took {} seconds to run Huggingface.".format(time.time() - t0))
save_json(results, "/tmp/huggingface.json")
if to_save_json:
results.append([PROMPTS[idx], logits[0].tolist()])
run_time = time.time() - t0
print("--- Took {} seconds to run Huggingface.".format(run_time))
if to_save_json:
save_json(results, "/tmp/huggingface.json")
run_memory_info = get_memory_info()
return {
REPORT_PLATFORM: PLATFORM_HUGGINGFACE,
REPORT_MODEL_NAME: model_name,
REPORT_MAX_SEQ_LEN: max_seq_len,
REPORT_LOAD_TIME: load_time,
REPORT_RUN_TIME: run_time / len(PROMPTS),
REPORT_LOAD_PHYSICAL_MEMORY_MB: load_memory_info.rss >> 20,
REPORT_LOAD_VIRTUAL_MEMORY_MB: load_memory_info.vms >> 20,
REPORT_RUN_PHYSICAL_MEMORY_MB: run_memory_info.rss >> 20,
REPORT_RUN_VIRTUAL_MEMORY_MB: run_memory_info.vms >> 20,
}
def collect_shark_logits():
def collect_shark_logits(
model_name: str,
max_seq_len: int,
recompile_shark: bool,
to_save_json: bool,
) -> Tuple[float, float]:
# Load
t0 = time.time()
model_wrapper = load_shark_model()
print("--- Took {} seconds to load Shark.".format(time.time() - t0))
model_wrapper = load_shark_model(model_name, max_seq_len, recompile_shark)
load_time = time.time() - t0
print("--- Took {} seconds to load Shark.".format(load_time))
load_memory_info = get_memory_info()
results = []
tokenized_prompts = []
for prompt in PROMPTS:
@@ -165,7 +233,7 @@ def collect_shark_logits():
prompt,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
max_length=max_seq_len,
return_tensors="pt",
)
inputs = (
@@ -173,16 +241,100 @@ def collect_shark_logits():
tokens["attention_mask"],
)
tokenized_prompts.append(inputs)
# Run
t0 = time.time()
for idx, tokens in enumerate(tokenized_prompts):
print("prompt: {}".format(PROMPTS[idx]))
logits = run_shark_model(model_wrapper, tokens)
lst = [e.tolist() for e in logits]
results.append([PROMPTS[idx], lst])
print("--- Took {} seconds to run Shark.".format(time.time() - t0))
save_json(results, "/tmp/shark.json")
if to_save_json:
results.append([PROMPTS[idx], lst])
run_time = time.time() - t0
print("--- Took {} seconds to run Shark.".format(run_time))
if to_save_json:
save_json(results, "/tmp/shark.json")
platform_postfix = "-compile" if recompile_shark else "-precompiled"
run_memory_info = get_memory_info()
return {
REPORT_PLATFORM: PLATFORM_SHARK + platform_postfix,
REPORT_MODEL_NAME: model_name,
REPORT_MAX_SEQ_LEN: max_seq_len,
REPORT_LOAD_TIME: load_time,
REPORT_RUN_TIME: run_time / len(PROMPTS),
REPORT_LOAD_PHYSICAL_MEMORY_MB: load_memory_info.rss >> 20,
REPORT_LOAD_VIRTUAL_MEMORY_MB: load_memory_info.vms >> 20,
REPORT_RUN_PHYSICAL_MEMORY_MB: run_memory_info.rss >> 20,
REPORT_RUN_VIRTUAL_MEMORY_MB: run_memory_info.vms >> 20,
}
def get_opt_fs_name(model_name: str) -> str:
"""Cleanses the model name ino a file system-friendly name.
Example: get_opt_fs_name('facebook/opt-1.3b') == 'opt_1-3b'
"""
slash_split = model_name.split("/")
assert 1 <= len(slash_split) <= 2, "There should be at most one slash."
model_name = slash_split[-1]
for src_pattern, dest_pattern in (("-", "_"), (".", "-")):
model_name = model_name.replace(src_pattern, dest_pattern)
return model_name
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--save-json",
help="If set, saves output JSON.",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--max-seq-len", help="Max sequence length", type=int, default=32
)
parser.add_argument(
"--model-name",
help="Model name",
type=str,
choices=[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
],
default="facebook/opt-1.3b",
)
parser.add_argument(
"--recompile-shark",
help="If set, recompiles MLIR",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--platform",
help="Either shark or huggingface",
type=str,
choices=[PLATFORM_SHARK, PLATFORM_HUGGINGFACE],
default=PLATFORM_SHARK,
)
args = parser.parse_args()
print("args={}".format(args))
return args
if __name__ == "__main__":
collect_shark_logits()
collect_huggingface_logits()
args = parse_args()
if args.platform == PLATFORM_SHARK:
shark_report = collect_shark_logits(
args.model_name,
args.max_seq_len,
args.recompile_shark,
args.save_json,
)
print("# Summary: {}".format(json.dumps(shark_report)))
else:
huggingface_report = collect_huggingface_logits(
args.model_name, args.max_seq_len, args.save_json
)
print("# Summary: {}".format(json.dumps(huggingface_report)))

View File

@@ -0,0 +1,30 @@
"""
Script for running opt_perf_comparison.py in batch with a series of arguments.
Usage: python opt_perf_comparison_batch.py
"""
from typing import Iterable, List
import shlex
import subprocess
def make_commands() -> Iterable[List[str]]:
command = shlex.split("python opt_perf_comparison.py --no-save-json")
max_seq_lens = [32, 128, 256, 512]
model_names = ["facebook/opt-" + e for e in ["125m", "350m", "1.3b"]]
for max_seq_len in max_seq_lens:
for model_name in model_names:
yield command + [
f"--max-seq-len={max_seq_len}",
f"--model-name={model_name}",
]
def main():
for command in make_commands():
result = subprocess.run(command, check=True)
if __name__ == "__main__":
main()

View File

@@ -2,7 +2,7 @@ import os
import torch
from transformers import AutoTokenizer, OPTForCausalLM
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark_opt_wrapper import OPTForCausalLMModel
model_name = "facebook/opt-1.3b"
@@ -25,11 +25,13 @@ inputs = (
model=model,
inputs=inputs,
is_f16=False,
debug=True,
model_name=model_name.split("/")[1],
save_dir=".",
)
mlir_module = save_mlir(
mlir_module,
model_name=model_name.split("/")[1],
frontend="torch",
mlir_dialect="linalg",
)
shark_module = SharkInference(
mlir_module,
device="cpu-sync",

View File

@@ -16,12 +16,6 @@ import subprocess as sp
import hashlib
import numpy as np
from pathlib import Path
from apps.stable_diffusion.src.models import (
model_wrappers as mw,
)
from apps.stable_diffusion.src.utils.stable_args import (
args,
)
def create_hash(file_name):
@@ -42,7 +36,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
get_hf_img_cls_model,
get_fp16_model,
)
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
with open(torch_model_list) as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")
@@ -60,31 +54,6 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
print("generating artifacts for: " + torch_model_name)
model = None
input = None
if model_type == "stable_diffusion":
args.use_tuned = False
args.import_mlir = True
args.local_tank_cache = local_tank_cache
precision_values = ["fp16"]
seq_lengths = [64, 77]
for precision_value in precision_values:
args.precision = precision_value
for length in seq_lengths:
model = mw.SharkifyStableDiffusionModel(
model_id=torch_model_name,
custom_weights="",
precision=precision_value,
max_len=length,
width=512,
height=512,
use_base_vae=False,
custom_vae="",
debug=True,
sharktank_dir=local_tank_cache,
generate_vmfb=False,
)
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(
torch_model_name, import_args
@@ -103,10 +72,11 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
model, input, _ = get_hf_img_cls_model(
torch_model_name, import_args
)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name, import_args)
torch_model_name = torch_model_name.replace("/", "_")
if import_args["batch_size"] != 1:
if import_args["batch_size"] > 1:
print(
f"Batch size for this model set to {import_args['batch_size']}"
)
torch_model_dir = os.path.join(
local_tank_cache,
str(torch_model_name)
@@ -160,133 +130,6 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
)
def save_tf_model(tf_model_list, local_tank_cache, import_args):
from tank.model_utils_tf import (
get_causal_image_model,
get_masked_lm_model,
get_causal_lm_model,
get_keras_model,
get_TFhf_model,
get_tfhf_seq2seq_model,
)
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
with open(tf_model_list) as csvfile:
tf_reader = csv.reader(csvfile, delimiter=",")
fields = next(tf_reader)
for row in tf_reader:
tf_model_name = row[0]
model_type = row[1]
model = None
input = None
print(f"Generating artifacts for model {tf_model_name}")
if model_type == "hf":
model, input, _ = get_masked_lm_model(
tf_model_name, import_args
)
elif model_type == "img":
model, input, _ = get_causal_image_model(
tf_model_name, import_args
)
elif model_type == "keras":
model, input, _ = get_keras_model(tf_model_name, import_args)
elif model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name, import_args)
elif model_type == "tfhf_seq2seq":
model, input, _ = get_tfhf_seq2seq_model(
tf_model_name, import_args
)
elif model_type == "hf_causallm":
model, input, _ = get_causal_lm_model(
tf_model_name, import_args
)
tf_model_name = tf_model_name.replace("/", "_")
if import_args["batch_size"] != 1:
tf_model_dir = os.path.join(
local_tank_cache,
str(tf_model_name)
+ "_tf"
+ f"_BS{str(import_args['batch_size'])}",
)
else:
tf_model_dir = os.path.join(
local_tank_cache, str(tf_model_name) + "_tf"
)
os.makedirs(tf_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
model,
inputs=input,
frontend="tf",
)
mlir_importer.import_debug(
is_dynamic=False,
dir=tf_model_dir,
model_name=tf_model_name,
)
def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
from shark.tflite_utils import TFLitePreprocessor
with open(tflite_model_list) as csvfile:
tflite_reader = csv.reader(csvfile, delimiter=",")
for row in tflite_reader:
print("\n")
tflite_model_name = row[0]
tflite_model_link = row[1]
print("tflite_model_name", tflite_model_name)
print("tflite_model_link", tflite_model_link)
tflite_model_name_dir = os.path.join(
local_tank_cache, str(tflite_model_name) + "_tflite"
)
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
# Preprocess to get SharkImporter input import_args
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()
# Use SharkImporter to get SharkInference input import_args
my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
frontend="tflite",
raw_model_file=raw_model_file_path,
)
my_shark_importer.import_debug(
dir=tflite_model_name_dir,
model_name=tflite_model_name,
func_name="main",
)
mlir_hash = create_hash(
os.path.join(
tflite_model_name_dir,
tflite_model_name + "_tflite" + ".mlir",
)
)
np.save(
os.path.join(tflite_model_name_dir, "hash"),
np.array(mlir_hash),
)
def check_requirements(frontend):
import importlib
@@ -295,10 +138,6 @@ def check_requirements(frontend):
tv_spec = importlib.util.find_spec("torchvision")
has_pkgs = tv_spec is not None
elif frontend in ["tensorflow", "tf"]:
tf_spec = importlib.util.find_spec("tensorflow")
has_pkgs = tf_spec is not None
return has_pkgs
@@ -317,27 +156,11 @@ def gen_shark_files(modelname, frontend, tank_dir, importer_args):
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tf_model_list.csv"
)
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
)
# Create a temporary .csv with only the desired entry.
if frontend == "tf":
with open(tf_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_tf_model(custom_model_csv.name, tank_dir, import_args)
elif frontend == "torch":
if frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
@@ -371,18 +194,6 @@ if __name__ == "__main__":
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
# )
# parser.add_argument(
# "--tf_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tf_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--tflite_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tflite/tflite_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--ci_tank_dir",
# type=bool,
# default=False,
@@ -391,7 +202,7 @@ if __name__ == "__main__":
# old_import_args = parser.parse_import_args()
import_args = {
"batch_size": "1",
"batch_size": 1,
}
print(import_args)
home = str(Path.home())
@@ -399,16 +210,5 @@ if __name__ == "__main__":
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
tflite_model_csv = os.path.join(
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
)
save_torch_model(
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
WORKDIR,
import_args,
)
save_torch_model(torch_model_csv, WORKDIR, import_args)
save_tf_model(tf_model_csv, WORKDIR, import_args)
save_tflite_model(tflite_model_csv, WORKDIR, import_args)

View File

@@ -278,7 +278,7 @@ def get_vision_model(torch_model, import_args):
int(import_args["batch_size"]), 3, *input_image_size
)
actual_out = model(test_input)
if fp16_model is not None:
if fp16_model == True:
test_input_fp16 = test_input.to(
device=torch.device("cuda"), dtype=torch.half
)

View File

@@ -145,6 +145,7 @@ class SharkModuleTester:
shark_args.shark_prefix = self.shark_tank_prefix
shark_args.local_tank_cache = self.local_tank_cache
shark_args.dispatch_benchmarks = self.benchmark_dispatches
shark_args.enable_tf32 = self.tf32
if self.benchmark_dispatches is not None:
_m = self.config["model_name"].split("/")
@@ -216,10 +217,12 @@ class SharkModuleTester:
result = shark_module(func_name, inputs)
golden_out, result = self.postprocess_outputs(golden_out, result)
if self.tf32 == "true":
print("Validating with relaxed tolerances.")
atol = 1e-02
rtol = 1e-03
if self.tf32 == True:
print(
"Validating with relaxed tolerances for TensorFloat32 calculations."
)
self.config["atol"] = 1e-01
self.config["rtol"] = 1e-02
try:
np.testing.assert_allclose(
golden_out,
@@ -254,9 +257,6 @@ class SharkModuleTester:
model_config = {
"batch_size": self.batch_size,
}
shark_args.enable_tf32 = self.tf32
if shark_args.enable_tf32 == True:
shark_module.compile()
shark_args.onnx_bench = self.onnx_bench
shark_module.shark_runner.benchmark_all_csv(
@@ -287,6 +287,9 @@ class SharkModuleTester:
repro_path = os.path.join("reproducers", self.tmp_prefix, "*")
bashCommand = f"gsutil cp -r {repro_path} gs://shark-public/builder/repro_artifacts/{self.ci_sha}/{self.tmp_prefix}/"
print(
f"Uploading reproducer {repro_path} to gs://shark-public/builder/repro_artifacts/{self.ci_sha}/{self.tmp_prefix}/"
)
process = subprocess.run(bashCommand.split())
def postprocess_outputs(self, golden_out, result):

View File

@@ -1,28 +0,0 @@
model_name, model_type
albert-base-v2,hf
bert-base-uncased,hf
camembert-base,hf
dbmdz/convbert-base-turkish-cased,hf
distilbert-base-uncased,hf
google/electra-small-discriminator,hf
funnel-transformer/small,hf
microsoft/layoutlm-base-uncased,hf
google/mobilebert-uncased,hf
microsoft/mpnet-base,hf
roberta-base,hf
resnet50,keras
xlm-roberta-base,hf
microsoft/MiniLM-L12-H384-uncased,TFhf
funnel-transformer/small,hf
microsoft/mpnet-base,hf
facebook/convnext-tiny-224,img
google/vit-base-patch16-224,img
efficientnet-v2-s,keras
bert-large-uncased,hf
t5-base,tfhf_seq2seq
t5-large,tfhf_seq2seq
efficientnet_b0,keras
efficientnet_b7,keras
gpt2,hf_causallm
t5-base,tfhf_seq2seq
t5-large,tfhf_seq2seq
1 model_name model_type
2 albert-base-v2 hf
3 bert-base-uncased hf
4 camembert-base hf
5 dbmdz/convbert-base-turkish-cased hf
6 distilbert-base-uncased hf
7 google/electra-small-discriminator hf
8 funnel-transformer/small hf
9 microsoft/layoutlm-base-uncased hf
10 google/mobilebert-uncased hf
11 microsoft/mpnet-base hf
12 roberta-base hf
13 resnet50 keras
14 xlm-roberta-base hf
15 microsoft/MiniLM-L12-H384-uncased TFhf
16 funnel-transformer/small hf
17 microsoft/mpnet-base hf
18 facebook/convnext-tiny-224 img
19 google/vit-base-patch16-224 img
20 efficientnet-v2-s keras
21 bert-large-uncased hf
22 t5-base tfhf_seq2seq
23 t5-large tfhf_seq2seq
24 efficientnet_b0 keras
25 efficientnet_b7 keras
26 gpt2 hf_causallm
27 t5-base tfhf_seq2seq
28 t5-large tfhf_seq2seq

View File

@@ -5,7 +5,6 @@ microsoft/MiniLM-L12-H384-uncased,True,hf,True,linalg,False,66M,"nlp;bert-varian
bert-base-uncased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
google/mobilebert-uncased,True,hf,True,linalg,False,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
alexnet,False,vision,True,linalg,False,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
resnet18,False,vision,True,linalg,False,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
resnet50,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
resnet101,False,vision,True,linalg,False,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
@@ -18,11 +17,9 @@ facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,linalg,False,22M
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,linalg,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
nvidia/mit-b0,True,hf_img_cls,False,linalg,False,3.7M,"image-classification,transformer-encoder",SegFormer
mnasnet1_0,False,vision,True,linalg,False,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
1 model_name use_tracing model_type dynamic mlir_type decompose param_count tags notes
5 bert-base-uncased True hf True linalg False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
6 bert-base-cased True hf True linalg False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
7 google/mobilebert-uncased True hf True linalg False 25M nlp,bert-variant,transformer-encoder,mobile 24 layers, 512 hidden size, 128 embedding
alexnet False vision True linalg False 61M cnn,parallel-layers The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod.
8 resnet18 False vision True linalg False 11M cnn,image-classification,residuals,resnet-variant 1 7x7 conv2d and the rest are 3x3 conv2d
9 resnet50 False vision True linalg False 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
10 resnet101 False vision True linalg False 29M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
17 microsoft/beit-base-patch16-224-pt22k-ft22k True hf_img_cls False linalg False 86M image-classification,transformer-encoder,bert-variant,vision-transformer N/A
18 nvidia/mit-b0 True hf_img_cls False linalg False 3.7M image-classification,transformer-encoder SegFormer
19 mnasnet1_0 False vision True linalg False - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
resnet50_fp16 False vision True linalg False 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
bert-base-uncased_fp16 True fp16 False linalg False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
20 bert-large-uncased True hf True linalg False 330M nlp;bert-variant;transformer-encoder 24 layers, 1024 hidden units, 16 attention heads
21 bert-base-uncased True hf False stablehlo False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
22 gpt2 True hf_causallm False stablehlo True 125M nlp;transformer-encoder -
23 facebook/opt-125m True hf False stablehlo True 125M nlp;transformer-encoder -
24 distilgpt2 True hf False stablehlo True 88M nlp;transformer-encoder -
25 microsoft/deberta-v3-base True hf False stablehlo True 88M nlp;transformer-encoder -