Compare commits

...

107 Commits

Author SHA1 Message Date
Ean Garvey
d7399c8ee7 Update nightly.yml 2023-09-28 11:45:54 -05:00
Ean Garvey
b6f8993dcc Temporarily disable sharktank gen. 2023-09-28 11:44:38 -05:00
PhaneeshB
94594542a9 remove use of vulkaninfo 2023-09-28 21:57:00 +05:30
Gaurav Shukla
82f833e87d [vulkan] Update vmfb naming
Update vmfb naming for vulkan devices in order to resolve naming
conflicts in the presence of multiple vulkan devices.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-28 14:52:11 +05:30
Vivek Khandelwal
c9d6870105 Modify falcon pipeline for 180b support 2023-09-28 12:39:35 +05:30
Jakub Kuderski
4fec03a6cc [vulkan] Switch from coop matrix NV to KHR (#1848) 2023-09-27 21:43:37 -04:00
harsh-nod
9a27f51378 Deprecate inference directory
This patch removes the inference directory that was no longer being used.
2023-09-27 14:29:00 -07:00
Abhishek Varma
ad1a0f35ff Fix misdirection while saving vmfb
-- Currently SHARK suggests that vmfb has been saved, while
    that is not the case and no vmfb is generated. 
    This creates a misdirection for IR/vmfbs which are of larger
    size.
-- This commit therefore fixes that misdirection.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-27 16:25:29 +05:30
Nelson Sharpe
6773278ec2 Fix checkpoint_path unexpected argument (#1832) 2023-09-24 14:17:52 -07:00
Abhishek Varma
9a0efffcca [Llama2] Fix wrong Vulkan device ID + Add Vulkan compile flags
-- This commit fixes the wrong Vulkan device being selected during
   runtime.
-- It also adds couple of IREE compilation flags to target specific
   Vulkan device.
-- It also changes the Vulkan device listing to be more in tune with
   lowering control flow.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-22 22:24:18 +05:30
gpetters94
61c6f153d9 Switch to keras-nightly to fix a Linux issue (#1835) 2023-09-21 12:33:45 -04:00
Phaneesh Barwaria
effd42e8f5 pin gradio to v3.44.3 2023-09-21 17:33:43 +05:30
Sungsoon Cho
b5fbb1a8a0 Rename the func arg save_json to avoid name collision. (#1837)
* Rename the func arg save_json to avoid name collision.

* black formatted.
2023-09-19 17:29:27 -05:00
Quinn Dawkins
ded74d09cd [vicuna.py] Keep past key values on device (#1836)
The past key values are only used within the models themselves and can
be kept on device. For vulkan int4, this gives 44 tok/s (for the first
prompt) and settles at around 26 tok/s on 7900xtx.
2023-09-19 18:17:41 -04:00
Boian Petkantchin
79267931c1 Add argument --additional_compile_args (#1119)
This allows to pass more arguemnts to the IREE compiler
Example:
python my-app.py --additional_compile_args="--mlir-pretty-debuginfo --mlir-timing"

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-09-19 11:26:03 -05:00
zjgarvey
9eceba69b7 local_tank_cache included into clear_all (#1833) 2023-09-18 00:27:23 -05:00
Ean Garvey
ca609afb6a Update README.md (#1830) 2023-09-14 10:33:57 -05:00
Gaurav Shukla
11bdce9790 [flags] Fix vulkan runtime flags as vma is dropped from iree (#1831) 2023-09-14 08:58:59 -05:00
Ean Garvey
684943a4a6 (SD) Fix tokenizers imports in pyinstaller builds. (#1828)
* Fix tokenizers metadata.

* (SD) Disable VAE lowering configs (rdna3) and add versioned tunings.

* Update sd_annotation.py

* (SD) Add cv2 to spec.

* Update stencil pipeline with the new img2img arg.
2023-09-12 12:23:48 -05:00
PhaneeshB
b817bb8455 add roles for llama2 2023-09-12 10:59:28 +05:30
Ean Garvey
780f520f02 Fix vk.target_env extensions and remove redundant SD imports. (#1826)
* Remove redundant IREE runtime imports.

* Fix vulkan target env extensions.
2023-09-11 13:42:52 -05:00
Dom
c61b6f8d65 Code refactoring (#1817)
* use join

* fix bug

* further code optimizations

---------

Co-authored-by: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com>
2023-09-11 11:30:56 -05:00
Abhishek Varma
c854208d49 [Llama2] Prefetch llama2 tokenizer configs (#1824)
-- This commit prefetches llama2 tokenizer configs from shark_tank.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-08 11:29:54 -07:00
Gaurav Shukla
c5dcfc1f13 [vicuna] Exit when mlir is not present in shark tank (#1825)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-08 10:30:29 -07:00
Abhishek Varma
bde63ee8ae Add logging feature in WebUI (#1821) 2023-09-08 05:48:05 -07:00
Vivek Khandelwal
9681d494eb Update decomp list and shark trainer for DLRM 2023-09-06 21:24:50 +05:30
Gaurav Shukla
ede6bf83e2 [vicuna] Disabling the IR generation path
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-06 20:13:17 +05:30
Ean Garvey
2c2693fb7d Fix torchvision versioning in Linux importer setup. (#1809) 2023-09-05 12:57:03 -05:00
Vivek Khandelwal
1d31b2b2c6 Fix StableHLO Compilation flag 2023-09-05 21:32:33 +05:30
Gaurav Shukla
d2f64eefa3 [chatbot] Remove few outdated models from list (#1814) 2023-09-04 09:26:32 -07:00
Abhishek Varma
87ae14b6ff [SD] Add sdpfa decomposition + update IREE flag
-- This commit adds Scaled Dot Product Flash Attention's decomposition
   in shark_importer.
-- It also updates `iree-flow-enable-data-tiling` to `iree-opt-data-tiling`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-04 18:03:53 +05:30
Phaneesh Barwaria
1ccafa1fc1 fix llama2-70b rewrite tensor dim 2023-09-01 17:27:06 +05:30
jinchen62
4c3d8a0a7f Enable downloading vmfb/mlir for webui (#1807) 2023-08-31 11:05:47 -07:00
jinchen62
3601dc7c3b Fix llama2 13b combined ir (#1803) 2023-08-28 11:34:44 -07:00
Daniel Garvey
671881cf87 Llama2 70b (#1783)
* llama2 70b IR gen

* fix IR sec llama2 + debug

* llama270b

---------

Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
2023-08-25 23:04:28 -07:00
Gaurav Shukla
4e9be6be59 [chatbot] Add debug as class attribute (#1799)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-25 21:46:29 -07:00
Ean Garvey
9c8cbaf498 Add support for ROCM (Windows) in Studio + compile utils (#1770)
* WIP: MSVC ROCM support for SHARK Studio

* Make get_iree_rocm_args platform-agnostic.

* Update stable_args.py

* Update rocm arg handling in SD utils

* Guard quantization imports.

Co-authored-by: jam https://github.com/jammm
2023-08-25 20:56:05 -07:00
Ean Garvey
9e348a114e Revert changes process_skipfiles.py (#1798)
Keeps a small typo fix but reverts the rest of changes to this file from 450c231171
2023-08-25 15:31:49 -07:00
jinchen62
51f90a4d56 Update conversion passes for brevitas quant op (#1795) 2023-08-25 17:28:07 -05:00
Abhishek Varma
310d5d0a49 Fix llama2 13b crashing + add spec file for CLI execution of Llama (#1797)
* [Llama2] Add a fix for Llama2 13B downloading/crashing

-- This commit fixes downloading/crashing of llama2 13B on wrong
   .mlir file.
-- Also adds support for downloading vmfb from shark_tank in CLI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [llama2] Add a spec file to run Llama/Vicuna CLI exe

-- This commit adds a spec file to run Llama/Vicuna CLI exe.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-08-25 09:36:09 -05:00
Ean Garvey
9697981004 Pipe through a debug option to iree compile utils. (#1796)
* Update compile_utils.py

* Pipe through a flag to toggle debug options in compile utils.

* Update SharkLLMBase.py
2023-08-25 07:11:11 -07:00
Ean Garvey
450c231171 Add tokenizers to requirements.txt (#1790)
* Add tokenizers to requirements and pin version

* Update process_skipfiles.py
2023-08-24 19:44:04 -05:00
Ean Garvey
07f6f4a2f7 Add a short README for the OPT examples and small tweaks. (#1793)
* Small changes to OPT example.

* Update opt README.

* Add a few modes to batch script.

* Update README.md
2023-08-24 17:26:11 -07:00
jinchen62
610813c72f Add iree flag to strip assertions (#1791) 2023-08-24 10:51:19 -07:00
Ean Garvey
8e3860c9e6 Remove flags that are default in upstream IREE (#1785)
* Remove index bits flags now set by default

* Update shark_studio_imports.py
2023-08-24 11:57:54 -05:00
xzuyn
e37d6720eb Add Hires Fix (#1787)
* improper test hiresfix

* add sliders & use `clear_cache`

* add resample choices & fix step adjustment

* add step adjustment to img2img

* add resample options to img2img

* simplify hiresfix
- import `img2img_inf` from `img2img_ui.py` instead of just copying it into `txt2img_ui.py`

* set `hri` to None after using

* add more resample types, and don't show output until hiresfix is done

* cleaner implementation

* ran black

* ran black again with jupyter dependencies
2023-08-24 09:01:41 -07:00
Vivek Khandelwal
16160d9a7d Fix combine mlir script 2023-08-24 19:10:49 +05:30
Sungsoon Cho
79075a1a07 Opt perf (#1786)
* Define command line args, model-name, max-seq-len, platform, etc.

* Add usage example.

* Add opt_perf_comparision_batch.py.

* Use shlex instead.
2023-08-24 08:33:12 -05:00
Abhishek Varma
db990826d3 Add Llama2 13B int4 fp16 support (#1784)
Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-08-23 10:00:32 -07:00
gpetters94
7ee3e4ba5d Add stencil_unet_512 support (#1778)
This should fix any remaining issues with stencils and long prompts.
2023-08-22 12:23:46 -04:00
Vivek Khandelwal
05889a8fe1 Add LLaMa2-int4-fp16 support (#1782) 2023-08-22 07:45:50 -07:00
jinchen62
b87efe7686 Fix venv setup for brevitas (#1779) 2023-08-21 11:58:51 -07:00
gpetters94
82b462de3a Fix stencils for long prompts (#1777) 2023-08-19 00:26:51 -07:00
Daniel Garvey
d8f0f7bade replace public with private (#1776)
unload footguns
2023-08-18 14:22:46 -07:00
gpetters94
79bd0b84a1 Fix an issue with diffusers>0.19.3 (#1775) 2023-08-18 14:06:06 -04:00
jinchen62
8738571d1e Adapt the change of brevitas custom op name (#1772) 2023-08-17 14:24:43 -07:00
Gaurav Shukla
a4c354ce54 [version] Pin diffusers==0.19.3
Once the latest works with LORA train, unpin it.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
cc53efa89f [cli] Fix chatbot cli
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
9ae8bc921e [chatbot] Fix chatbot cli and webview warning
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
32eb78f0f9 [chatbot] Fix switching parameters in chatbot
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 19:14:17 +05:30
Ean Garvey
cb509343d9 Fix pytest benchmarks and shark_tank generation. (#1632)
- fix setup_venv.sh for benchmarks/imports etc.
- fix torch benchmarks in SharkBenchmarkRunner
- generate SD artifacts using build_tools/stable_diffusion_testing.py and --import_mlir
- decouple SD gen from tank/generate_sharktank for now
2023-08-16 17:48:47 -05:00
powderluv
6da391c9b1 update signtool to use /fd certHash 2023-08-15 15:11:40 -07:00
Ean Garvey
9dee7ae652 fix tkinter window (#1766) 2023-08-15 13:23:09 -07:00
Ean Garvey
343dfd901c Update SHARK-Runtime links to SRT (#1765)
* Update nightly.yml

* Update setup_venv.ps1

* Update CMakeLists.txt

* Update shark_iree_profiling.md

* Update setup_venv.sh

* Update README.md

* Update .gitmodules

* Update CMakeLists.txt

* Update README.md

* fix signtool flags

* Update nightly.yml

* Update benchmark_utils.py

* uncomment tkinter launch
2023-08-15 12:40:44 -07:00
Ean Garvey
57260b9c37 (Studio) Add hf-hub to pyinstaller metadata (#1761) 2023-08-14 23:01:50 -05:00
Ean Garvey
18e7d2d061 Enable vae tunings for rdna3. (#1764) 2023-08-14 21:00:14 -07:00
Stanley Winata
51a1009796 Add Forward method to SHARKRunner and fix examples. (#1756) 2023-08-14 19:20:37 -07:00
Daniel Garvey
045c3c3852 enable iree-opt-const-expr-hoisting in vicuna (#1742)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-08-14 18:43:42 -07:00
Ean Garvey
0139dd58d9 Specify max allocation size in IREE compile args. (#1760) 2023-08-14 15:43:09 -05:00
Ean Garvey
c96571855a prevents recompiles for cuda benchmarks + update benchmark_module path (#1759)
* xfail resnet50_fp16

* Fix cuda benchmarks and prevent recompilation.
2023-08-14 15:30:32 -05:00
PhaneeshB
4f61d69d86 add support passing iree flags for LLMs 2023-08-15 00:22:56 +05:30
Phaneesh Barwaria
531d447768 set default allocator for metal device creation (#1755) 2023-08-14 06:17:52 -07:00
Vivek Khandelwal
16f46f8de9 Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
c4723f469f Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
d804f45a61 Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
d22177f936 Update requirements.txt 2023-08-14 14:32:19 +05:30
George Petterson
75e68f02f4 Remove CUDNN 2023-08-14 14:32:19 +05:30
Gaurav Shukla
4dc9c59611 [chatbot] Add tokens generated per second (#1753) 2023-08-13 11:25:41 -07:00
Gaurav Shukla
18801dcabc [chat] Update chatbot ui
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-13 18:39:22 +05:30
Gaurav Shukla
3c577f7168 [vicuna] fix shard config generator script (#1747)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-10 11:26:03 -07:00
Stefan Kapusniak
f5e4fa6ffe UI/Web - Revert tab order (#1724)
* Revert ui tab order

* Reverts the tab order, so that SD, LLM, and Experimental are grouped
together again as far as is possible.
* Labelled "Generate Sharding Config" as experimental as pressing the
'Get Model Config' errors for me.

* Fix formatting in index.py
2023-08-10 11:25:36 -07:00
powderluv
48de445325 Enable caching and disable vma (#1746)
* Enable caching allocator by default

Going to toggle VMA off too and this is required for performance.  Will have to monitor in the wild reports.

* Disable VMA

Disable VMA
2023-08-10 10:49:44 -07:00
Gaurav Shukla
8e90f1b81a [vicuna] add default config in case of sharded vicuna
Signed-Off-by: Gaurav Shukla<gaurav@nod-labs.com>
2023-08-10 21:28:08 +05:30
Vivek Khandelwal
e8c1203be2 Fix vicuna script (#1745) 2023-08-10 06:11:14 -07:00
Vivek Khandelwal
e4d7abb519 Final patch for fixing Langchain token streaming issue (#1744) 2023-08-09 10:09:41 -07:00
powderluv
96185c9dc1 pin safetensors to 0.3.1 (#1740) 2023-08-08 19:24:44 -07:00
powderluv
bc22a81925 re-enable constant folding (#1739)
Tested and works well. (modulo unrelated driver issue)
2023-08-08 17:17:38 -07:00
Eliasj42
5203679f1f Bandaid fix 2 (#1728)
* download all mlirs

* fixed install method

* download all mlirs (#1727)

Co-authored-by: Elias Joseph <elias@nod-labs.com>

* added taggs

* fix name check for file existence

* Remove SD from all_models.csv (#1706)

Removes SD from pytests as it has its own test suite.

* gpt_langchain.py fixes for pydantic (#1722)

* removed dead code

---------

Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: Stefan Kapusniak <121311569+one-lithe-rune@users.noreply.github.com>
2023-08-08 12:14:57 -05:00
Vivek Khandelwal
bf073f8f37 [Langchain] Expand pipelines to fix token streaming issue 2023-08-08 10:27:23 +05:30
Stella Laurenzo
cec6eda6b4 Optimize device enumeration overhead and log details on long operations. (#1734)
* Optimize device enumeration overhead and log details on long operations.

* Various fixes to add `@functools.cache` to what should be one time, expensive, device enumeration and setup activities. Cuts several seconds off of initialization on my machine.
* Add detailed tracing to actual invocations if they exceed a certain timeout or have an exception.
* Add detailed tracing to loading status.
* By default detail logging is only printed if an operation takes an excessive amount of time. All logging/timing can be printed by setting the variable `$env:SHARK_DETAIL_TRACE = "1"`

* Remove cache from unhashable functions
2023-08-07 17:20:53 -07:00
Stella Laurenzo
9e37e03741 Clearly differentiate phases of loading modules to better understand if things are taking a long time. (#1733) 2023-08-07 14:03:12 -07:00
Stefan Kapusniak
9b8c4401b5 gpt_langchain.py fixes for pydantic (#1722) 2023-08-07 00:55:38 -07:00
Ean Garvey
a9f95a218b Remove SD from all_models.csv (#1706)
Removes SD from pytests as it has its own test suite.
2023-08-05 15:55:52 -05:00
PhaneeshB
872bd72d0b fix name check for file existence 2023-08-05 21:33:53 +05:30
Eliasj42
fd1c4db5d0 download all mlirs (#1727)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-08-04 18:22:06 -05:00
Daniel Garvey
759664bb48 add py files to pyinstaller for shark (#1723) 2023-08-04 14:10:43 -07:00
Daniel Garvey
14fd0cdd87 add missing subprocess import (#1721) 2023-08-04 15:15:22 -05:00
Daniel Garvey
a57eccc997 fix lint (#1720) 2023-08-04 14:54:33 -05:00
Daniel Garvey
a686d7d89f temporarily disable langchain stuff in webui (#1719)
its breaking the exe
2023-08-04 12:48:06 -07:00
Eliasj42
ed484b8253 added functionality for int8 vicuna and 4 shards (#1712)
combined vicuna_4_shards.py and vicuna.py to reduce code duplication

Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-08-04 14:05:05 -05:00
gpetters94
7fe57ebaaf Add vector database and add support on the web UI (#1699) 2023-08-04 13:47:19 -04:00
Nithin Meganathan
c287fd2be8 Add GPU ID's in model_confg.json by default for manual annotation (#1718) 2023-08-04 12:46:27 -05:00
Gaurav Shukla
51ec1a1360 [vicuna] Integrate sharded vicuna in web (#1717)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-04 11:46:53 -05:00
Gaurav Shukla
bd30044c0b [Shard] Add sharding generation in shark studio
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-04 21:51:14 +05:30
Ean Garvey
c9de2729b2 Add flag for toggling constant folding. (#1714) 2023-08-04 04:55:52 -07:00
Vivek Khandelwal
a5b13fcc2f [Langchain] Patch for fixing streaming of tokens (#1709) 2023-08-03 10:06:49 -07:00
Stefan Kapusniak
6bb329c4af Unsharded Vicuna: Fix Memory Error compiling mlir for lmsys/vicuna-7b-v1.3 fp16 with 64 GiB (#1702) 2023-08-01 06:07:56 -07:00
84 changed files with 5129 additions and 3869 deletions

View File

@@ -51,11 +51,11 @@ jobs:
run: |
./setup_venv.ps1
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets
@@ -104,7 +104,7 @@ jobs:
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
@@ -134,6 +134,8 @@ jobs:
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
fi
export SHA=$(git log -1 --format='%h')
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
rm -rf ./wheelhouse/nodai*
- name: Build and validate the SHARK Runtime package
@@ -144,7 +146,7 @@ jobs:
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models

6
.gitignore vendored
View File

@@ -193,3 +193,9 @@ stencil_annotator/
# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData
# Embeded browser cache and other
apps/stable_diffusion/web/EBWebView/
# Llama2 tokenizer configs
llama2_tokenizer_configs/

2
.gitmodules vendored
View File

@@ -1,4 +1,4 @@
[submodule "inference/thirdparty/shark-runtime"]
path = inference/thirdparty/shark-runtime
url =https://github.com/nod-ai/SHARK-Runtime.git
url =https://github.com/nod-ai/SRT.git
branch = shark-06032022

View File

@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
@@ -170,7 +170,7 @@ python -m pip install --upgrade pip
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
```shell
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```
### Run shark tank model tests.

View File

@@ -1,4 +1,3 @@
"""Load question answering chains."""
from __future__ import annotations
from typing import (
Any,
@@ -11,23 +10,34 @@ from typing import (
Union,
Protocol,
)
import inspect
import json
import warnings
from pathlib import Path
import yaml
from abc import ABC, abstractmethod
import langchain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.question_answering import stuff_prompt
from langchain.prompts.base import BasePromptTemplate
from langchain.docstore.document import Document
from abc import ABC, abstractmethod
from langchain.chains.base import Chain
from langchain.callbacks.manager import (
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult, PromptValue
from pydantic import Extra, Field, root_validator
from pydantic import Extra, Field, root_validator, validator
def _get_verbosity() -> bool:
return langchain.verbose
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
@@ -48,6 +58,413 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
return prompt.format(**document_info)
class Chain(Serializable, ABC):
"""Base interface that all chains should implement."""
memory: Optional[BaseMemory] = None
callbacks: Callbacks = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(
default=None, exclude=True
)
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
tags: Optional[List[str]] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def _chain_type(self) -> str:
raise NotImplementedError("Saving not supported for this chain type.")
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Input keys this chain expects."""
@property
@abstractmethod
def output_keys(self) -> List[str]:
"""Output keys this chain expects."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present."""
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
missing_keys = set(self.output_keys).difference(outputs)
if missing_keys:
raise ValueError(f"Missing some output keys: {missing_keys}")
@abstractmethod
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
input_docs = inputs["input_documents"]
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
)
if "is_first" in inputs.keys() and not inputs["is_first"]:
run_manager_ = run_manager
input_list = [inputs]
stop = None
prompts = []
for inputs in input_list:
selected_inputs = {
k: inputs[k] for k in self.prompt.input_variables
}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager_:
run_manager_.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
prompt_strings = [p.to_string() for p in prompts]
prompts = prompt_strings
callbacks = run_manager_.get_child() if run_manager_ else None
tags = None
"""Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will
# not make sense.
if not isinstance(prompts, list):
raise ValueError(
"Argument 'prompts' is expected to be of type List[str], received"
f" argument of type {type(prompts)}."
)
params = self.llm.dict()
params["stop"] = stop
options = {"stop": stop}
disregard_cache = self.llm.cache is not None and not self.llm.cache
callback_manager = CallbackManager.configure(
callbacks,
self.llm.callbacks,
self.llm.verbose,
tags,
self.llm.tags,
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.llm.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager_ = callback_manager.on_llm_start(
dumpd(self),
prompts,
invocation_params=params,
options=options,
)
generations = []
for prompt in prompts:
inputs_ = prompt
num_workers = None
batch_size = None
if num_workers is None:
if self.llm.pipeline._num_workers is None:
num_workers = 0
else:
num_workers = self.llm.pipeline._num_workers
if batch_size is None:
if self.llm.pipeline._batch_size is None:
batch_size = 1
else:
batch_size = self.llm.pipeline._batch_size
preprocess_params = {}
generate_kwargs = {}
preprocess_params.update(generate_kwargs)
forward_params = generate_kwargs
postprocess_params = {}
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
preprocess_params = {
**self.llm.pipeline._preprocess_params,
**preprocess_params,
}
forward_params = {
**self.llm.pipeline._forward_params,
**forward_params,
}
postprocess_params = {
**self.llm.pipeline._postprocess_params,
**postprocess_params,
}
self.llm.pipeline.call_count += 1
if (
self.llm.pipeline.call_count > 10
and self.llm.pipeline.framework == "pt"
and self.llm.pipeline.device.type == "cuda"
):
warnings.warn(
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
" dataset",
UserWarning,
)
model_inputs = self.llm.pipeline.preprocess(
inputs_, **preprocess_params
)
model_outputs = self.llm.pipeline.forward(
model_inputs, **forward_params
)
model_outputs["process"] = False
return model_outputs
output = LLMResult(generations=generations)
run_manager_.on_llm_end(output)
if run_manager_:
output.run = RunInfo(run_id=run_manager_.run_id)
response = output
outputs = [
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
][0]
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
else:
_run_manager = (
run_manager or CallbackManagerForChainRun.get_noop_manager()
)
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {
k: v for k, v in inputs.items() if k != self.input_key
}
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in other_keys.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
inputs["is_first"] = False
inputs["input_documents"] = input_docs
# Call predict on the LLM.
output = self.llm_chain(inputs, callbacks=_run_manager.get_child())
if "process" in output.keys() and not output["process"]:
return output
output = output[self.llm_chain.output_key]
extra_return_dict = {}
extra_return_dict[self.output_key] = output
outputs = extra_return_dict
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(
self, inputs: Union[Dict[str, Any], Any]
) -> Dict[str, str]:
"""Validate and prep inputs."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(
self.memory.memory_variables
)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]
def run(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
if args and not kwargs:
if len(args) != 1:
raise ValueError(
"`run` supports only one positional argument."
)
return self(args[0], callbacks=callbacks, tags=tags)[
self.output_keys[0]
]
if kwargs and not args:
return self(kwargs, callbacks=callbacks, tags=tags)[
self.output_keys[0]
]
if not kwargs and not args:
raise ValueError(
"`run` supported with either positional arguments or keyword arguments,"
" but none were provided."
)
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of chain."""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict()
_dict["_type"] = self._chain_type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the chain.
Args:
file_path: Path to file to save the chain to.
Example:
.. code-block:: python
chain.save(file_path="path/chain.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
chain_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(chain_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(chain_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents."""
@@ -79,12 +496,6 @@ class BaseCombineDocumentsChain(Chain, ABC):
"""
return None
@abstractmethod
def combine_docs(
self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]:
"""Combine documents into a single string."""
def _call(
self,
inputs: Dict[str, List[Document]],
@@ -96,13 +507,49 @@ class BaseCombineDocumentsChain(Chain, ABC):
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
output, extra_return_dict = self.combine_docs(
docs, callbacks=_run_manager.get_child(), **other_keys
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in other_keys.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
# Call predict on the LLM.
output, extra_return_dict = (
self.llm_chain(inputs, callbacks=_run_manager.get_child())[
self.llm_chain.output_key
],
{},
)
extra_return_dict[self.output_key] = output
return extra_return_dict
from pydantic import BaseModel
class Generation(Serializable):
"""Output of a single generation."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw generation info response from the provider"""
"""May include things like reason for finishing (e.g. in OpenAI)"""
# TODO: add log probs
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
class LLMChain(Chain):
"""Chain to run queries against LLMs.
@@ -153,21 +600,13 @@ class LLMChain(Chain):
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
prompts, stop = self.prep_prompts([inputs], run_manager=run_manager)
response = self.llm.generate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
)
return self.create_outputs(response)[0]
def prep_prompts(
self,
@@ -223,23 +662,6 @@ class LLMChain(Chain):
for generation in response.generations
]
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return self(kwargs, callbacks=callbacks)[self.output_key]
def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, Any]]:
@@ -350,14 +772,6 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
@property
def _chain_type(self) -> str:
return "stuff_documents_chain"

View File

@@ -1129,7 +1129,7 @@ class Langchain:
max_time=max_time,
num_return_sequences=num_return_sequences,
)
for r in run_qa_db(
out = run_qa_db(
query=instruction,
iinput=iinput,
context=context,
@@ -1170,689 +1170,8 @@ class Langchain:
auto_reduce_chunks=auto_reduce_chunks,
max_chunks=max_chunks,
device=self.device,
):
(
outr,
extra,
) = r # doesn't accumulate, new answer every yield, so only save that full answer
yield dict(response=outr, sources=extra)
if save_dir:
extra_dict = gen_hyper_langchain.copy()
extra_dict.update(
prompt_type=prompt_type,
inference_server=inference_server,
langchain_mode=langchain_mode,
langchain_action=langchain_action,
document_choice=document_choice,
num_prompt_tokens=num_prompt_tokens,
instruction=instruction,
iinput=iinput,
context=context,
)
save_generate_output(
prompt=prompt,
output=outr,
base_model=base_model,
save_dir=save_dir,
where_from="run_qa_db",
extra_dict=extra_dict,
)
if verbose:
print(
"Post-Generate Langchain: %s decoded_output: %s"
% (str(datetime.now()), len(outr) if outr else -1),
flush=True,
)
if outr or base_model in non_hf_types:
# if got no response (e.g. not showing sources and got no sources,
# so nothing to give to LLM), then slip through and ask LLM
# Or if llama/gptj, then just return since they had no response and can't go down below code path
# clear before return, since .then() never done if from API
clear_torch_cache()
return
if inference_server.startswith(
"openai"
) or inference_server.startswith("http"):
if inference_server.startswith("openai"):
import openai
where_from = "openai_client"
openai.api_key = os.getenv("OPENAI_API_KEY")
stop_sequences = list(
set(prompter.terminate_response + [prompter.PreResponse])
)
stop_sequences = [x for x in stop_sequences if x]
# OpenAI will complain if ask for too many new tokens, takes it as min in some sense, wrongly so.
max_new_tokens_openai = min(
max_new_tokens, model_max_length - num_prompt_tokens
)
gen_server_kwargs = dict(
temperature=temperature if do_sample else 0,
max_tokens=max_new_tokens_openai,
top_p=top_p if do_sample else 1,
frequency_penalty=0,
n=num_return_sequences,
presence_penalty=1.07
- repetition_penalty
+ 0.6, # so good default
)
if inference_server == "openai":
response = openai.Completion.create(
model=base_model,
prompt=prompt,
**gen_server_kwargs,
stop=stop_sequences,
stream=stream_output,
)
if not stream_output:
text = response["choices"][0]["text"]
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
collected_events = []
text = ""
for event in response:
collected_events.append(
event
) # save the event response
event_text = event["choices"][0][
"text"
] # extract the text
text += event_text # append the text
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
elif inference_server == "openai_chat":
response = openai.ChatCompletion.create(
model=base_model,
messages=[
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": prompt,
},
],
stream=stream_output,
**gen_server_kwargs,
)
if not stream_output:
text = response["choices"][0]["message"]["content"]
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
text = ""
for chunk in response:
delta = chunk["choices"][0]["delta"]
if "content" in delta:
text += delta["content"]
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
raise RuntimeError(
"No such OpenAI mode: %s" % inference_server
)
elif inference_server.startswith("http"):
inference_server, headers = get_hf_server(inference_server)
from gradio_utils.grclient import GradioClient
from text_generation import Client as HFClient
if isinstance(model, GradioClient):
gr_client = model
hf_client = None
elif isinstance(model, HFClient):
gr_client = None
hf_client = model
else:
(
inference_server,
gr_client,
hf_client,
) = self.get_client_from_inference_server(
inference_server, base_model=base_model
)
# quick sanity check to avoid long timeouts, just see if can reach server
requests.get(
inference_server,
timeout=int(os.getenv("REQUEST_TIMEOUT_FAST", "10")),
)
if gr_client is not None:
# Note: h2oGPT gradio server could handle input token size issues for prompt,
# but best to handle here so send less data to server
chat_client = False
where_from = "gr_client"
client_langchain_mode = "Disabled"
client_langchain_action = LangChainAction.QUERY.value
gen_server_kwargs = dict(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
early_stopping=early_stopping,
max_time=max_time,
repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
chat=chat_client,
)
# account for gradio into gradio that handles prompting, avoid duplicating prompter prompt injection
if prompt_type in [
None,
"",
PromptType.plain.name,
PromptType.plain.value,
str(PromptType.plain.value),
]:
# if our prompt is plain, assume either correct or gradio server knows different prompt type,
# so pass empty prompt_Type
gr_prompt_type = ""
gr_prompt_dict = ""
gr_prompt = prompt # already prepared prompt
gr_context = ""
gr_iinput = ""
else:
# if already have prompt_type that is not plain, None, or '', then already applied some prompting
# But assume server can handle prompting, and need to avoid double-up.
# Also assume server can do better job of using stopping.py to stop early, so avoid local prompting, let server handle
# So avoid "prompt" and let gradio server reconstruct from prompt_type we passed
# Note it's ok that prompter.get_response() has prompt+text, prompt=prompt passed,
# because just means extra processing and removal of prompt, but that has no human-bot prompting doesn't matter
# since those won't appear
gr_context = context
gr_prompt = instruction
gr_iinput = iinput
gr_prompt_type = prompt_type
gr_prompt_dict = prompt_dict
client_kwargs = dict(
instruction=gr_prompt
if chat_client
else "", # only for chat=True
iinput=gr_iinput, # only for chat=True
context=gr_context,
# streaming output is supported, loops over and outputs each generation in streaming mode
# but leave stream_output=False for simple input/output mode
stream_output=stream_output,
**gen_server_kwargs,
prompt_type=gr_prompt_type,
prompt_dict=gr_prompt_dict,
instruction_nochat=gr_prompt
if not chat_client
else "",
iinput_nochat=gr_iinput, # only for chat=False
langchain_mode=client_langchain_mode,
langchain_action=client_langchain_action,
top_k_docs=top_k_docs,
chunk=chunk,
chunk_size=chunk_size,
document_choice=[DocumentChoices.All_Relevant.name],
)
api_name = "/submit_nochat_api" # NOTE: like submit_nochat but stable API for string dict passing
if not stream_output:
res = gr_client.predict(
str(dict(client_kwargs)), api_name=api_name
)
res_dict = ast.literal_eval(res)
text = res_dict["response"]
sources = res_dict["sources"]
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources=sources,
)
else:
job = gr_client.submit(
str(dict(client_kwargs)), api_name=api_name
)
text = ""
sources = ""
res_dict = dict(response=text, sources=sources)
while not job.done():
outputs_list = job.communicator.job.outputs
if outputs_list:
res = job.communicator.job.outputs[-1]
res_dict = ast.literal_eval(res)
text = res_dict["response"]
sources = res_dict["sources"]
if gr_prompt_type == "plain":
# then gradio server passes back full prompt + text
prompt_and_text = text
else:
prompt_and_text = prompt + text
yield dict(
response=prompter.get_response(
prompt_and_text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources=sources,
)
time.sleep(0.01)
# ensure get last output to avoid race
res_all = job.outputs()
if len(res_all) > 0:
res = res_all[-1]
res_dict = ast.literal_eval(res)
text = res_dict["response"]
sources = res_dict["sources"]
else:
# go with old text if last call didn't work
e = job.future._exception
if e is not None:
stre = str(e)
strex = "".join(
traceback.format_tb(e.__traceback__)
)
else:
stre = ""
strex = ""
print(
"Bad final response: %s %s %s %s %s: %s %s"
% (
base_model,
inference_server,
res_all,
prompt,
text,
stre,
strex,
),
flush=True,
)
if gr_prompt_type == "plain":
# then gradio server passes back full prompt + text
prompt_and_text = text
else:
prompt_and_text = prompt + text
yield dict(
response=prompter.get_response(
prompt_and_text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources=sources,
)
elif hf_client:
# HF inference server needs control over input tokens
where_from = "hf_client"
# prompt must include all human-bot like tokens, already added by prompt
# https://github.com/huggingface/text-generation-inference/tree/main/clients/python#types
stop_sequences = list(
set(
prompter.terminate_response
+ [prompter.PreResponse]
)
)
stop_sequences = [x for x in stop_sequences if x]
gen_server_kwargs = dict(
do_sample=do_sample,
max_new_tokens=max_new_tokens,
# best_of=None,
repetition_penalty=repetition_penalty,
return_full_text=True,
seed=SEED,
stop_sequences=stop_sequences,
temperature=temperature,
top_k=top_k,
top_p=top_p,
# truncate=False, # behaves oddly
# typical_p=top_p,
# watermark=False,
# decoder_input_details=False,
)
# work-around for timeout at constructor time, will be issue if multi-threading,
# so just do something reasonable or max_time if larger
# lower bound because client is re-used if multi-threading
hf_client.timeout = max(300, max_time)
if not stream_output:
text = hf_client.generate(
prompt, **gen_server_kwargs
).generated_text
yield dict(
response=prompter.get_response(
text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
text = ""
for response in hf_client.generate_stream(
prompt, **gen_server_kwargs
):
if not response.token.special:
# stop_sequences
text_chunk = response.token.text
text += text_chunk
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
raise RuntimeError(
"Failed to get client: %s" % inference_server
)
else:
raise RuntimeError(
"No such inference_server %s" % inference_server
)
if save_dir and text:
# save prompt + new text
extra_dict = gen_server_kwargs.copy()
extra_dict.update(
dict(
inference_server=inference_server,
num_prompt_tokens=num_prompt_tokens,
)
)
save_generate_output(
prompt=prompt,
output=text,
base_model=base_model,
save_dir=save_dir,
where_from=where_from,
extra_dict=extra_dict,
)
return
else:
assert not inference_server, (
"inferene_server=%s not supported" % inference_server
)
if isinstance(tokenizer, str):
# pipeline
if tokenizer == "summarization":
key = "summary_text"
else:
raise RuntimeError("No such task type %s" % tokenizer)
# NOTE: uses max_length only
yield dict(
response=model(prompt, max_length=max_new_tokens)[0][key],
sources="",
)
if "mbart-" in base_model.lower():
assert src_lang is not None
tokenizer.src_lang = self.languages_covered()[src_lang]
stopping_criteria = get_stopping(
prompt_type,
prompt_dict,
tokenizer,
self.device,
model_max_length=tokenizer.model_max_length,
)
print(prompt)
# exit(0)
inputs = tokenizer(prompt, return_tensors="pt")
if debug and len(inputs["input_ids"]) > 0:
print("input_ids length", len(inputs["input_ids"][0]), flush=True)
input_ids = inputs["input_ids"].to(self.device)
# CRITICAL LIMIT else will fail
max_max_tokens = tokenizer.model_max_length
max_input_tokens = max_max_tokens - min_new_tokens
# NOTE: Don't limit up front due to max_new_tokens, let go up to max or reach max_max_tokens in stopping.py
input_ids = input_ids[:, -max_input_tokens:]
# required for falcon if multiple threads or asyncio accesses to model during generation
if use_cache is None:
use_cache = False if "falcon" in base_model else True
gen_config_kwargs = dict(
temperature=float(temperature),
top_p=float(top_p),
top_k=top_k,
num_beams=num_beams,
do_sample=do_sample,
repetition_penalty=float(repetition_penalty),
num_return_sequences=num_return_sequences,
renormalize_logits=True,
remove_invalid_values=True,
use_cache=use_cache,
)
token_ids = [
"eos_token_id",
"pad_token_id",
"bos_token_id",
"cls_token_id",
"sep_token_id",
]
for token_id in token_ids:
if (
hasattr(tokenizer, token_id)
and getattr(tokenizer, token_id) is not None
):
gen_config_kwargs.update(
{token_id: getattr(tokenizer, token_id)}
)
generation_config = GenerationConfig(**gen_config_kwargs)
gen_kwargs = dict(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens, # prompt + new
min_new_tokens=min_new_tokens, # prompt + new
early_stopping=early_stopping, # False, True, "never"
max_time=max_time,
stopping_criteria=stopping_criteria,
)
if "gpt2" in base_model.lower():
gen_kwargs.update(
dict(
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
)
elif "mbart-" in base_model.lower():
assert tgt_lang is not None
tgt_lang = self.languages_covered()[tgt_lang]
gen_kwargs.update(
dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])
)
else:
token_ids = ["eos_token_id", "bos_token_id", "pad_token_id"]
for token_id in token_ids:
if (
hasattr(tokenizer, token_id)
and getattr(tokenizer, token_id) is not None
):
gen_kwargs.update({token_id: getattr(tokenizer, token_id)})
decoder_kwargs = dict(
skip_special_tokens=True, clean_up_tokenization_spaces=True
)
decoder = functools.partial(tokenizer.decode, **decoder_kwargs)
decoder_raw_kwargs = dict(
skip_special_tokens=False, clean_up_tokenization_spaces=True
)
decoder_raw = functools.partial(tokenizer.decode, **decoder_raw_kwargs)
with torch.no_grad():
have_lora_weights = lora_weights not in [no_lora_str, "", None]
context_class_cast = (
NullContext
if self.device == "cpu"
or have_lora_weights
or self.device == "mps"
else torch.autocast
)
with context_class_cast(self.device):
# protection for gradio not keeping track of closed users,
# else hit bitsandbytes lack of thread safety:
# https://github.com/h2oai/h2ogpt/issues/104
# but only makes sense if concurrency_count == 1
context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
if verbose:
print("Pre-Generate: %s" % str(datetime.now()), flush=True)
decoded_output = None
with context_class("generate.lock"):
if verbose:
print("Generate: %s" % str(datetime.now()), flush=True)
# decoded tokenized prompt can deviate from prompt due to special characters
inputs_decoded = decoder(input_ids[0])
inputs_decoded_raw = decoder_raw(input_ids[0])
if inputs_decoded == prompt:
# normal
pass
elif inputs_decoded.lstrip() == prompt.lstrip():
# sometimes extra space in front, make prompt same for prompt removal
prompt = inputs_decoded
elif inputs_decoded_raw == prompt:
# some models specify special tokens that are part of normal prompt, so can't skip them
inputs_decoded = prompt = inputs_decoded_raw
decoder = decoder_raw
decoder_kwargs = decoder_raw_kwargs
elif inputs_decoded_raw.replace("<unk> ", "").replace(
"<unk>", ""
).replace("\n", " ").replace(" ", "") == prompt.replace(
"\n", " "
).replace(
" ", ""
):
inputs_decoded = prompt = inputs_decoded_raw
decoder = decoder_raw
decoder_kwargs = decoder_raw_kwargs
else:
if verbose:
print(
"WARNING: Special characters in prompt",
flush=True,
)
if stream_output:
skip_prompt = False
streamer = H2OTextIteratorStreamer(
tokenizer,
skip_prompt=skip_prompt,
block=False,
**decoder_kwargs,
)
gen_kwargs.update(dict(streamer=streamer))
target = wrapped_partial(
self.generate_with_exceptions,
model.generate,
prompt=prompt,
inputs_decoded=inputs_decoded,
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
**gen_kwargs,
)
bucket = queue.Queue()
thread = EThread(
target=target, streamer=streamer, bucket=bucket
)
thread.start()
outputs = ""
try:
for new_text in streamer:
if bucket.qsize() > 0 or thread.exc:
thread.join()
outputs += new_text
yield dict(
response=prompter.get_response(
outputs,
prompt=inputs_decoded,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
except BaseException:
# if any exception, raise that exception if was from thread, first
if thread.exc:
raise thread.exc
raise
finally:
# clear before return, since .then() never done if from API
clear_torch_cache()
# in case no exception and didn't join with thread yet, then join
if not thread.exc:
thread.join()
# in case raise StopIteration or broke queue loop in streamer, but still have exception
if thread.exc:
raise thread.exc
decoded_output = outputs
else:
try:
outputs = model.generate(**gen_kwargs)
finally:
clear_torch_cache() # has to be here for API submit_nochat_api since.then() not called
outputs = [decoder(s) for s in outputs.sequences]
yield dict(
response=prompter.get_response(
outputs,
prompt=inputs_decoded,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
if outputs and len(outputs) >= 1:
decoded_output = prompt + outputs[0]
if save_dir and decoded_output:
extra_dict = gen_config_kwargs.copy()
extra_dict.update(
dict(num_prompt_tokens=num_prompt_tokens)
)
save_generate_output(
prompt=prompt,
output=decoded_output,
base_model=base_model,
save_dir=save_dir,
where_from="evaluate_%s" % str(stream_output),
extra_dict=gen_config_kwargs,
)
if verbose:
print(
"Post-Generate: %s decoded_output: %s"
% (
str(datetime.now()),
len(decoded_output) if decoded_output else -1,
),
flush=True,
)
return outputs[0]
return out
inputs_list_names = list(inspect.signature(evaluate).parameters)
global inputs_kwargs_list

View File

@@ -436,7 +436,7 @@ class GradioInference(LLM):
chat_client: bool = False
return_full_text: bool = True
stream: bool = False
stream_output: bool = Field(False, alias="stream")
sanitize_bot_response: bool = False
prompter: Any = None
@@ -481,7 +481,7 @@ class GradioInference(LLM):
# so server should get prompt_type or '', not plain
# This is good, so gradio server can also handle stopping.py conditions
# this is different than TGI server that uses prompter to inject prompt_type prompting
stream_output = self.stream
stream_output = self.stream_output
gr_client = self.client
client_langchain_mode = "Disabled"
client_langchain_action = LangChainAction.QUERY.value
@@ -596,7 +596,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
inference_server_url: str = ""
timeout: int = 300
headers: dict = None
stream: bool = False
stream_output: bool = Field(False, alias="stream")
sanitize_bot_response: bool = False
prompter: Any = None
tokenizer: Any = None
@@ -663,7 +663,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
# lower bound because client is re-used if multi-threading
self.client.timeout = max(300, self.timeout)
if not self.stream:
if not self.stream_output:
res = self.client.generate(
prompt,
**gen_server_kwargs,
@@ -852,7 +852,7 @@ def get_llm(
top_p=top_p,
# typical_p=top_p,
callbacks=callbacks if stream_output else None,
stream=stream_output,
stream_output=stream_output,
prompter=prompter,
tokenizer=tokenizer,
client=hf_client,
@@ -2510,8 +2510,7 @@ def _run_qa_db(
formatted_doc_chunks = "\n\n".join(
[get_url(x) + "\n\n" + x.page_content for x in docs]
)
yield formatted_doc_chunks, ""
return
return formatted_doc_chunks, ""
if not docs and langchain_action in [
LangChainAction.SUMMARIZE_MAP.value,
LangChainAction.SUMMARIZE_ALL.value,
@@ -2523,8 +2522,7 @@ def _run_qa_db(
else "No documents to summarize."
)
extra = ""
yield ret, extra
return
return ret, extra
if not docs and langchain_mode not in [
LangChainMode.DISABLED.value,
LangChainMode.CHAT_LLM.value,
@@ -2536,8 +2534,7 @@ def _run_qa_db(
else "No documents to query."
)
extra = ""
yield ret, extra
return
return ret, extra
if chain is None and model_name not in non_hf_types:
# here if no docs at all and not HF type
@@ -2557,22 +2554,7 @@ def _run_qa_db(
)
with context_class_cast(args.device):
answer = chain()
if not use_context:
ret = answer["output_text"]
extra = ""
yield ret, extra
elif answer is not None:
ret, extra = get_sources_answer(
query,
answer,
scores,
show_rank,
answer_with_sources,
verbose=verbose,
)
yield ret, extra
return
return answer
def get_similarity_chain(

View File

@@ -3,13 +3,11 @@ from apps.stable_diffusion.src.utils.utils import _compile_module
from io import BytesIO
import torch_mlir
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
from stopping import get_stopping
from prompter import Prompter, PromptType
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
@@ -31,14 +29,8 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
def brevitasmatmul_rhs_group_quant〡shape(
lhs: List[int],
rhs: List[int],
rhs_scale: List[int],
rhs_zero_point: List[int],
rhs_bit_width: int,
rhs_group_size: int,
) -> List[int]:
# fmt: off
def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
@@ -47,30 +39,21 @@ def brevitasmatmul_rhs_group_quant〡shape(
raise ValueError("Input shapes not supported.")
def brevitasmatmul_rhs_group_quant〡dtype(
lhs_rank_dtype: Tuple[int, int],
rhs_rank_dtype: Tuple[int, int],
rhs_scale_rank_dtype: Tuple[int, int],
rhs_zero_point_rank_dtype: Tuple[int, int],
rhs_bit_width: int,
rhs_group_size: int,
) -> int:
def quantmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def brevitasmatmul_rhs_group_quant〡has_value_semantics(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
) -> None:
def quantmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics,
]
quantmatmul_rhs_group_quant〡shape,
quantmatmul_rhs_group_quant〡dtype,
quantmatmul_rhs_group_quant〡has_value_semantics]
# fmt: on
global_device = "cuda"
global_precision = "fp16"
@@ -246,7 +229,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
ts_graph,
[*h2ogptCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
@@ -254,7 +237,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
@@ -285,7 +268,215 @@ class H2OGPTSHARKModel(torch.nn.Module):
return result
h2ogpt_model = H2OGPTSHARKModel()
def decode_tokens(tokenizer, res_tokens):
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = tokenizer.decode(res_tokens, skip_special_tokens=True)
return res_str
def generate_token(h2ogpt_shark_model, model, tokenizer, **generate_kwargs):
del generate_kwargs["max_time"]
generate_kwargs["input_ids"] = generate_kwargs["input_ids"].to(
device=tensor_device
)
generate_kwargs["attention_mask"] = generate_kwargs["attention_mask"].to(
device=tensor_device
)
truncated_input_ids = []
stopping_criteria = generate_kwargs["stopping_criteria"]
generation_config_ = GenerationConfig.from_model_config(model.config)
generation_config = copy.deepcopy(generation_config_)
model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
model_kwargs,
) = model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
input_ids_seq_length = input_ids.shape[-1]
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
)
logits_warper = model._get_logits_warper(generation_config)
(
input_ids,
model_kwargs,
) = model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=model.config.is_encoder_decoder, # False
**model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = (
torch.tensor(eos_token_id).to(device=tensor_device)
if eos_token_id is not None
else None
)
pad_token_id = generation_config.pad_token_id
eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(
input_ids.shape[0],
dtype=torch.long,
device=input_ids.device,
)
timesRan = 0
import time
start = time.time()
print("\n")
res_tokens = []
while True:
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs
)
outputs = h2ogpt_shark_model.forward(
model_inputs["input_ids"], model_inputs["attention_mask"]
)
if args.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = next_token * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
input_ids = torch.cat([input_ids, next_token[:, None]], dim=-1)
model_kwargs["past_key_values"] = None
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
truncated_input_ids.append(input_ids[:, 0])
input_ids = input_ids[:, 1:]
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 1:]
new_word = tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
res_tokens.append(next_token)
if new_word == "<0x0A>":
print("\n", end="", flush=True)
else:
print(f"{new_word}", end=" ", flush=True)
part_str = decode_tokens(tokenizer, res_tokens)
yield part_str
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_token.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0 or stopping_criteria(
input_ids, scores
):
break
timesRan = timesRan + 1
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / timesRan
)
)
torch.cuda.empty_cache()
gc.collect()
res_str = decode_tokens(tokenizer, res_tokens)
yield res_str
def pad_or_truncate_inputs(
@@ -498,233 +689,6 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
)
return records
def generate_new_token(self):
model_inputs = self.model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = h2ogpt_model.forward(
model_inputs["input_ids"], model_inputs["attention_mask"]
)
if args.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.truncated_input_ids.append(self.input_ids[:, 0])
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
def generate_token(self, **generate_kwargs):
del generate_kwargs["max_time"]
self.truncated_input_ids = []
generation_config_ = GenerationConfig.from_model_config(
self.model.config
)
generation_config = copy.deepcopy(generation_config_)
self.model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
self.stopping_criteria = (
self.stopping_criteria
if self.stopping_criteria is not None
else StoppingCriteriaList()
)
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
self.model_kwargs,
) = self.model._prepare_model_inputs(
None, generation_config.bos_token_id, self.model_kwargs
)
batch_size = inputs_tensor.shape[0]
self.model_kwargs[
"output_attentions"
] = generation_config.output_attentions
self.model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
self.model_kwargs["use_cache"] = generation_config.use_cache
self.input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else self.model_kwargs.pop("input_ids")
)
input_ids_seq_length = self.input_ids.shape[-1]
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
self.logits_processor = self.model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=self.stopping_criteria,
)
self.logits_warper = self.model._get_logits_warper(generation_config)
(
self.input_ids,
self.model_kwargs,
) = self.model._expand_inputs_for_generation(
input_ids=self.input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.model.config.is_encoder_decoder, # False
**self.model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id).to(device=tensor_device)
if eos_token_id is not None
else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
output_attentions = generation_config.output_attentions # False
output_hidden_states = generation_config.output_hidden_states # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
self.input_ids.shape[0],
dtype=torch.long,
device=self.input_ids.device,
)
timesRan = 0
import time
start = time.time()
print("\n")
while True:
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
print(f"{new_word}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(self.input_ids, self.scores)
):
break
timesRan = timesRan + 1
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / timesRan
)
)
self.input_ids = torch.cat(
[
torch.tensor(self.truncated_input_ids)
.to(device=tensor_device)
.unsqueeze(dim=0),
self.input_ids,
],
dim=-1,
)
torch.cuda.empty_cache()
gc.collect()
return self.input_ids
def _forward(self, model_inputs, **generate_kwargs):
if self.can_stop:
stopping_criteria = get_stopping(
@@ -784,19 +748,13 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
input_ids, attention_mask = pad_or_truncate_inputs(
input_ids, attention_mask, max_padding_length=max_padding_length
)
self.stopping_criteria = generate_kwargs["stopping_criteria"]
generated_sequence = self.generate_token(
input_ids=input_ids,
attention_mask=attention_mask,
**generate_kwargs,
)
out_b = generated_sequence.shape[0]
generated_sequence = generated_sequence.reshape(
in_b, out_b // in_b, *generated_sequence.shape[1:]
)
return {
"generated_sequence": generated_sequence,
return_dict = {
"model": self.model,
"tokenizer": self.tokenizer,
"input_ids": input_ids,
"prompt_text": prompt_text,
"attention_mask": attention_mask,
"attention_mask": attention_mask,
}
return_dict = {**return_dict, **generate_kwargs}
return return_dict

View File

@@ -1,5 +1,4 @@
import os
import fire
from gpt_langchain import (
path_to_docs,
@@ -202,7 +201,3 @@ def make_db_main(
if verbose:
print("DONE", flush=True)
return db, collection_name
if __name__ == "__main__":
fire.Fire(make_db_main)

View File

@@ -46,6 +46,7 @@ def compile_stableLM(
model_vmfb_name,
device="cuda",
precision="fp32",
debug=False,
):
from shark.shark_inference import SharkInference
@@ -92,7 +93,7 @@ def compile_stableLM(
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
)
print("Saved vmfb at ", str(path))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,94 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_submodules
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += copy_metadata('huggingface-hub')
datas += copy_metadata('sentencepiece')
datas += copy_metadata("pyyaml")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('py-cpuinfo')
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['scripts/vicuna.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_llama_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,877 @@
import argparse
import json
import re
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from typing import List, Optional, Tuple, Union
import numpy as np
import iree.runtime
import itertools
import subprocess
import torch
import torch_mlir
from torch_mlir import TensorPlaceholder
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LlamaPreTrainedModel,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
FirstVicunaLayer,
SecondVicunaLayer,
CompiledVicunaLayer,
ShardedVicunaModel,
LMHead,
LMHeadCompiled,
VicunaEmbedding,
VicunaEmbeddingCompiled,
VicunaNorm,
VicunaNormCompiled,
)
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna7B,
)
from apps.language_models.utils import (
get_vmfb_from_path,
)
from shark.shark_downloader import download_public_file
from shark.shark_importer import get_f16_inputs
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaRMSNorm,
_make_causal_mask,
_expand_mask,
)
from torch import nn
from time import time
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(config)
for _ in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
t1 = time()
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = (
use_cache if use_cache is not None else self.config.use_cache
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = (
seq_length_with_past + past_key_values_length
)
if position_ids is None:
device = (
input_ids.device
if input_ids is not None
else inputs_embeds.device
)
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.compressedlayers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[8 * idx : 8 * (idx + 1)]
if past_key_values is not None
else None
)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer.forward(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[1:],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
try:
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
except:
_ = 10
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
next_cache = tuple(itertools.chain.from_iterable(next_cache))
print(f"Token generated in {time() - t1} seconds")
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class EightLayerLayerSV(torch.nn.Module):
def __init__(self, layers):
super().__init__()
assert len(layers) == 8
self.layers = layers
def forward(
self,
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
):
pkvs = [
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
]
new_pkvs = []
for layer, pkv in zip(self.layers, pkvs):
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
pkv[0],
pkv[1],
),
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
return (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
class EightLayerLayerFV(torch.nn.Module):
def __init__(self, layers):
super().__init__()
assert len(layers) == 8
self.layers = layers
def forward(self, hidden_states, attention_mask, position_ids):
new_pkvs = []
for layer in self.layers:
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=None,
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
return (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
class CompiledEightLayerLayerSV(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
(
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
) = past_key_value
pkv00 = pkv00.detatch()
pkv01 = pkv01.detatch()
pkv10 = pkv10.detatch()
pkv11 = pkv11.detatch()
pkv20 = pkv20.detatch()
pkv21 = pkv21.detatch()
pkv30 = pkv30.detatch()
pkv31 = pkv31.detatch()
pkv40 = pkv40.detatch()
pkv41 = pkv41.detatch()
pkv50 = pkv50.detatch()
pkv51 = pkv51.detatch()
pkv60 = pkv60.detatch()
pkv61 = pkv61.detatch()
pkv70 = pkv70.detatch()
pkv71 = pkv71.detatch()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
),
send_to_host=False,
)
return (
output[0],
(output[1][0], output[1][1]),
(output[2][0], output[2][1]),
(output[3][0], output[3][1]),
(output[4][0], output[4][1]),
(output[5][0], output[5][1]),
(output[6][0], output[6][1]),
(output[7][0], output[7][1]),
(output[8][0], output[8][1]),
)
def forward_compressed(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = (
input_ids.device if input_ids is not None else inputs_embeds.device
)
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.compressedlayers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[8 * idx : 8 * (idx + 1)]
if past_key_values is not None
else None
)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1],
)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class CompiledEightLayerLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
t2 = time()
if past_key_value is None:
try:
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
except:
pass
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
t1 = time()
output = self.model(
"first_vicuna_forward",
(hidden_states, attention_mask, position_ids),
send_to_host=False,
)
output2 = (
output[0],
(
output[1],
output[2],
),
(
output[3],
output[4],
),
(
output[5],
output[6],
),
(
output[7],
output[8],
),
(
output[9],
output[10],
),
(
output[11],
output[12],
),
(
output[13],
output[14],
),
(
output[15],
output[16],
),
)
return output2
else:
(
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
) = past_key_value
try:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv00 = pkv00.detach()
pkv01 = pkv01.detach()
pkv10 = pkv10.detach()
pkv11 = pkv11.detach()
pkv20 = pkv20.detach()
pkv21 = pkv21.detach()
pkv30 = pkv30.detach()
pkv31 = pkv31.detach()
pkv40 = pkv40.detach()
pkv41 = pkv41.detach()
pkv50 = pkv50.detach()
pkv51 = pkv51.detach()
pkv60 = pkv60.detach()
pkv61 = pkv61.detach()
pkv70 = pkv70.detach()
pkv71 = pkv71.detach()
except:
x = 10
t1 = time()
if type(hidden_states) == iree.runtime.array_interop.DeviceArray:
hidden_states = np.array(hidden_states, hidden_states.dtype)
hidden_states = torch.tensor(hidden_states)
hidden_states = hidden_states.detach()
output = self.model(
"second_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
),
send_to_host=False,
)
print(f"{time() - t1}")
del pkv00
del pkv01
del pkv10
del pkv11
del pkv20
del pkv21
del pkv30
del pkv31
del pkv40
del pkv41
del pkv50
del pkv51
del pkv60
del pkv61
del pkv70
del pkv71
output2 = (
output[0],
(
output[1],
output[2],
),
(
output[3],
output[4],
),
(
output[5],
output[6],
),
(
output[7],
output[8],
),
(
output[9],
output[10],
),
(
output[11],
output[12],
),
(
output[13],
output[14],
),
(
output[15],
output[16],
),
)
return output2

View File

@@ -1,9 +1,6 @@
import torch
from transformers import AutoModelForCausalLM
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
class FirstVicuna(torch.nn.Module):
def __init__(
@@ -21,12 +18,18 @@ class FirstVicuna(torch.nn.Module):
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("First Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
dtype=torch.float16 if precision == "int4" else torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
@@ -48,7 +51,7 @@ class FirstVicuna(torch.nn.Module):
return tuple(return_vals)
class SecondVicuna(torch.nn.Module):
class SecondVicuna7B(torch.nn.Module):
def __init__(
self,
model_path,
@@ -64,12 +67,18 @@ class SecondVicuna(torch.nn.Module):
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
dtype=torch.float16 if precision == "int4" else torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
@@ -148,8 +157,6 @@ class SecondVicuna(torch.nn.Module):
i63,
i64,
):
# input_ids = input_tuple[0]
# input_tuple = torch.unbind(pkv, dim=0)
token = i0
past_key_values = (
(i1, i2),
@@ -290,6 +297,833 @@ class SecondVicuna(torch.nn.Module):
return tuple(return_vals)
class SecondVicuna13B(torch.nn.Module):
def __init__(
self,
model_path,
precision="int8",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float16 if precision == "int4" else torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
i65,
i66,
i67,
i68,
i69,
i70,
i71,
i72,
i73,
i74,
i75,
i76,
i77,
i78,
i79,
i80,
):
token = i0
past_key_values = (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
(
i65,
i66,
),
(
i67,
i68,
),
(
i69,
i70,
),
(
i71,
i72,
),
(
i73,
i74,
),
(
i75,
i76,
),
(
i77,
i78,
),
(
i79,
i80,
),
)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class SecondVicuna70B(torch.nn.Module):
def __init__(
self,
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float16,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
i65,
i66,
i67,
i68,
i69,
i70,
i71,
i72,
i73,
i74,
i75,
i76,
i77,
i78,
i79,
i80,
i81,
i82,
i83,
i84,
i85,
i86,
i87,
i88,
i89,
i90,
i91,
i92,
i93,
i94,
i95,
i96,
i97,
i98,
i99,
i100,
i101,
i102,
i103,
i104,
i105,
i106,
i107,
i108,
i109,
i110,
i111,
i112,
i113,
i114,
i115,
i116,
i117,
i118,
i119,
i120,
i121,
i122,
i123,
i124,
i125,
i126,
i127,
i128,
i129,
i130,
i131,
i132,
i133,
i134,
i135,
i136,
i137,
i138,
i139,
i140,
i141,
i142,
i143,
i144,
i145,
i146,
i147,
i148,
i149,
i150,
i151,
i152,
i153,
i154,
i155,
i156,
i157,
i158,
i159,
i160,
):
token = i0
past_key_values = (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
(
i65,
i66,
),
(
i67,
i68,
),
(
i69,
i70,
),
(
i71,
i72,
),
(
i73,
i74,
),
(
i75,
i76,
),
(
i77,
i78,
),
(
i79,
i80,
),
(
i81,
i82,
),
(
i83,
i84,
),
(
i85,
i86,
),
(
i87,
i88,
),
(
i89,
i90,
),
(
i91,
i92,
),
(
i93,
i94,
),
(
i95,
i96,
),
(
i97,
i98,
),
(
i99,
i100,
),
(
i101,
i102,
),
(
i103,
i104,
),
(
i105,
i106,
),
(
i107,
i108,
),
(
i109,
i110,
),
(
i111,
i112,
),
(
i113,
i114,
),
(
i115,
i116,
),
(
i117,
i118,
),
(
i119,
i120,
),
(
i121,
i122,
),
(
i123,
i124,
),
(
i125,
i126,
),
(
i127,
i128,
),
(
i129,
i130,
),
(
i131,
i132,
),
(
i133,
i134,
),
(
i135,
i136,
),
(
i137,
i138,
),
(
i139,
i140,
),
(
i141,
i142,
),
(
i143,
i144,
),
(
i145,
i146,
),
(
i147,
i148,
),
(
i149,
i150,
),
(
i151,
i152,
),
(
i153,
i154,
),
(
i155,
i156,
),
(
i157,
i158,
),
(
i159,
i160,
),
)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class CombinedModel(torch.nn.Module):
def __init__(
self,
@@ -298,15 +1132,17 @@ class CombinedModel(torch.nn.Module):
):
super().__init__()
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
# NOT using this path for 13B currently, hence using `SecondVicuna7B`.
self.second_vicuna = SecondVicuna7B(second_vicuna_model_path)
def forward(self, input_ids):
first_output = self.first_vicuna(input_ids=input_ids, use_cache=True)
logits = first_output[0]
pkv = first_output[1:]
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
second_output = self.second_vicuna(secondVicunaInput)
first_output = self.first_vicuna(input_ids=input_ids)
# generate second vicuna
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
second_output = self.second_vicuna(*secondVicunaCompileInput)
return second_output

View File

@@ -66,7 +66,7 @@ class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers) == len(model.model.layers)
# assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers = layers
@@ -132,7 +132,10 @@ class VicunaNormCompiled(torch.nn.Module):
self.model = shark_module
def forward(self, hidden_states):
hidden_states.detach()
try:
hidden_states.detach()
except:
pass
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output

View File

@@ -3,7 +3,10 @@ from abc import ABC, abstractmethod
class SharkLLMBase(ABC):
def __init__(
self, model_name, hf_model_path=None, max_num_tokens=512
self,
model_name,
hf_model_path=None,
max_num_tokens=512,
) -> None:
self.model_name = model_name
self.hf_model_path = hf_model_path

View File

@@ -28,7 +28,9 @@ parser = argparse.ArgumentParser(
description="runs a falcon model",
)
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
parser.add_argument(
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
)
parser.add_argument(
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
)
@@ -49,7 +51,7 @@ parser.add_argument(
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
@@ -59,32 +61,52 @@ parser.add_argument(
action=argparse.BooleanOptionalAction,
help="Run model in cli mode",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication token for falcon-180B model.",
)
class Falcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_auth_token: str = None,
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=None,
falcon_vmfb_path=None,
debug=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
print("hf_model_path: ", self.hf_model_path)
if "180b" in self.model_name and hf_auth_token == None:
raise ValueError(
""" HF auth token required for falcon-180b. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.src_model = self.get_src_model()
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, trust_remote_code=True
self.hf_model_path,
trust_remote_code=True,
token=self.hf_auth_token,
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 11
@@ -92,13 +114,17 @@ class Falcon(SharkLLMBase):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
kwargs = {
"torch_dtype": torch.float,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return falcon_model
def compile_falcon(self):
def compile(self):
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
@@ -120,37 +146,39 @@ class Falcon(SharkLLMBase):
if vmfb is not None:
return vmfb
print(
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
)
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
if self.falcon_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
print(
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
if args.load_mlir_from_shark_tank:
# Downloading MLIR from shark_tank
print(f"[DEBUG] Trying to download mlir from shark_tank")
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
print(
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
)
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
if not mlir_generated:
print(f"[DEBUG] generating MLIR locally")
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 100)
)
@@ -189,10 +217,9 @@ class Falcon(SharkLLMBase):
bytecode = bytecode_stream.getvalue()
del module
print(f"[DEBUG] writing mlir to file")
with open(f"{self.model_name}.mlir", "wb") as f_:
with redirect_stdout(f_):
print(module.operation.get_asm())
f_ = open(self.falcon_mlir_path, "wb")
f_.write(bytecode)
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
f_.close()
shark_module = SharkInference(
@@ -202,22 +229,17 @@ class Falcon(SharkLLMBase):
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-spirv-index-bits=64",
],
debug=self.debug,
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile(self):
falcon_shark_model = self.compile_falcon()
return falcon_shark_model
def generate(self, prompt):
model_inputs = self.tokenizer(
prompt,
@@ -466,11 +488,16 @@ if __name__ == "__main__":
else Path(args.falcon_vmfb_path)
)
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "tiiuae/falcon-180B-chat"
else:
hf_model_path_value = (
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
)
falcon = Falcon(
"falcon_" + args.falcon_variant_to_use,
hf_model_path="tiiuae/falcon-"
+ args.falcon_variant_to_use
+ "-instruct",
model_name="falcon_" + args.falcon_variant_to_use,
hf_model_path=hf_model_path_value,
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,

View File

@@ -136,7 +136,8 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
def brevitasmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
# fmt: off
def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
@@ -145,20 +146,21 @@ def brevitasmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh
raise ValueError("Input shapes not supported.")
def brevitasmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
def quantmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def brevitasmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
def quantmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics]
quantmatmul_rhs_group_quant〡shape,
quantmatmul_rhs_group_quant〡dtype,
quantmatmul_rhs_group_quant〡has_value_semantics]
# fmt: on
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
@@ -176,7 +178,7 @@ def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
def compile_module(
shark_module, extended_model_name, generate_vmfb, extra_args=[]
shark_module, extended_model_name, generate_vmfb, extra_args=[], debug=False,
):
if generate_vmfb:
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
@@ -188,7 +190,7 @@ def compile_module(
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
)
path = shark_module.save_module(
os.getcwd(), extended_model_name, extra_args
os.getcwd(), extended_model_name, extra_args, debug=debug
)
shark_module.load_module(path, extra_args=extra_args)
else:
@@ -197,7 +199,7 @@ def compile_module(
def compile_int_precision(
model, inputs, precision, device, generate_vmfb, extended_model_name
model, inputs, precision, device, generate_vmfb, extended_model_name, debug=False
):
torchscript_module = import_with_fx(
model,
@@ -209,7 +211,7 @@ def compile_int_precision(
torchscript_module,
inputs,
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
@@ -217,7 +219,7 @@ def compile_int_precision(
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
mlir_module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
from contextlib import redirect_stdout
@@ -249,6 +251,7 @@ def compile_int_precision(
extended_model_name=extended_model_name,
generate_vmfb=generate_vmfb,
extra_args=extra_args,
debug=debug,
),
bytecode,
)
@@ -292,6 +295,7 @@ def shark_compile_through_fx_int(
device,
generate_or_load_vmfb,
extended_model_name,
debug,
)
extra_args = [
"--iree-hal-dump-executable-sources-to=ies",

View File

@@ -32,11 +32,13 @@ class SharkStableLM(SharkLLMBase):
max_num_tokens=512,
device="cuda",
precision="fp32",
debug="False",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
@@ -111,7 +113,7 @@ class SharkStableLM(SharkLLMBase):
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
)
print("Saved vmfb at ", str(path))

View File

@@ -8,7 +8,7 @@ from shark.shark_downloader import download_public_file
# expects a Path / str as arg
# returns None if path not found or SharkInference module
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
if not isinstance(vmfb_path, Path):
vmfb_path = Path(vmfb_path)
@@ -20,7 +20,7 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
print("Loading vmfb from: ", vmfb_path)
print("Device from get_vmfb_from_path - ", device)
shark_module = SharkInference(
None, device=device, mlir_dialect=mlir_dialect
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
@@ -28,7 +28,13 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
def get_vmfb_from_config(
shark_container, model, precision, device, vmfb_path, padding=None
shark_container,
model,
precision,
device,
vmfb_path,
padding=None,
device_id=None,
):
vmfb_url = (
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
@@ -37,4 +43,6 @@ def get_vmfb_from_config(
vmfb_url = vmfb_url + f"_{padding}"
vmfb_url = vmfb_url + ".vmfb"
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
return get_vmfb_from_path(vmfb_path, device, "tm_tensor")
return get_vmfb_from_path(
vmfb_path, device, "tm_tensor", device_id=device_id
)

View File

@@ -7,16 +7,16 @@ Compile Commands FP32/FP16:
```shell
Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda /path/to/input/mlir -o /path/to/output/vmfb
CPU:
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu /path/to/input/mlir -o /path/to/output/vmfb
```

View File

@@ -34,7 +34,7 @@ from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
@@ -287,7 +287,7 @@ def lora_train(
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)

View File

@@ -15,8 +15,8 @@ pathex = [
# datafiles for pyinstaller
datas = []
datas += collect_data_files("torch")
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
@@ -30,26 +30,29 @@ datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("pytorch_lightning")
datas += collect_data_files("opencv_python")
datas += collect_data_files("skimage")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree")
datas += collect_data_files("google_cloud_storage")
datas += collect_data_files("shark")
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
datas += collect_data_files("cv2")
datas += [
("src/utils/resources/prompts.json", "resources"),
("src/utils/resources/model_db.json", "resources"),
@@ -71,7 +74,11 @@ datas += [
# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
blacklist = ["tests", "convert"]
hiddenimports += [
x for x in collect_submodules("transformers") if "tests" not in x
x
for x in collect_submodules("transformers")
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]

View File

@@ -177,9 +177,11 @@ class SharkifyStableDiffusionModel:
"unet",
"unet512",
"stencil_unet",
"stencil_unet_512",
"vae",
"vae_encode",
"stencil_adaptor",
"stencil_adaptor_512",
]
index = 0
for model in sub_model_list:
@@ -339,7 +341,7 @@ class SharkifyStableDiffusionModel:
)
return shark_vae, vae_mlir
def get_controlled_unet(self):
def get_controlled_unet(self, use_large=False):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self,
@@ -415,6 +417,16 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
model_name = "stencil_unet"
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[:2]
+ (torch.nn.functional.pad(inputs[2], pad),)
+ inputs[3:]
)
model_name = "stencil_unet_512"
input_mask = [
True,
True,
@@ -437,19 +449,19 @@ class SharkifyStableDiffusionModel:
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["stencil_unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self):
def get_control_net(self, use_large=False):
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
@@ -497,17 +509,34 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor"]
)
input_mask = [True, True, True, True]
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
extended_model_name=self.model_name["stencil_adaptor"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_adaptor",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
@@ -748,7 +777,7 @@ class SharkifyStableDiffusionModel:
else:
return self.get_unet(use_large=use_large)
else:
return self.get_controlled_unet()
return self.get_controlled_unet(use_large=use_large)
def vae_encode(self):
try:
@@ -847,12 +876,14 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def controlnet(self):
def controlnet(self, use_large=False):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
use_large=use_large
)
check_compilation(compiled_stencil_adaptor, "Stencil")
if self.return_mlir:

View File

@@ -84,13 +84,35 @@ class Image2ImagePipeline(StableDiffusionPipeline):
num_inference_steps,
strength,
dtype,
resample_type,
):
# Pre process image -> get image encoded -> process latents
# TODO: process with variable HxW combos
# Pre process image
image = image.resize((width, height))
# Pre-process image
if resample_type == "Lanczos":
resample_type = Image.LANCZOS
elif resample_type == "Nearest Neighbor":
resample_type = Image.NEAREST
elif resample_type == "Bilinear":
resample_type = Image.BILINEAR
elif resample_type == "Bicubic":
resample_type = Image.BICUBIC
elif resample_type == "Adaptive":
resample_type = Image.ADAPTIVE
elif resample_type == "Antialias":
resample_type = Image.ANTIALIAS
elif resample_type == "Box":
resample_type = Image.BOX
elif resample_type == "Affine":
resample_type = Image.AFFINE
elif resample_type == "Cubic":
resample_type = Image.CUBIC
else: # Fallback to Lanczos
resample_type = Image.LANCZOS
image = image.resize((width, height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
@@ -147,6 +169,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -186,6 +209,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
# Get Image latents

View File

@@ -58,6 +58,7 @@ class StencilPipeline(StableDiffusionPipeline):
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
self.controlnet_512 = None
def load_controlnet(self):
if self.controlnet is not None:
@@ -68,6 +69,15 @@ class StencilPipeline(StableDiffusionPipeline):
del self.controlnet
self.controlnet = None
def load_controlnet_512(self):
if self.controlnet_512 is not None:
return
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
def unload_controlnet_512(self):
del self.controlnet_512
self.controlnet_512 = None
def prepare_latents(
self,
batch_size,
@@ -111,8 +121,12 @@ class StencilPipeline(StableDiffusionPipeline):
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
self.load_controlnet()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
self.load_controlnet()
else:
self.load_unet_512()
self.load_controlnet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
@@ -135,43 +149,82 @@ class StencilPipeline(StableDiffusionPipeline):
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
else:
control = self.controlnet_512(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
else:
print(self.unet_512)
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
@@ -191,7 +244,9 @@ class StencilPipeline(StableDiffusionPipeline):
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
self.unload_controlnet()
self.unload_controlnet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -218,6 +273,7 @@ class StencilPipeline(StableDiffusionPipeline):
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.

View File

@@ -109,7 +109,7 @@ def load_lower_configs(base_model_id=None):
spec = spec.split("-")[0]
if args.annotation_model == "vae":
if not spec or spec in ["rdna3", "sm_80"]:
if not spec or spec in ["sm_80"]:
config_name = (
f"{args.annotation_model}_{args.precision}_{device}.json"
)
@@ -158,9 +158,9 @@ def load_lower_configs(base_model_id=None):
f"{spec}.json"
)
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading lowering config file from ", lowering_config_dir)
full_gs_url = config_bucket + config_name
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir

View File

@@ -132,6 +132,57 @@ p.add_argument(
"img2img.",
)
p.add_argument(
"--use_hiresfix",
type=bool,
default=False,
help="Use Hires Fix to do higher resolution images, while trying to "
"avoid the issues that come with it. This is accomplished by first "
"generating an image using txt2img, then running it through img2img.",
)
p.add_argument(
"--hiresfix_height",
type=int,
default=768,
choices=range(128, 769, 8),
help="The height of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_width",
type=int,
default=768,
choices=range(128, 769, 8),
help="The width of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_strength",
type=float,
default=0.6,
help="The denoising strength to apply for the Hires Fix.",
)
p.add_argument(
"--resample_type",
type=str,
default="Nearest Neighbor",
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
help="The resample type to use when resizing an image before being run "
"through stable diffusion.",
)
##############################################################################
# Stable Diffusion Training Params
##############################################################################
@@ -519,6 +570,20 @@ p.add_argument(
"in shark importer. Does nothing if import_mlir is false (the default).",
)
p.add_argument(
"--compile_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to toggle debug assert/verify flags for imported IR in the"
"iree-compiler. Default to false.",
)
p.add_argument(
"--iree_constant_folding",
default=True,
action=argparse.BooleanOptionalAction,
help="Controls constant folding in iree-compile for all SD models.",
)
##############################################################################
# Web UI flags
@@ -568,6 +633,13 @@ p.add_argument(
help="Flag for enabling rest API.",
)
p.add_argument(
"--debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling debugging log in WebUI.",
)
p.add_argument(
"--output_gallery",
default=True,

View File

@@ -25,7 +25,7 @@ from shark.iree_utils.vulkan_utils import (
get_iree_vulkan_runtime_flags,
)
from shark.iree_utils.metal_utils import get_metal_target_triple
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from shark.iree_utils.gpu_utils import get_cuda_sm_cc, get_iree_rocm_args
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
@@ -78,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
)
)
path = shark_module.save_module(
os.getcwd(), model_name, extra_args
os.getcwd(), model_name, extra_args, debug=args.compile_debug
)
shark_module.load_module(path, extra_args=extra_args)
else:
@@ -470,12 +470,25 @@ def get_available_devices():
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
@@ -499,6 +512,15 @@ def get_opt_flags(model, precision="fp16"):
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
if "rocm" in args.device:
rocm_args = get_iree_rocm_args()
iree_flags.extend(rocm_args)
print(iree_flags)
if args.iree_constant_folding == False:
iree_flags.append("--iree-opt-const-expr-hoisting=False")
iree_flags.append(
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
@@ -566,7 +588,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
)
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
checkpoint_path_or_dict=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
@@ -816,6 +838,8 @@ def clear_all():
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
if args.local_tank_cache != "":
shutil.rmtree(args.local_tank_cache)
def get_generated_imgs_path() -> Path:

View File

@@ -1,6 +1,7 @@
from multiprocessing import Process, freeze_support
import os
import sys
import logging
if sys.platform == "darwin":
# import before IREE to avoid torch-MLIR library issues
@@ -37,10 +38,12 @@ def launch_app(address):
height=height,
text_select=True,
)
webview.start(private_mode=False)
webview.start(private_mode=False, storage_path=os.getcwd())
if __name__ == "__main__":
if args.debug:
logging.basicConfig(level=logging.DEBUG)
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
if args.api or "api" in args.ui.split(","):
@@ -115,7 +118,8 @@ if __name__ == "__main__":
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
h2ogpt_web,
# h2ogpt_upload,
# h2ogpt_web,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
@@ -154,6 +158,7 @@ if __name__ == "__main__":
upscaler_sendto_outpaint,
lora_train_web,
model_web,
model_config_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
@@ -211,6 +216,15 @@ if __name__ == "__main__":
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
# and that the order in the code here is the order they should
# appear in the ui, as the id value doesn't determine the order.
# Where possible, avoid changing the id of any tab that is the
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
@@ -238,14 +252,20 @@ if __name__ == "__main__":
)
with gr.TabItem(label="Model Manager", id=6):
model_web.render()
with gr.TabItem(label="LoRA Training (Experimental)", id=8):
with gr.TabItem(label="LoRA Training (Experimental)", id=7):
lora_train_web.render()
with gr.TabItem(label="Chat Bot (Experimental)", id=7):
with gr.TabItem(label="Chat Bot (Experimental)", id=8):
stablelm_chat.render()
with gr.TabItem(label="MultiModal (Experimental)", id=9):
with gr.TabItem(
label="Generate Sharding Config (Experimental)", id=9
):
model_config_web.render()
with gr.TabItem(label="MultiModal (Experimental)", id=10):
minigpt4_web.render()
with gr.TabItem(label="DocuChat(Experimental)", id=10):
h2ogpt_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
# send to buttons
register_button_click(

View File

@@ -78,7 +78,7 @@ from apps.stable_diffusion.web.ui.stablelm_ui import (
stablelm_chat,
llm_chat_api,
)
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_web
from apps.stable_diffusion.web.ui.generate_config import model_config_web
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_web,

View File

@@ -0,0 +1,41 @@
import gradio as gr
import torch
from transformers import AutoTokenizer
from apps.language_models.src.model_wrappers.vicuna_model import CombinedModel
from shark.shark_generate_model_config import GenerateConfigFile
def get_model_config():
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
return c.split_into_layers()
with gr.Blocks() as model_config_web:
with gr.Row():
hf_models = gr.Dropdown(
label="Model List",
choices=["Vicuna"],
value="Vicuna",
visible=True,
)
get_model_config_btn = gr.Button(value="Get Model Config")
json_view = gr.JSON()
get_model_config_btn.click(
fn=get_model_config,
inputs=[],
outputs=[json_view],
)

View File

@@ -12,6 +12,10 @@ from apps.language_models.langchain.enums import (
LangChainAction,
)
import apps.language_models.langchain.gen as gen
from gpt_langchain import (
path_to_docs,
create_or_update_db,
)
from apps.stable_diffusion.src import args
@@ -33,8 +37,15 @@ start_message = """
def create_prompt(history):
system_message = start_message
for item in history:
print("His item: ", item)
conversation = "".join(["".join([item[0], item[1]]) for item in history])
conversation = "<|endoftext|>".join(
[
"<|endoftext|><|answer|>".join([item[0], item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
@@ -44,10 +55,12 @@ def create_prompt(history):
def chat(curr_system_message, history, device, precision):
args.run_docuchat_web = True
global h2ogpt_model
global sharkModel
global h2ogpt_tokenizer
global model_state
global langchain
global userpath_selector
from apps.language_models.langchain.h2oai_pipeline import generate_token
if h2ogpt_model == 0:
if "cuda" in device:
@@ -102,9 +115,14 @@ def chat(curr_system_message, history, device, precision):
prompt_type=None,
prompt_dict=None,
)
from apps.language_models.langchain.h2oai_pipeline import (
H2OGPTSHARKModel,
)
sharkModel = H2OGPTSHARKModel()
prompt = create_prompt(history)
output = langchain.evaluate(
output_dict = langchain.evaluate(
model_state=model_state,
my_db_state=None,
instruction=prompt,
@@ -164,14 +182,22 @@ def chat(curr_system_message, history, device, precision):
model_lock=True,
user_path=userpath_selector.value,
)
for partial_text in output:
history[-1][1] = partial_text["response"]
yield history
output = generate_token(sharkModel, **output_dict)
for partial_text in output:
history[-1][1] = partial_text
yield history
return history
with gr.Blocks(title="H2OGPT") as h2ogpt_web:
userpath_selector = gr.Textbox(
label="Document Directory",
value=str(os.path.abspath("apps/language_models/langchain/user_path/")),
interactive=True,
container=True,
)
with gr.Blocks(title="DocuChat") as h2ogpt_web:
with gr.Row():
supported_devices = available_devices
enabled = len(supported_devices) > 0
@@ -198,14 +224,6 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
],
visible=True,
)
userpath_selector = gr.Textbox(
label="Document Directory",
value=str(
os.path.abspath("apps/language_models/langchain/user_path/")
),
interactive=True,
container=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
@@ -249,3 +267,100 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)
with gr.Blocks(title="DocuChat Upload") as h2ogpt_upload:
import pathlib
upload_path = None
database = None
database_directory = os.path.abspath(
"apps/language_models/langchain/db_path/"
)
def read_path():
global upload_path
filenames = [
[f]
for f in os.listdir(upload_path)
if os.path.isfile(os.path.join(upload_path, f))
]
filenames.sort()
return filenames
def upload_file(f):
names = []
for tmpfile in f:
name = tmpfile.name.split("/")[-1]
basename = os.path.join(upload_path, name)
with open(basename, "wb") as w:
with open(tmpfile.name, "rb") as r:
w.write(r.read())
update_or_create_db()
return read_path()
def update_userpath(newpath):
global upload_path
upload_path = newpath
pathlib.Path(upload_path).mkdir(parents=True, exist_ok=True)
return read_path()
def update_or_create_db():
global database
global upload_path
sources = path_to_docs(
upload_path,
verbose=True,
fail_any_exception=False,
n_jobs=-1,
chunk=True,
chunk_size=512,
url=None,
enable_captions=False,
captions_model=None,
caption_loader=None,
enable_ocr=False,
)
pathlib.Path(database_directory).mkdir(parents=True, exist_ok=True)
database = create_or_update_db(
"chroma",
database_directory,
"UserData",
sources,
False,
True,
True,
"sentence-transformers/all-MiniLM-L6-v2",
)
def first_run():
global database
if database is None:
update_or_create_db()
update_userpath(
os.path.abspath("apps/language_models/langchain/user_path/")
)
h2ogpt_upload.load(fn=first_run)
h2ogpt_web.load(fn=first_run)
with gr.Column():
text = gr.DataFrame(
col_count=(1, "fixed"),
type="array",
label="Documents",
value=read_path(),
)
with gr.Row():
upload = gr.UploadButton(
label="Upload documents",
file_count="multiple",
)
upload.upload(fn=upload_file, inputs=upload, outputs=text)
userpath_selector.render()
userpath_selector.input(
fn=update_userpath, inputs=userpath_selector, outputs=text
).then(fn=update_or_create_db)

View File

@@ -3,6 +3,7 @@ import torch
import time
import gradio as gr
import PIL
from math import ceil
from PIL import Image
import base64
from io import BytesIO
@@ -67,6 +68,7 @@ def img2img_inf(
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -245,7 +247,7 @@ def img2img_inf(
batch_size,
height,
width,
steps,
ceil(steps / strength),
strength,
guidance_scale,
seeds[current_batch],
@@ -255,6 +257,7 @@ def img2img_inf(
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
resample_type=resample_type,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
@@ -348,6 +351,7 @@ def img2img_api(
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
resample_type="Lanczos",
)
# Converts generator type to subscriptable
@@ -432,7 +436,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
lines=2,
elem_id="negative_prompt_box",
)
# TODO: make this import image prompt info if it exists
img2img_init_image = gr.Image(
label="Input Image",
source="upload",
@@ -550,15 +554,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
@@ -581,11 +576,35 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
step=0.01,
label="Denoising Strength",
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
label="Resample Type",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
@@ -695,6 +714,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
lora_hf_id,
ondemand,
repeatable_seeds,
resample_type,
],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -109,7 +109,7 @@ with gr.Blocks() as minigpt4_web:
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=0.5):
with gr.Column():
image = gr.Image(type="pil")
upload_button = gr.Button(
value="Upload & Start Chat",

View File

@@ -7,6 +7,8 @@ from transformers import (
)
from apps.stable_diffusion.web.ui.utils import available_devices
from datetime import datetime as dt
import json
import time
def user(message, history):
@@ -22,11 +24,9 @@ past_key_values = None
model_map = {
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
"llama2_13b": "meta-llama/Llama-2-13b-chat-hf",
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
}
# NOTE: Each `model_name` should have its own start message
@@ -40,6 +40,15 @@ start_message = {
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_13b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_70b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
@@ -49,54 +58,39 @@ start_message = {
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"StableLM": (
"<|SYSTEM|># StableLM Tuned (Alpha version)"
"\n- StableLM is a helpful and harmless open-source AI language model "
"developed by StabilityAI."
"\n- StableLM is excited to be able to help the user, but will refuse "
"to do anything that could be considered harmful to the user."
"\n- StableLM is more than just an information source, StableLM is also "
"able to write poetry, short stories, and make jokes."
"\n- StableLM will refuse to participate in anything that "
"could harm a human."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"vicuna1p3": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"codegen": "",
}
def create_prompt(model_name, history):
system_message = start_message[model_name]
if model_name in [
"StableLM",
"vicuna",
"vicuna1p3",
"llama2_7b",
"llama2_70b",
]:
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
msg = system_message + conversation
msg = msg.strip()
msg = system_message + conversation
msg = msg.strip()
return msg
@@ -105,84 +99,165 @@ def set_vicuna_model(model):
vicuna_model = model
def get_default_config():
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
CombinedModel,
)
from shark.shark_generate_model_config import GenerateConfigFile
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()
model_vmfb_key = ""
# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision, cli=True):
def chat(
curr_system_message,
history,
model,
device,
precision,
download_vmfb,
config_file,
cli=False,
progress=gr.Progress(),
):
global past_key_values
global model_vmfb_key
global vicuna_model
device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")
if model_name in [
"vicuna",
"vicuna1p3",
"codegen",
"llama2_7b",
"llama2_70b",
]:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{precision}"
if new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
# get iree flags that need to be overridden, from commandline args
_extra_args = []
# vulkan target triple
vulkan_target_triple = args.iree_vulkan_target_triple
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
get_vulkan_target_triple,
)
from apps.stable_diffusion.src import args
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
# We already have the device_id extracted via WebUI, so we directly use
# that to find the target triple.
vulkan_target_triple = get_vulkan_target_triple(
vulkaninfo_list[device_id]
)
_extra_args.append(
f"-iree-vulkan-target-triple={vulkan_target_triple}"
)
if "rdna" in vulkan_target_triple:
flags_to_add = [
"--iree-spirv-index-bits=64",
]
_extra_args = _extra_args + flags_to_add
max_toks = 128 if model_name == "codegen" else 512
if device_id is None:
id = 0
for device in vulkaninfo_list:
target_triple = get_vulkan_target_triple(
vulkaninfo_list[id]
)
if target_triple == vulkan_target_triple:
device_id = id
break
id += 1
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use target triple : {vulkan_target_triple}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
)
else:
# if config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
vulkan_target_triple=vulkan_target_triple,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
device_id=device_id,
)
prompt = create_prompt(model_name, history)
for partial_text in vicuna_model.generate(prompt, cli=cli):
history[-1][1] = partial_text
yield history
return history
# else Model is StableLM
global sharkModel
from apps.language_models.src.pipelines.stablelm_pipeline import (
SharkStableLM,
)
if sharkModel == 0:
# max_new_tokens=512
shark_slm = SharkStableLM(
model_name
) # pass elements from UI as required
# Construct the input message string for the model by concatenating the
# current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
prompt = create_prompt(model_name, history)
generate_kwargs = dict(prompt=prompt)
words_list = shark_slm.generate(**generate_kwargs)
partial_text = ""
for new_text in words_list:
print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to clean up the message textbox and the updated
# conversation history
yield history
return words_list
count = 0
start_time = time.time()
for text, msg in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
count += 1
if "formatted" in msg:
history[-1][1] = text
end_time = time.time()
tokens_per_sec = count / (end_time - start_time)
yield history, str(format(tokens_per_sec, ".2f")) + " tokens/sec"
else:
partial_text += text + " "
history[-1][1] = partial_text
yield history, ""
return history, ""
def llm_chat_api(InputData: dict):
@@ -218,6 +293,7 @@ def llm_chat_api(InputData: dict):
UnshardedVicuna,
)
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
@@ -226,6 +302,7 @@ def llm_chat_api(InputData: dict):
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
else:
print("unrecognized device")
@@ -236,6 +313,9 @@ def llm_chat_api(InputData: dict):
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
device_id=device_id,
)
# TODO: add role dict for different models
@@ -306,7 +386,6 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
@@ -314,23 +393,33 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
# multiselect=True,
)
precision = gr.Radio(
label="Precision",
value="fp16",
value="int8",
choices=[
"int4",
"int8",
"fp16",
"fp32",
],
visible=True,
)
with gr.Row():
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
interactive=True,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(label="Upload sharding configuration")
json_view_button = gr.Button("View as JSON")
json_view = gr.JSON()
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
@@ -357,16 +446,32 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
inputs=[
system_msg,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
inputs=[
system_msg,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
queue=True,
)
stop.click(

View File

@@ -4,6 +4,7 @@ import time
import sys
import gradio as gr
from PIL import Image
from math import ceil
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
@@ -26,6 +27,7 @@ from apps.stable_diffusion.src import (
utils,
save_output_img,
prompt_examples,
Image2ImagePipeline,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
@@ -62,6 +64,11 @@ def txt2img_inf(
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
use_hiresfix: bool,
hiresfix_height: int,
hiresfix_width: int,
hiresfix_strength: float,
resample_type: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -200,6 +207,81 @@ def txt2img_inf(
cpu_scheduling,
args.max_embeddings_multiples,
)
# TODO: allow user to save original image
# TODO: add option to let user keep both pipelines loaded, and unload
# either at will
# TODO: add custom step value slider
# TODO: add option to use secondary model for the img2img pass
if use_hiresfix is True:
new_config_obj = Config(
"img2img",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
1,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil="None",
ondemand=ondemand,
)
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
global_obj.set_sd_obj(
Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
1,
hiresfix_height,
hiresfix_width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(args.scheduler)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
out_imgs[0],
batch_size,
hiresfix_height,
hiresfix_width,
ceil(steps / hiresfix_strength),
hiresfix_strength,
guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil="None",
resample_type=resample_type,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
@@ -271,6 +353,11 @@ def txt2img_api(
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
use_hiresfix=False,
hiresfix_height=512,
hiresfix_width=512,
hiresfix_strength=0.6,
resample_type="Nearest Neighbor",
)
# Convert Generator to Subscriptable
@@ -460,6 +547,49 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Low VRAM",
interactive=True,
)
with gr.Group():
with gr.Row():
use_hiresfix = gr.Checkbox(
value=args.use_hiresfix,
label="Use Hires Fix",
interactive=True,
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
label="Resample Type",
)
hiresfix_height = gr.Slider(
384,
768,
value=args.hiresfix_height,
step=8,
label="Hires Fix Height",
)
hiresfix_width = gr.Slider(
384,
768,
value=args.hiresfix_width,
step=8,
label="Hires Fix Width",
)
hiresfix_strength = gr.Slider(
0,
1,
value=args.hiresfix_strength,
step=0.01,
label="Hires Fix Denoising Strength",
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
@@ -495,16 +625,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
@@ -530,6 +650,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
show_label=False,
)
txt2img_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -565,6 +697,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
lora_hf_id,
ondemand,
repeatable_seeds,
use_hiresfix,
hiresfix_height,
hiresfix_width,
hiresfix_strength,
resample_type,
],
outputs=[txt2img_gallery, std_output, txt2img_status],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -25,7 +25,7 @@ class Config:
device: str
use_lora: str
use_stencil: str
ondemand: str
ondemand: str # should this be expecting a bool instead?
custom_model_filetypes = (

View File

@@ -24,13 +24,13 @@ def get_image(url, local_filename):
shutil.copyfileobj(res.raw, f)
def compare_images(new_filename, golden_filename):
def compare_images(new_filename, golden_filename, upload=False):
new = np.array(Image.open(new_filename)) / 255.0
golden = np.array(Image.open(golden_filename)) / 255.0
diff = np.abs(new - golden)
mean = np.mean(diff)
if mean > 0.1:
if os.name != "nt":
if os.name != "nt" and upload == True:
subprocess.run(
[
"gsutil",
@@ -39,7 +39,7 @@ def compare_images(new_filename, golden_filename):
"gs://shark_tank/testdata/builder/",
]
)
raise SystemExit("new and golden not close")
raise AssertionError("new and golden not close")
else:
print("SUCCESS")

View File

@@ -1,5 +1,5 @@
#!/bin/bash
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
IMPORTER=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python tank/generate_sharktank.py
python build_tools/stable_diffusion_testing.py --gen

View File

@@ -63,7 +63,14 @@ def get_inpaint_inputs():
open("./test_images/inputs/mask.png", "wb").write(mask.content)
def test_loop(device="vulkan", beta=False, extra_flags=[]):
def test_loop(
device="vulkan",
beta=False,
extra_flags=[],
upload_bool=True,
exit_on_fail=True,
do_gen=False,
):
# Get golden values from tank
shutil.rmtree("./test_images", ignore_errors=True)
model_metrics = []
@@ -81,6 +88,8 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
if beta:
extra_flags.append("--beta_models=True")
extra_flags.append("--no-progress_bar")
if do_gen:
extra_flags.append("--import_debug")
to_skip = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
@@ -181,7 +190,14 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
"./test_images/golden/" + model_name + "/*.png"
)
golden_file = glob(golden_path)[0]
compare_images(test_file, golden_file)
try:
compare_images(
test_file, golden_file, upload=upload_bool
)
except AssertionError as e:
print(e)
if exit_on_fail == True:
raise
else:
print(command)
print("failed to generate image for this configuration")
@@ -200,6 +216,9 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
if do_gen:
prepare_artifacts()
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)
@@ -218,15 +237,49 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
f.write(";".join(output) + "\n")
def prepare_artifacts():
gen_path = os.path.join(os.getcwd(), "gen_shark_tank")
if not os.path.isdir(gen_path):
os.mkdir(gen_path)
for dirname in os.listdir(os.getcwd()):
for modelname in ["clip", "unet", "vae"]:
if modelname in dirname and "vmfb" not in dirname:
if not os.path.isdir(os.path.join(gen_path, dirname)):
shutil.move(os.path.join(os.getcwd(), dirname), gen_path)
print(f"Moved dir: {dirname} to {gen_path}.")
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--device", default="vulkan")
parser.add_argument(
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
)
parser.add_argument("-e", "--extra_args", type=str, default=None)
parser.add_argument(
"-u", "--upload", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-g", "--gen", action=argparse.BooleanOptionalAction, default=False
)
if __name__ == "__main__":
args = parser.parse_args()
print(args)
test_loop(args.device, args.beta, [])
extra_args = []
if args.extra_args:
for arg in args.extra_args.split(","):
extra_args.append(arg)
test_loop(
args.device,
args.beta,
extra_args,
args.upload,
args.exit_on_fail,
args.gen,
)
if args.gen:
prepare_artifacts()

View File

@@ -27,7 +27,7 @@ include(FetchContent)
FetchContent_Declare(
iree
GIT_REPOSITORY https://github.com/nod-ai/shark-runtime.git
GIT_REPOSITORY https://github.com/nod-ai/srt.git
GIT_TAG shark
GIT_SUBMODULES_RECURSE OFF
GIT_SHALLOW OFF

View File

@@ -40,7 +40,7 @@ cmake --build build/
*Prepare the model*
```bash
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux resnet50_tf.mlir -o resnet50_tf.vmfb
```
*Prepare the input*
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
```bash
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux stable_diff_tf.mlir -o stable_diff_tf.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
```
VAE and Autoencoder are also available
```bash
# VAE
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux vae.mlir -o vae.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
# CLIP Autoencoder
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux clip_autoencoder.mlir -o clip_autoencoder.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
```

View File

@@ -55,7 +55,7 @@ The command line for compilation will start something like this, where the `-` n
The `-o output_filename.vmfb` flag can be used to specify the location to save the compiled vmfb. Note that a dump of the
dispatches that can be compiled + run in isolation can be generated by adding `--iree-hal-dump-executable-benchmarks-to=/some/directory`. Say, if they are in the `benchmarks` directory, the following compile/run commands would work for Vulkan on RDNA3.
```
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
iree-benchmark-module --module=benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb --function=forward --device=vulkan
```
@@ -63,8 +63,8 @@ Where `${NUM}` is the dispatch number that you want to benchmark/profile in isol
### Enabling Tracy for Vulkan profiling
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime builds an
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SHARK-Runtime/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime (SRT) builds an
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SRT/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
```
$env:IREE_ENABLE_RUNTIME_TRACING="ON"
```

View File

@@ -1,192 +0,0 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.17)
project(sharkbackend LANGUAGES C CXX)
#
# Options
#
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()
#
# Dependencies
#
# FetchContent requires us to include the transitive closure of all
# repos that we depend on so that we can override the tags.
#
include(FetchContent)
FetchContent_Declare(
repo-common
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
GIT_TAG ${TRITON_COMMON_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-core
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
GIT_TAG ${TRITON_CORE_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-backend
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
#
# The backend must be built into a shared library. Use an ldscript to
# hide all symbols except for the TRITONBACKEND API.
#
configure_file(src/libtriton_dshark.ldscript libtriton_dshark.ldscript COPYONLY)
add_library(
triton-dshark-backend SHARED
src/dshark.cc
#src/dshark_driver_module.c
)
add_library(
SharkBackend::triton-dshark-backend ALIAS triton-dshark-backend
)
target_include_directories(
triton-dshark-backend
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src
)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
add_subdirectory(thirdparty/shark-runtime EXCLUDE_FROM_ALL)
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
iree_hal_hal
iree_hal_cuda_cuda
iree_hal_cuda_registration_registration
iree_hal_vmvx_registration_registration
iree_hal_dylib_registration_registration
iree_modules_hal_hal
iree_vm_vm
iree_vm_bytecode_module
iree_hal_local_loaders_system_library_loader
iree_hal_local_loaders_vmvx_module_loader
)
target_compile_features(triton-dshark-backend PRIVATE cxx_std_11)
target_link_libraries(
triton-dshark-backend
PRIVATE
triton-core-serverapi # from repo-core
triton-core-backendapi # from repo-core
triton-core-serverstub # from repo-core
triton-backend-utils # from repo-backend
)
if(WIN32)
set_target_properties(
triton-dshark-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_dshark
)
else()
set_target_properties(
triton-dshark-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_dshark
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_dshark.ldscript
LINK_FLAGS "-Wl,--version-script libtriton_dshark.ldscript"
)
endif()
#
# Install
#
include(GNUInstallDirs)
set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/SharkBackend)
install(
TARGETS
triton-dshark-backend
EXPORT
triton-dshark-backend-targets
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
)
install(
EXPORT
triton-dshark-backend-targets
FILE
SharkBackendTargets.cmake
NAMESPACE
SharkBackend::
DESTINATION
${INSTALL_CONFIGDIR}
)
include(CMakePackageConfigHelpers)
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/cmake/SharkBackendConfig.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
)
install(
FILES
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
DESTINATION ${INSTALL_CONFIGDIR}
)
#
# Export from build tree
#
export(
EXPORT triton-dshark-backend-targets
FILE ${CMAKE_CURRENT_BINARY_DIR}/SharkBackendTargets.cmake
NAMESPACE SharkBackend::
)
export(PACKAGE SharkBackend)

View File

@@ -1,100 +0,0 @@
# SHARK Triton Backend
The triton backend for shark.
# Build
Install SHARK
```
git clone https://github.com/nod-ai/SHARK.git
# skip above step if dshark is already installed
cd SHARK/inference
```
install dependancies
```
apt-get install patchelf rapidjson-dev python3-dev
git submodule update --init
```
update the submodules of iree
```
cd thirdparty/shark-runtime
git submodule update --init
```
Next, make the backend and install it
```
cd ../..
mkdir build && cd build
cmake -DTRITON_ENABLE_GPU=ON \
-DIREE_HAL_DRIVER_CUDA=ON \
-DIREE_TARGET_BACKEND_CUDA=ON \
-DMLIR_ENABLE_CUDA_RUNNER=ON \
-DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \
-DTRITON_BACKEND_REPO_TAG=r22.02 \
-DTRITON_CORE_REPO_TAG=r22.02 \
-DTRITON_COMMON_REPO_TAG=r22.02 ..
make install
```
# Incorporating into Triton
There are much more in depth explenations for the following steps in triton's documentation:
https://github.com/triton-inference-server/server/blob/main/docs/compose.md#triton-with-unsupported-and-custom-backends
There should be a file at /build/install/backends/dshark/libtriton_dshark.so. You will need to copy it into your triton server image.
More documentation is in the link above, but to create the docker image, you need to run the compose.py command in the triton-backend server repo
To first build your image, clone the tritonserver repo.
```
git clone https://github.com/triton-inference-server/server.git
```
then run `compose.py` to build a docker compose file
```
cd server
python3 compose.py --repoagent checksum --dry-run
```
Because dshark is a third party backend, you will need to manually modify the `Dockerfile.compose` to include the dshark backend. To do this, in the Dockerfile.compose file produced, copy this line.
the dshark backend will be located in the build folder from earlier under `/build/install/backends`
```
COPY /path/to/build/install/backends/dshark /opt/tritonserver/backends/dshark
```
Next run
```
docker build -t tritonserver_custom -f Dockerfile.compose .
docker run -it --gpus=1 --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
```
where `path/to/model_repos` is where you are storing the models you want to run
if your not using gpus, omit `--gpus=1`
```
docker run -it --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
```
# Setting up a model
to include a model in your backend, add a directory with your model name to your model repository directory. examples of models can be seen here: https://github.com/triton-inference-server/backend/tree/main/examples/model_repos/minimal_models
make sure to adjust the input correctly in the config.pbtxt file, and save a vmfb file under 1/model.vmfb
# CUDA
if you're having issues with cuda, make sure your correct drivers are installed, and that `nvidia-smi` works, and also make sure that the nvcc compiler is on the path.

View File

@@ -1,39 +0,0 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include(CMakeFindDependencyMacro)
get_filename_component(
SHARKBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
)
list(APPEND CMAKE_MODULE_PATH ${SHARKBACKEND_CMAKE_DIR})
if(NOT TARGET SharkBackend::triton-dshark-backend)
include("${SHARKBACKEND_CMAKE_DIR}/SharkBackendTargets.cmake")
endif()
set(SHARKBACKEND_LIBRARIES SharkBackend::triton-dshark-backend)

File diff suppressed because it is too large Load Diff

View File

@@ -1,30 +0,0 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
{
global:
TRITONBACKEND_*;
local: *;
};

View File

@@ -6,15 +6,15 @@ from distutils.sysconfig import get_python_lib
import fileinput
from pathlib import Path
# Temorary workaround for transformers/__init__.py.
path_to_tranformers_hook = Path(
# Temporary workaround for transformers/__init__.py.
path_to_transformers_hook = Path(
get_python_lib()
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
)
if path_to_tranformers_hook.is_file():
if path_to_transformers_hook.is_file():
pass
else:
with open(path_to_tranformers_hook, "w") as f:
with open(path_to_transformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")

View File

@@ -5,7 +5,7 @@ requires = [
"packaging",
"numpy>=1.22.4",
"torch-mlir>=20221021.633",
"torch-mlir>=20230620.875",
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
]

View File

@@ -3,7 +3,7 @@
numpy>1.22.4
pytorch-triton
torchvision==0.16.0.dev20230322
torchvision
tabulate
tqdm
@@ -15,8 +15,8 @@ iree-tools-tf
# TensorFlow and JAX.
gin-config
tensorflow>2.11
keras
tf-nightly
keras-nightly
#tf-models-nightly
#tensorflow-text-nightly
transformers

View File

@@ -1,3 +1,6 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
setuptools
wheel
@@ -15,16 +18,18 @@ Pillow
parameterized
# Add transformers, diffusers and scipy since it most commonly used
tokenizers==0.13.3
transformers
diffusers
#accelerate is now required for diffusers import from ckpt.
accelerate
scipy
ftfy
gradio
gradio==3.44.3
altair
omegaconf
safetensors
# 0.3.2 doesn't have binaries for arm64
safetensors==0.3.1
opencv-python
scikit-image
pytorch_lightning # for runwayml models
@@ -35,6 +40,7 @@ py-cpuinfo
tiktoken # for codegen
joblib # for langchain
timm # for MiniGPT4
langchain
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile

View File

@@ -90,8 +90,8 @@ python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
Write-Host "Build and installation completed successfully"
Write-Host "Source your venv with ./shark.venv/Scripts/activate"

View File

@@ -103,7 +103,7 @@ else
fi
if [[ -z "${USE_IREE}" ]]; then
rm .use-iree
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
RUNTIME="https://nod-ai.github.io/SRT/pip-release-links.html"
else
touch ./.use-iree
RUNTIME="https://openxla.github.io/iree/pip-release-links.html"
@@ -128,16 +128,15 @@ if [[ ! -z "${IMPORTER}" ]]; then
fi
fi
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/cpu/
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
T_VER=$($PYTHON -m pip show torch | grep Version)
TORCH_VERSION=${T_VER:9:17}
T_VER_MIN=${T_VER:14:12}
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
TV_VERSION=${TV_VER:9:18}
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
TV_VER_MAJ=${TV_VER:9:6}
$PYTHON -m pip uninstall -y torchvision
$PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu118."
else
@@ -145,14 +144,8 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
fi
fi
if [[ ! -z "${ONNX}" ]]; then
echo "${Yellow}Installing ONNX and onnxruntime for benchmarks..."
$PYTHON -m pip install onnx onnxruntime psutil
if [ $? -eq 0 ];then
echo "Successfully installed ONNX and ONNX runtime."
else
echo "Could not install ONNX." >&2
fi
if [[ -z "${NO_BREVITAS}" ]]; then
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev
fi
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then

View File

@@ -43,9 +43,7 @@ if __name__ == "__main__":
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=False, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, mlir_dialect="linalg"
)
shark_module = SharkInference(minilm_mlir)
shark_module.compile()
token_logits = torch.tensor(shark_module.forward(inputs))
mask_id = torch.where(

View File

@@ -13,7 +13,7 @@
# limitations under the License.
## Common utilities to be shared by iree utilities.
import functools
import os
import sys
import subprocess
@@ -52,6 +52,8 @@ def iree_device_map(device):
)
if len(uri_parts) == 1:
return iree_driver
elif "rocm" in uri_parts:
return "rocm"
else:
return f"{iree_driver}://{uri_parts[1]}"
@@ -63,7 +65,6 @@ def get_supported_device_list():
_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"AMD-AIE": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
@@ -82,7 +83,6 @@ def iree_target_map(device):
_IREE_TARGET_MAP = {
"cpu": "llvm-cpu",
"cpu-task": "llvm-cpu",
"AMD-AIE": "llvm-cpu",
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan",
@@ -93,6 +93,7 @@ _IREE_TARGET_MAP = {
# Finds whether the required drivers are installed for the given device.
@functools.cache
def check_device_drivers(device):
"""Checks necessary drivers present for gpu and vulkan devices"""
if "://" in device:
@@ -120,7 +121,10 @@ def check_device_drivers(device):
return False
elif device == "rocm":
try:
subprocess.check_output("rocminfo")
if sys.platform == "win32":
subprocess.check_output("hipinfo")
else:
subprocess.check_output("rocminfo")
except Exception:
return True

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
from shark.iree_utils._common import run_cmd, iree_device_map
from shark.iree_utils.cpu_utils import get_cpu_count
import numpy as np
@@ -62,16 +61,12 @@ def build_benchmark_args(
and whether it is training or not.
Outputs: string that execute benchmark-module on target model.
"""
path = benchmark_module.__path__[0]
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
if platform.system() == "Windows":
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module.exe"
)
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
time_extractor = None
else:
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module"
)
benchmarker_path = os.path.join(path, "iree-benchmark-module")
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
# TODO: The function named can be passed as one of the args.
@@ -106,15 +101,13 @@ def build_benchmark_args_non_tensor_input(
and whether it is training or not.
Outputs: string that execute benchmark-module on target model.
"""
path = benchmark_module.__path__[0]
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
if platform.system() == "Windows":
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module.exe"
)
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
time_extractor = None
else:
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module"
)
benchmarker_path = os.path.join(path, "iree-benchmark-module")
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
# TODO: The function named can be passed as one of the args.
if function_name:
@@ -139,7 +132,7 @@ def run_benchmark_module(benchmark_cl):
benchmark_path = benchmark_cl[0]
assert os.path.exists(
benchmark_path
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
), "Cannot find iree_benchmark_module, Please contact SHARK maintainer on discord."
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
try:
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")

View File

@@ -11,18 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import iree_device_map, iree_target_map
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
from shark.iree_utils.benchmark_utils import *
from shark.parser import shark_args
import functools
import numpy as np
import os
import re
import tempfile
import time
from pathlib import Path
import iree.runtime as ireert
import iree.compiler as ireec
from shark.parser import shark_args
from .trace import DetailLogger
from ._common import iree_device_map, iree_target_map
from .cpu_utils import get_iree_cpu_rt_args
from .benchmark_utils import *
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
@@ -41,7 +46,7 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
data_tiling_flag = ["--iree-opt-data-tiling"]
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
@@ -79,7 +84,7 @@ def get_iree_frontend_args(frontend):
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
return [
"--iree-llvmcpu-target-cpu-features=host",
"--iree-flow-demote-i64-to-i32",
"--iree-input-demote-i64-to-i32",
]
else:
# Frontend not found.
@@ -87,13 +92,27 @@ def get_iree_frontend_args(frontend):
# Common args to be used given any frontend or device.
def get_iree_common_args():
return [
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
def get_iree_common_args(debug=False):
common_args = [
"--iree-stream-resource-max-allocation-size=4294967295",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
]
if debug == True:
common_args.extend(
[
"--iree-opt-strip-assertions=false",
"--verify=true",
]
)
else:
common_args.extend(
[
"--iree-opt-strip-assertions=true",
"--verify=false",
]
)
return common_args
# Args that are suitable only for certain models or groups of models.
@@ -272,14 +291,16 @@ def compile_module_to_flatbuffer(
model_config_path,
extra_args,
model_name="None",
debug=False,
):
# Setup Compile arguments wrt to frontends.
input_type = ""
args = get_iree_frontend_args(frontend)
args += get_iree_device_args(device, extra_args)
args += get_iree_common_args()
args += get_iree_common_args(debug=debug)
args += get_model_specific_args()
args += extra_args
args += shark_args.additional_compile_args
if frontend in ["tensorflow", "tf"]:
input_type = "auto"
@@ -317,7 +338,6 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
@@ -337,58 +357,65 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
):
instance = ireert.VmInstance()
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device_by_uri(
device,
allocators=[],
)
# First get configs.
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
print(f"Loading module {flatbuffer_blob_or_path}...")
if "rocm" in device:
device = "rocm"
with DetailLogger(timeout=2.5) as dl:
# First get configs.
if device_idx is not None:
dl.log(f"Mapping device id: {device_idx}")
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
dl.log(f"ireert.get_driver()")
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(mmaped_vmfb)
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
return mmaped_vmfb, config, temp_file_to_unlink
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
dl.log(f"ireert.create_device()")
config = ireert.Config(device=haldevice)
dl.log(f"ireert.Config()")
else:
config = get_iree_runtime_config(device)
dl.log("get_iree_runtime_config")
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(
config.vm_instance, flatbuffer_blob_or_path
)
dl.log(f"mmap {flatbuffer_blob_or_path}")
ctx = ireert.SystemContext(config=config)
dl.log(f"ireert.SystemContext created")
ctx.add_vm_module(mmaped_vmfb)
dl.log(f"module initialized")
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
dl.log(f"mmap temp {vmfb_file_path}")
return mmaped_vmfb, config, temp_file_to_unlink
def get_iree_compiled_module(
@@ -399,10 +426,11 @@ def get_iree_compiled_module(
extra_args: list = [],
device_idx: int = None,
mmap: bool = False,
debug: bool = False,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args
module, device, frontend, model_config_path, extra_args, debug
)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
@@ -410,7 +438,6 @@ def get_iree_compiled_module(
# we're setting delete=False when creating NamedTemporaryFile. That's why
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
if mmap:
print(f"Will load the compiled module as a mmapped temporary file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_blob, device, device_idx
)
@@ -434,7 +461,6 @@ def load_flatbuffer(
):
temp_file_to_unlink = None
if mmap:
print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_path, device, device_idx
)
@@ -460,10 +486,11 @@ def export_iree_module_to_vmfb(
model_config_path: str = None,
module_name: str = None,
extra_args: list = [],
debug: bool = False,
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, mlir_dialect, model_config_path, extra_args
module, device, mlir_dialect, model_config_path, extra_args, debug
)
if module_name is None:
device_name = (
@@ -471,9 +498,9 @@ def export_iree_module_to_vmfb(
)
module_name = f"{mlir_dialect}_{device_name}"
filename = os.path.join(directory, module_name + ".vmfb")
print(f"Saved vmfb in {filename}.")
with open(filename, "wb") as f:
f.write(flatbuffer_blob)
print(f"Saved vmfb in {filename}.")
return filename
@@ -498,37 +525,56 @@ def get_results(
config,
frontend="torch",
send_to_host=True,
debug_timeout: float = 5.0,
):
"""Runs a .vmfb file given inputs and config and returns output."""
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm[function_name](*device_inputs)
result_tensors = []
if isinstance(result, tuple):
if send_to_host:
for val in result:
result_tensors.append(np.asarray(val, val.dtype))
with DetailLogger(debug_timeout) as dl:
device_inputs = []
for input_array in input:
dl.log(f"Load to device: {input_array.shape}")
device_inputs.append(
ireert.asdevicearray(config.device, input_array)
)
dl.log(f"Invoke function: {function_name}")
result = compiled_vm[function_name](*device_inputs)
dl.log(f"Invoke complete")
result_tensors = []
if isinstance(result, tuple):
if send_to_host:
for val in result:
dl.log(f"Result to host: {val.shape}")
result_tensors.append(np.asarray(val, val.dtype))
else:
for val in result:
result_tensors.append(val)
return result_tensors
elif isinstance(result, dict):
data = list(result.items())
if send_to_host:
res = np.array(data, dtype=object)
return np.copy(res)
return data
else:
for val in result:
result_tensors.append(val)
return result_tensors
elif isinstance(result, dict):
data = list(result.items())
if send_to_host:
res = np.array(data, dtype=object)
return np.copy(res)
return data
else:
if send_to_host and result is not None:
return result.to_host()
return result
if send_to_host and result is not None:
dl.log("Result to host")
return result.to_host()
return result
dl.log("Execution complete")
@functools.cache
def get_iree_runtime_config(device):
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
if device == "metal" and shark_args.device_allocator == "caching":
print(
"[WARNING] metal devices can not have a `caching` allocator."
"\nUsing default allocator `None`"
)
haldevice = haldriver.create_device_by_uri(
device,
allocators=shark_args.device_allocator,
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
allocators=shark_args.device_allocator if device != "metal" else None,
)
config = ireert.Config(device=haldevice)
return config

View File

@@ -14,6 +14,7 @@
# All the iree_cpu related functionalities go here.
import functools
import subprocess
import platform
from shark.parser import shark_args
@@ -30,6 +31,7 @@ def get_cpu_count():
# Get the default cpu args.
@functools.cache
def get_iree_cpu_args():
uname = platform.uname()
os_name, proc_name = uname.system, uname.machine
@@ -51,6 +53,7 @@ def get_iree_cpu_args():
# Get iree runtime flags for cpu
@functools.cache
def get_iree_cpu_rt_args():
default = get_cpu_count()
default = default if default <= 8 else default - 2

View File

@@ -14,12 +14,15 @@
# All the iree_gpu related functionalities go here.
import functools
import iree.runtime as ireert
import ctypes
import sys
from shark.parser import shark_args
# Get the default gpu args given the architecture.
@functools.cache
def get_iree_gpu_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
ireert.flags.parse_flags("--cuda_allow_inline_execution")
@@ -37,23 +40,54 @@ def get_iree_gpu_args():
# Get the default gpu args given the architecture.
@functools.cache
def get_iree_rocm_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
# get arch from rocminfo.
# get arch from hipinfo.
import os
import re
import subprocess
rocm_arch = re.match(
r".*(gfx\w+)",
subprocess.check_output(
"rocminfo | grep -i 'gfx'", shell=True, text=True
),
).group(1)
print(f"Found rocm arch {rocm_arch}...")
if sys.platform == "win32":
if "HIP_PATH" in os.environ:
rocm_path = os.environ["HIP_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to C:\\AMD\\ROCM\\5.5")
rocm_path = "C:\\AMD\\ROCM\\5.5"
else:
if "ROCM_PATH" in os.environ:
rocm_path = os.environ["ROCM_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to /opt/rocm")
rocm_path = "/opt/rocm/"
try:
if sys.platform == "win32":
rocm_arch = re.search(
r"gfx\d{3,}",
subprocess.check_output("hipinfo", shell=True, text=True),
).group(0)
else:
rocm_arch = re.match(
r".*(gfx\w+)",
subprocess.check_output(
"rocminfo | grep -i 'gfx'", shell=True, text=True
),
).group(1)
print(f"Found rocm arch {rocm_arch}...")
except:
print(
"Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100."
)
rocm_arch = "gfx1100"
bc_path = os.path.join(rocm_path, "amdgcn", "bitcode")
return [
f"--iree-rocm-target-chip={rocm_arch}",
"--iree-rocm-link-bc=true",
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
f"--iree-rocm-bc-dir={bc_path}",
]
@@ -65,6 +99,7 @@ CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13
CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36
@functools.cache
def get_cuda_sm_cc():
libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll")
for libname in libnames:

View File

@@ -14,12 +14,15 @@
# All the iree_vulkan related functionalities go here.
import functools
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
@functools.cache
def get_metal_device_name(device_num=0):
iree_device_dump = run_cmd("iree-run-module --dump_devices")
iree_device_dump = iree_device_dump[0].split("\n\n")

76
shark/iree_utils/trace.py Normal file
View File

@@ -0,0 +1,76 @@
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import os
import threading
import time
def _enable_detail_trace() -> bool:
return os.getenv("SHARK_DETAIL_TRACE", "0") == "1"
class DetailLogger:
"""Context manager which can accumulate detailed log messages.
Detailed log is only emitted if the operation takes a long time
or errors.
"""
def __init__(self, timeout: float):
self._timeout = timeout
self._messages: List[Tuple[float, str]] = []
self._start_time = time.time()
self._active = not _enable_detail_trace()
self._lock = threading.RLock()
self._cond = threading.Condition(self._lock)
self._thread = None
def __enter__(self):
self._thread = threading.Thread(target=self._run)
self._thread.start()
return self
def __exit__(self, type, value, traceback):
with self._lock:
self._active = False
self._cond.notify()
if traceback:
self.dump_on_error(f"exception")
def _run(self):
with self._lock:
timed_out = not self._cond.wait(self._timeout)
if timed_out:
self.dump_on_error(f"took longer than {self._timeout}s")
def log(self, msg):
with self._lock:
timestamp = time.time()
if self._active:
self._messages.append((timestamp, msg))
else:
print(f" +{(timestamp - self._start_time) * 1000}ms: {msg}")
def dump_on_error(self, summary: str):
with self._lock:
if self._active:
print(f"::: Detailed report ({summary}):")
for timestamp, msg in self._messages:
print(
f" +{(timestamp - self._start_time) * 1000}ms: {msg}"
)
self._active = False

View File

@@ -13,8 +13,10 @@
# limitations under the License.
from collections import OrderedDict
import functools
@functools.cache
def get_vulkan_target_env(vulkan_target_triple):
arch, product, os = vulkan_target_triple.split("=")[1].split("-")
triple = (arch, product, os)
@@ -52,13 +54,11 @@ def get_version(triple):
return "v1.3"
@functools.cache
def get_extensions(triple):
def make_ext_list(ext_list):
res = ""
for e in ext_list:
res += e + ", "
res = f"[{res[:-2]}]"
return res
res = ", ".join(ext_list)
return f"[{res}]"
arch, product, os = triple
if arch == "m1":
@@ -116,12 +116,13 @@ def get_extensions(triple):
]
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
ext.append("VK_NV_cooperative_matrix")
ext.append("VK_KHR_cooperative_matrix")
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
ext.append("VK_KHR_shader_integer_dot_product")
return make_ext_list(ext_list=ext)
@functools.cache
def get_vendor(triple):
arch, product, os = triple
if arch == "unknown":
@@ -146,6 +147,7 @@ def get_vendor(triple):
return "Unknown"
@functools.cache
def get_device_type(triple):
arch, product, _ = triple
if arch == "unknown":
@@ -166,6 +168,7 @@ def get_device_type(triple):
# get all the capabilities for the device
# TODO: make a dataclass for capabilites and init using vulkaninfo
@functools.cache
def get_vulkan_target_capabilities(triple):
def get_subgroup_val(l):
return int(sum([subgroup_feature[sgf] for sgf in l]))
@@ -241,7 +244,7 @@ def get_vulkan_target_capabilities(triple):
if arch == "rdna3":
# TODO: Get scope value
cap["coopmatCases"] = [
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>"
]
if product == "rx5700xt":
@@ -462,9 +465,9 @@ def get_vulkan_target_capabilities(triple):
cap["variablePointersStorageBuffer"] = True
cap["coopmatCases"] = [
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, accSat = false, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, accSat = false, scope = #vk.scope<Subgroup>",
]
elif arch == "adreno":
@@ -525,7 +528,7 @@ def get_vulkan_target_capabilities(triple):
cmc = ""
for case in v:
cmc += f"#vk.coop_matrix_props<{case}>, "
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
res += f"cooperativeMatrixPropertiesKHR = [{cmc[:-2]}], "
else:
res += f"{k} = {get_comma_sep_str(v)}, "
else:

View File

@@ -14,6 +14,7 @@
# All the iree_vulkan related functionalities go here.
import functools
from os import linesep
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
@@ -22,10 +23,19 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
from shark.parser import shark_args
@functools.cache
def get_all_vulkan_devices():
from iree.runtime import get_driver
driver = get_driver("vulkan")
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return [d["name"] for d in device_list_src]
@functools.cache
def get_vulkan_device_name(device_num=0):
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
vulkaninfo_list = get_all_vulkan_devices()
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
@@ -48,6 +58,7 @@ def get_os_name():
return "linux"
@functools.cache
def get_vulkan_target_triple(device_name):
"""This method provides a target triple str for specified vulkan device.
@@ -172,11 +183,10 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
return res_vulkan_flag
@functools.cache
def get_iree_vulkan_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={shark_args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
f"--vulkan_vma_allocator={'true' if shark_args.vulkan_vma_allocator else 'false'}",
]
return vulkan_runtime_flags

View File

@@ -14,8 +14,21 @@
import argparse
import os
import shlex
import subprocess
class SplitStrToListAction(argparse.Action):
def __init__(self, option_strings, dest, *args, **kwargs):
super(SplitStrToListAction, self).__init__(
option_strings=option_strings, dest=dest, *args, **kwargs
)
def __call__(self, parser, namespace, values, option_string=None):
del parser, option_string
setattr(namespace, self.dest, shlex.split(values[0]))
parser = argparse.ArgumentParser(description="SHARK runner.")
parser.add_argument(
@@ -24,6 +37,13 @@ parser.add_argument(
default="cpu",
help="Device on which shark_runner runs. options are cpu, cuda, and vulkan",
)
parser.add_argument(
"--additional_compile_args",
default=list(),
nargs=1,
action=SplitStrToListAction,
help="Additional arguments to pass to the compiler. These are appended as the last arguments.",
)
parser.add_argument(
"--enable_tf32",
type=bool,
@@ -114,7 +134,7 @@ parser.add_argument(
"--device_allocator",
type=str,
nargs="*",
default=[],
default=["caching"],
help="Specifies one or more HAL device allocator specs "
"to augment the base device allocator",
choices=["debug", "caching"],
@@ -133,13 +153,6 @@ parser.add_argument(
help="Profiles vulkan device and collects the .rdc info.",
)
parser.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
help="Flag for setting VMA preferredLargeHeapBlockSize for "
"vulkan device, default is 4G.",
)
parser.add_argument(
"--vulkan_validation_layers",
default=False,
@@ -147,11 +160,4 @@ parser.add_argument(
help="Flag for disabling vulkan validation layers when benchmarking.",
)
parser.add_argument(
"--vulkan_vma_allocator",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for enabling / disabling Vulkan VMA Allocator.",
)
shark_args, unknown = parser.parse_known_args()

View File

@@ -13,7 +13,11 @@
# limitations under the License.
from shark.shark_runner import SharkRunner
from shark.iree_utils.compile_utils import export_iree_module_to_vmfb
from shark.iree_utils.compile_utils import (
export_iree_module_to_vmfb,
load_flatbuffer,
get_iree_runtime_config,
)
from shark.iree_utils.benchmark_utils import (
build_benchmark_args,
run_benchmark_module,
@@ -79,22 +83,31 @@ class SharkBenchmarkRunner(SharkRunner):
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
self.import_args = {}
self.temp_file_to_unlink = None
SharkRunner.__init__(
self,
mlir_module,
device,
self.mlir_dialect,
self.extra_args,
compile_vmfb=True,
compile_vmfb=False,
)
if self.vmfb_file == None:
self.vmfb_file = export_iree_module_to_vmfb(
mlir_module,
device,
".",
self.mlir_dialect,
extra_args=self.extra_args,
)
self.vmfb_file = export_iree_module_to_vmfb(
mlir_module,
device,
".",
self.mlir_dialect,
extra_args=self.extra_args,
)
params = load_flatbuffer(
self.vmfb_file,
device,
mmap=True,
)
self.iree_compilation_module = params["vmfb"]
self.iree_config = params["config"]
self.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
def setup_cl(self, input_tensors):
self.benchmark_cl = build_benchmark_args(
@@ -111,42 +124,41 @@ class SharkBenchmarkRunner(SharkRunner):
elif self.mlir_dialect in ["mhlo", "tf"]:
return self.benchmark_tf(modelname)
def benchmark_torch(self, modelname):
def benchmark_torch(self, modelname, device="cpu"):
import torch
from tank.model_utils import get_torch_model
if self.device == "cuda":
torch.set_default_tensor_type(torch.cuda.FloatTensor)
if self.enable_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# TODO: Pass this as an arg. currently the best way is to setup with BENCHMARK=1 if we want to use torch+cuda, else use cpu.
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
torch.set_default_device("cuda:0")
# if self.enable_tf32:
# torch.backends.cuda.matmul.allow_tf32 = True
else:
torch.set_default_tensor_type(torch.FloatTensor)
torch_device = torch.device(
"cuda:0" if self.device == "cuda" else "cpu"
)
torch.set_default_dtype(torch.float32)
torch.set_default_device("cpu")
torch_device = torch.device("cuda:0" if device == "cuda" else "cpu")
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
frontend_model = HFmodel.model
frontend_model.to(torch_device)
input.to(torch_device)
# TODO: re-enable as soon as pytorch CUDA context issues are resolved
try:
frontend_model = torch.compile(
frontend_model, mode="max-autotune", backend="inductor"
)
except RuntimeError:
frontend_model = HFmodel.model
if device == "cuda":
frontend_model.cuda()
input.to(torch.device("cuda:0"))
print(input)
else:
frontend_model.cpu()
input.cpu()
for i in range(shark_args.num_warmup_iterations):
frontend_model.forward(input)
if self.device == "cuda":
if device == "cuda":
torch.cuda.reset_peak_memory_stats()
begin = time.time()
for i in range(shark_args.num_iterations):
out = frontend_model.forward(input)
end = time.time()
if self.device == "cuda":
if device == "cuda":
stats = torch.cuda.memory_stats()
device_peak_b = stats["allocated_bytes.all.peak"]
frontend_model.to(torch.device("cpu"))
@@ -158,7 +170,7 @@ class SharkBenchmarkRunner(SharkRunner):
print(
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
if self.device == "cuda":
if device == "cuda":
# Set device to CPU so we don't run into segfaults exiting pytest subprocesses.
torch_device = torch.device("cpu")
return [

View File

@@ -11,14 +11,8 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
def brevitasmatmul_rhs_group_quant〡shape(
lhs: List[int],
rhs: List[int],
rhs_scale: List[int],
rhs_zero_point: List[int],
rhs_bit_width: int,
rhs_group_size: int,
) -> List[int]:
# fmt: off
def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
@@ -27,30 +21,21 @@ def brevitasmatmul_rhs_group_quant〡shape(
raise ValueError("Input shapes not supported.")
def brevitasmatmul_rhs_group_quant〡dtype(
lhs_rank_dtype: Tuple[int, int],
rhs_rank_dtype: Tuple[int, int],
rhs_scale_rank_dtype: Tuple[int, int],
rhs_zero_point_rank_dtype: Tuple[int, int],
rhs_bit_width: int,
rhs_group_size: int,
) -> int:
def quantmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def brevitasmatmul_rhs_group_quant〡has_value_semantics(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
) -> None:
def quantmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics,
]
quantmatmul_rhs_group_quant〡shape,
quantmatmul_rhs_group_quant〡dtype,
quantmatmul_rhs_group_quant〡has_value_semantics]
# fmt: on
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
@@ -122,7 +107,7 @@ def compile_int_precision(
torchscript_module,
inputs,
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
@@ -130,7 +115,7 @@ def compile_int_precision(
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
mlir_module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
from contextlib import redirect_stdout

View File

@@ -111,22 +111,20 @@ os.makedirs(WORKDIR, exist_ok=True)
def check_dir_exists(model_name, frontend="torch", dynamic=""):
model_dir = os.path.join(WORKDIR, model_name)
# Remove the _tf keyword from end.
if frontend in ["tf", "tensorflow"]:
model_name = model_name[:-3]
elif frontend in ["tflite"]:
model_name = model_name[:-7]
elif frontend in ["torch", "pytorch"]:
model_name = model_name[:-6]
# Remove the _tf keyword from end only for non-SD models.
if not any(model in model_name for model in ["clip", "unet", "vae"]):
if frontend in ["tf", "tensorflow"]:
model_name = model_name[:-3]
elif frontend in ["tflite"]:
model_name = model_name[:-7]
elif frontend in ["torch", "pytorch"]:
model_name = model_name[:-6]
model_mlir_file_name = f"{model_name}{dynamic}_{frontend}.mlir"
if os.path.isdir(model_dir):
if (
os.path.isfile(
os.path.join(
model_dir,
model_name + dynamic + "_" + str(frontend) + ".mlir",
)
)
os.path.isfile(os.path.join(model_dir, model_mlir_file_name))
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))

View File

@@ -1,5 +1,7 @@
import re
import json
import numpy as np
import torch_mlir
from iree.compiler import compile_str
from shark.shark_importer import import_with_fx, get_f16_inputs
@@ -11,6 +13,7 @@ class GenerateConfigFile:
model,
num_sharding_stages: int,
sharding_stages_id: list[str],
units_in_each_stage: list[int],
model_input=None,
config_file_path="model_config.json",
):
@@ -22,13 +25,16 @@ class GenerateConfigFile:
), "Number of sharding stages should be equal to the list of their ID"
self.model_input = model_input
self.config_file_path = config_file_path
# (Nithin) this is a quick fix - revisit and rewrite
self.units_in_each_stage = np.array(units_in_each_stage)
self.track_loop = np.zeros(len(self.sharding_stages_id)).astype(int)
def split_into_dispatches(
self,
backend,
fx_tracing_required=True,
fx_tracing_required=False,
f16_model=False,
torch_mlir_tracing=False,
torch_mlir_tracing=True,
):
graph_for_compilation = self.model
if fx_tracing_required:
@@ -95,7 +101,17 @@ class GenerateConfigFile:
if substring_before_final_period in model_dictionary:
del model_dictionary[substring_before_final_period]
layer_dict = {n: "None" for n in self.sharding_stages_id}
# layer_dict = {n: "None" for n in self.sharding_stages_id}
# By default embed increasing device id's for each layer
increasing_wraparound_idx_list = (
self.track_loop % self.units_in_each_stage
)
layer_dict = {
n: int(increasing_wraparound_idx_list[idx][0][0])
for idx, n in enumerate(self.sharding_stages_id)
}
self.track_loop += 1
model_dictionary[name] = layer_dict
self.generate_json(model_dictionary)
@@ -103,3 +119,29 @@ class GenerateConfigFile:
def generate_json(self, artifacts):
with open(self.config_file_path, "w") as outfile:
json.dump(artifacts, outfile)
if __name__ == "__main__":
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna7B,
CombinedModel,
)
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()

View File

@@ -509,22 +509,6 @@ def import_with_fx(
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
)
from brevitas_examples.llm.llm_quant.export import (
brevitas_layer_export_mode,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
LinearWeightBlockQuantHandlerFwd,
)
from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
matmul_rhs_group_quant_placeholder,
)
from brevitas.backport.fx.experimental.proxy_tensor import (
make_fx as brevitas_make_fx,
)
golden_values = None
if debug:
@@ -596,8 +580,30 @@ def import_with_fx(
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten.index_add,
torch.ops.aten.index_add_,
]
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
)
from brevitas_examples.llm.llm_quant.export import (
brevitas_layer_export_mode,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
LinearWeightBlockQuantHandlerFwd,
)
from brevitas_examples.llm.llm_quant.export import (
replace_call_fn_target,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
matmul_rhs_group_quant_placeholder,
)
from brevitas.backport.fx.experimental.proxy_tensor import (
make_fx as brevitas_make_fx,
)
export_context_manager = brevitas_layer_export_mode
export_class = block_quant_layer_level_manager(
export_handlers=[LinearWeightBlockQuantHandlerFwd]
@@ -612,7 +618,7 @@ def import_with_fx(
replace_call_fn_target(
fx_g,
src=matmul_rhs_group_quant_placeholder,
target=torch.ops.brevitas.matmul_rhs_group_quant,
target=torch.ops.quant.matmul_rhs_group_quant,
)
fx_g.recompile()
@@ -677,5 +683,5 @@ def import_with_fx(
)
return mlir_module, func_name
mlir_module, func_name = mlir_importer.import_mlir()
mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type)
return mlir_module, func_name

View File

@@ -141,6 +141,10 @@ class SharkInference:
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
return self.shark_runner.run(function_name, inputs, send_to_host)
# forward function.
def forward(self, inputs: tuple, send_to_host=True):
return self.shark_runner.run("forward", inputs, send_to_host)
# Get all function names defined within the compiled module.
def get_functions_in_module(self):
return self.shark_runner.get_functions_in_module()
@@ -188,7 +192,9 @@ class SharkInference:
# TODO: Instead of passing directory and having names decided by the module
# , user may want to save the module with manual names.
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
def save_module(
self, dir=os.getcwd(), module_name=None, extra_args=[], debug=False
):
return export_iree_module_to_vmfb(
self.mlir_module,
self.device,
@@ -196,6 +202,7 @@ class SharkInference:
self.mlir_dialect,
module_name=module_name,
extra_args=extra_args,
debug=debug,
)
# load and return the module.

View File

@@ -69,7 +69,7 @@ class SharkTrainer:
self.frontend = frontend
# Training function is needed in the case of torch_fn.
def compile(self, training_fn=None, extra_args=[]):
def compile(self, training_fn=None, mlir_type="linalg", extra_args=[]):
if self.frontend in ["torch", "pytorch"]:
packed_inputs = (
dict(self.model.named_parameters()),
@@ -77,7 +77,12 @@ class SharkTrainer:
tuple(self.input),
)
mlir_module, func_name = import_with_fx(
training_fn, packed_inputs, False, [], training=True
training_fn,
packed_inputs,
False,
[],
training=True,
mlir_type=mlir_type,
)
self.shark_runner = SharkRunner(
mlir_module,

View File

@@ -13,7 +13,6 @@ google/vit-base-patch16-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,
microsoft/MiniLM-L12-H384-uncased,stablehlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
microsoft/layoutlm-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
microsoft/mpnet-base,stablehlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with aten.tanh in torch-mlir",""
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
@@ -30,7 +29,7 @@ nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,False,True,"",""
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
efficientnet-v2-s,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
@@ -44,4 +43,3 @@ t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq m
t5-base,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
t5-large,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
stabilityai/stable-diffusion-2-1-base,linalg,torch,1e-3,1e-3,default,None,True,False,False,"","macos"
1 resnet50 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False macos
13 microsoft/MiniLM-L12-H384-uncased stablehlo tf 1e-2 1e-3 tf_hf None True False False Fails during iree-compile.
14 microsoft/layoutlm-base-uncased stablehlo tf 1e-2 1e-3 default None False False False
15 microsoft/mpnet-base stablehlo tf 1e-2 1e-2 default None True True True
albert-base-v2 linalg torch 1e-2 1e-3 default None True True True issue with aten.tanh in torch-mlir
16 alexnet linalg torch 1e-2 1e-3 default None True True False https://github.com/nod-ai/SHARK/issues/879
17 bert-base-cased linalg torch 1e-2 1e-3 default None False True False
18 bert-base-uncased linalg torch 1e-2 1e-3 default None False True False
29 resnet101 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col True False False macos
30 resnet18 linalg torch 1e-2 1e-3 default None True True False macos
31 resnet50 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False macos
32 resnet50_fp16 linalg torch 1e-2 1e-2 default nhcw-nhwc/img2col True False True True Numerics issues, awaiting cuda-independent fp16 integration
33 squeezenet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False macos
34 wide_resnet50_2 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col True False False macos
35 efficientnet-v2-s stablehlo tf 1e-02 1e-3 default nhcw-nhwc False False False macos
43 t5-base stablehlo tf 1e-2 1e-3 default None False False False macos
44 t5-large linalg torch 1e-2 1e-3 default None True True True Inputs for seq2seq models in torch currently unsupported macos
45 t5-large stablehlo tf 1e-2 1e-3 default None False False False macos
stabilityai/stable-diffusion-2-1-base linalg torch 1e-3 1e-3 default None True False False macos

View File

@@ -85,8 +85,6 @@ if __name__ == "__main__":
args = [
"--iree-llvmcpu-target-cpu-features=host",
"--iree-mhlo-demote-i64-to-i32=false",
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
]
backend_config = "dylib"
# backend = "cuda"

View File

@@ -1,3 +1,26 @@
# Running Different OPT Variants
# Run OPT for sentence completion through SHARK
To run different sizes of OPT, change the string `OPT_MODEL` string in `opt_torch_test.py`. The default is 350m parameters. 66b cases also exist in the file, simply uncomment the test cases.
From base SHARK directory, follow instructions to set up a virtual environment with SHARK. (`./setup_venv.sh` or `./setup_venv.ps1`)
Then, you may run opt_causallm.py to get a very simple sentence completion application running through SHARK
```
python opt_causallm.py
```
# Run OPT performance comparison on SHARK vs. PyTorch
```
python opt_perf_comparison.py --max-seq-len=512 --model-name=facebook/opt-1.3b \
--platform=shark
```
Any OPT model from huggingface should work with this script, and you can choose between `--platform=shark` or `--platform=huggingface` to generate benchmarks of OPT inference on SHARK / PyTorch.
# Run a small suite of OPT models through the benchmark script
```
python opt_perf_comparison_batch.py
```
This script will run benchmarks from a suite of OPT configurations:
- Sequence Lengths: 32, 128, 256, 512
- Parameter Counts: 125m, 350m, 1.3b
note: Most of these scripts are written for use on CPU, as perf comparisons against pytorch can be problematic across platforms otherwise.

View File

@@ -59,7 +59,7 @@ def create_module(model_name, tokenizer, device):
)
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
shark_module.save_module(module_name=vmfb_name)
shark_module.save_module(module_name=vmfb_name, debug=False)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path

View File

@@ -1,18 +1,46 @@
"""
Script for comparing OPT model performance between SHARK and Huggingface
PyTorch.
Usage Example:
python opt_perf_comparison.py --max-seq-len=32 --model-name=facebook/opt-125m \
--platform=shark
python opt_perf_comparison.py --max-seq-len=512 --model-name=facebook/opt-1.3b \
--platform=shark
See parse_args() below for command line argument usage.
"""
import argparse
import collections
import json
import time
import os
import psutil
import resource
import time
from typing import Tuple
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
from shark_opt_wrapper import OPTForCausalLMModel
MODEL_NAME = "facebook/opt-1.3b"
OPT_MODELNAME = "opt-1.3b"
OPT_FS_NAME = "opt_1-3b"
MAX_SEQUENCE_LENGTH = 512
DEVICE = "cpu"
PLATFORM_SHARK = "shark"
PLATFORM_HUGGINGFACE = "huggingface"
# Dict keys for reports.
REPORT_PLATFORM = "platform"
REPORT_MODEL_NAME = "model"
REPORT_MAX_SEQ_LEN = "max_seq_len"
REPORT_LOAD_TIME = "load_time_sec"
REPORT_RUN_TIME = "run_time_sec"
REPORT_LOAD_PHYSICAL_MEMORY_MB = "load_physical_MB"
REPORT_LOAD_VIRTUAL_MEMORY_MB = "load_virtual_MB"
REPORT_RUN_PHYSICAL_MEMORY_MB = "run_physical_MB"
REPORT_RUN_VIRTUAL_MEMORY_MB = "run_virtual_MB"
PROMPTS = [
"What is the meaning of life?",
@@ -30,15 +58,27 @@ PROMPTS = [
ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])
def create_vmfb_module(model_name, tokenizer, device):
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
def get_memory_info():
pid = os.getpid()
process = psutil.Process(pid)
return process.memory_info()
def create_vmfb_module(
model_name: str,
tokenizer,
device: str,
max_seq_len: int,
recompile_shark: bool,
):
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
"What is the meaning of life?",
PROMPTS[0],
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
max_length=max_seq_len,
return_tensors="pt",
)
inputs = (
@@ -48,8 +88,16 @@ def create_vmfb_module(model_name, tokenizer, device):
# np.save("model_inputs_0.npy", inputs[0])
# np.save("model_inputs_1.npy", inputs[1])
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
if os.path.isfile(mlir_path):
opt_fs_name = get_opt_fs_name(model_name)
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
# If MLIR has already been loaded and recompilation is not requested, use
# the loaded MLIR file.
has_mlir = os.path.isfile(mlir_path)
# The purpose of recompile_shark is to measure compilation time; the
# compilation time can be correctly measured only when MLIR has already been
# loaded.
assert not recompile_shark or has_mlir
if has_mlir:
with open(mlir_path, "r") as f:
model_mlir = f.read()
print(f"Loaded .mlir from {mlir_path}")
@@ -58,7 +106,7 @@ def create_vmfb_module(model_name, tokenizer, device):
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=OPT_FS_NAME,
model_name=opt_fs_name,
return_str=True,
)
with open(mlir_path, "w") as f:
@@ -72,18 +120,25 @@ def create_vmfb_module(model_name, tokenizer, device):
is_benchmark=False,
)
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{DEVICE}_tiled_ukernels"
vmfb_name = (
f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels"
)
shark_module.save_module(module_name=vmfb_name)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path
def load_shark_model() -> ModelWrapper:
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{DEVICE}_tiled_ukernels.vmfb"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
if not os.path.isfile(vmfb_name):
def load_shark_model(
model_name: str, max_seq_len: int, recompile_shark: bool
) -> ModelWrapper:
opt_fs_name = get_opt_fs_name(model_name)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
if recompile_shark or not os.path.isfile(vmfb_name):
print(f"vmfb not found. compiling and saving to {vmfb_name}")
create_vmfb_module(OPT_MODELNAME, tokenizer, DEVICE)
create_vmfb_module(
model_name, tokenizer, DEVICE, max_seq_len, recompile_shark
)
shark_module = SharkInference(mlir_module=None, device="cpu-task")
shark_module.load_module(vmfb_name)
return ModelWrapper(model=shark_module, tokenizer=tokenizer)
@@ -94,20 +149,10 @@ def run_shark_model(model_wrapper: ModelWrapper, tokens):
return model_wrapper.model("forward", tokens)
def run_shark():
model_wrapper = load_shark_model()
prompt = "What is the meaning of life?"
logits = run_shark_model(model_wrapper, prompt)
# Print output logits to validate vs. pytorch + base transformers
print(logits[0])
def load_huggingface_model() -> ModelWrapper:
def load_huggingface_model(model_name: str) -> ModelWrapper:
return ModelWrapper(
model=OPTForCausalLM.from_pretrained(MODEL_NAME),
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME),
model=OPTForCausalLM.from_pretrained(model_name),
tokenizer=AutoTokenizer.from_pretrained(model_name),
)
@@ -117,47 +162,71 @@ def run_huggingface_model(model_wrapper: ModelWrapper, tokens):
)
def run_huggingface():
model_wrapper = load_huggingface_model()
prompt = "What is the meaning of life?"
logits = run_huggingface_model(model_wrapper, prompt)
print(logits[0])
def save_json(data, filename):
with open(filename, "w") as file:
json.dump(data, file)
def collect_huggingface_logits():
def collect_huggingface_logits(
model_name: str, max_seq_len: int, to_save_json: bool
) -> Tuple[float, float]:
# Load
t0 = time.time()
model_wrapper = load_huggingface_model()
print("--- Took {} seconds to load Huggingface.".format(time.time() - t0))
model_wrapper = load_huggingface_model(model_name)
load_time = time.time() - t0
print("--- Took {} seconds to load Huggingface.".format(load_time))
load_memory_info = get_memory_info()
results = []
tokenized_prompts = []
for prompt in PROMPTS:
tokens = model_wrapper.tokenizer(
prompt,
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
max_length=max_seq_len,
truncation=True,
return_tensors="pt",
)
tokenized_prompts.append(tokens)
# Run
t0 = time.time()
for idx, tokens in enumerate(tokenized_prompts):
print("prompt: {}".format(PROMPTS[idx]))
logits = run_huggingface_model(model_wrapper, tokens)
results.append([PROMPTS[idx], logits[0].tolist()])
print("--- Took {} seconds to run Huggingface.".format(time.time() - t0))
save_json(results, "/tmp/huggingface.json")
if to_save_json:
results.append([PROMPTS[idx], logits[0].tolist()])
run_time = time.time() - t0
print("--- Took {} seconds to run Huggingface.".format(run_time))
if to_save_json:
save_json(results, "/tmp/huggingface.json")
run_memory_info = get_memory_info()
return {
REPORT_PLATFORM: PLATFORM_HUGGINGFACE,
REPORT_MODEL_NAME: model_name,
REPORT_MAX_SEQ_LEN: max_seq_len,
REPORT_LOAD_TIME: load_time,
REPORT_RUN_TIME: run_time / len(PROMPTS),
REPORT_LOAD_PHYSICAL_MEMORY_MB: load_memory_info.rss >> 20,
REPORT_LOAD_VIRTUAL_MEMORY_MB: load_memory_info.vms >> 20,
REPORT_RUN_PHYSICAL_MEMORY_MB: run_memory_info.rss >> 20,
REPORT_RUN_VIRTUAL_MEMORY_MB: run_memory_info.vms >> 20,
}
def collect_shark_logits():
def collect_shark_logits(
model_name: str,
max_seq_len: int,
recompile_shark: bool,
to_save_json: bool,
) -> Tuple[float, float]:
# Load
t0 = time.time()
model_wrapper = load_shark_model()
print("--- Took {} seconds to load Shark.".format(time.time() - t0))
model_wrapper = load_shark_model(model_name, max_seq_len, recompile_shark)
load_time = time.time() - t0
print("--- Took {} seconds to load Shark.".format(load_time))
load_memory_info = get_memory_info()
results = []
tokenized_prompts = []
for prompt in PROMPTS:
@@ -165,7 +234,7 @@ def collect_shark_logits():
prompt,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
max_length=max_seq_len,
return_tensors="pt",
)
inputs = (
@@ -173,16 +242,100 @@ def collect_shark_logits():
tokens["attention_mask"],
)
tokenized_prompts.append(inputs)
# Run
t0 = time.time()
for idx, tokens in enumerate(tokenized_prompts):
print("prompt: {}".format(PROMPTS[idx]))
logits = run_shark_model(model_wrapper, tokens)
lst = [e.tolist() for e in logits]
results.append([PROMPTS[idx], lst])
print("--- Took {} seconds to run Shark.".format(time.time() - t0))
save_json(results, "/tmp/shark.json")
if to_save_json:
results.append([PROMPTS[idx], lst])
run_time = time.time() - t0
print("--- Took {} seconds to run Shark.".format(run_time))
if to_save_json:
save_json(results, "/tmp/shark.json")
platform_postfix = "-compile" if recompile_shark else "-precompiled"
run_memory_info = get_memory_info()
return {
REPORT_PLATFORM: PLATFORM_SHARK + platform_postfix,
REPORT_MODEL_NAME: model_name,
REPORT_MAX_SEQ_LEN: max_seq_len,
REPORT_LOAD_TIME: load_time,
REPORT_RUN_TIME: run_time / len(PROMPTS),
REPORT_LOAD_PHYSICAL_MEMORY_MB: load_memory_info.rss >> 20,
REPORT_LOAD_VIRTUAL_MEMORY_MB: load_memory_info.vms >> 20,
REPORT_RUN_PHYSICAL_MEMORY_MB: run_memory_info.rss >> 20,
REPORT_RUN_VIRTUAL_MEMORY_MB: run_memory_info.vms >> 20,
}
def get_opt_fs_name(model_name: str) -> str:
"""Cleanses the model name ino a file system-friendly name.
Example: get_opt_fs_name('facebook/opt-1.3b') == 'opt_1-3b'
"""
slash_split = model_name.split("/")
assert 1 <= len(slash_split) <= 2, "There should be at most one slash."
model_name = slash_split[-1]
for src_pattern, dest_pattern in (("-", "_"), (".", "-")):
model_name = model_name.replace(src_pattern, dest_pattern)
return model_name
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--save-json",
help="If set, saves output JSON.",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--max-seq-len", help="Max sequence length", type=int, default=32
)
parser.add_argument(
"--model-name",
help="Model name",
type=str,
choices=[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
],
default="facebook/opt-1.3b",
)
parser.add_argument(
"--recompile-shark",
help="If set, recompiles MLIR",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--platform",
help="Either shark or huggingface",
type=str,
choices=[PLATFORM_SHARK, PLATFORM_HUGGINGFACE],
default=PLATFORM_SHARK,
)
args = parser.parse_args()
print("args={}".format(args))
return args
if __name__ == "__main__":
collect_shark_logits()
collect_huggingface_logits()
args = parse_args()
if args.platform == PLATFORM_SHARK:
shark_report = collect_shark_logits(
args.model_name,
args.max_seq_len,
args.recompile_shark,
args.save_json,
)
print("# Summary: {}".format(json.dumps(shark_report)))
else:
huggingface_report = collect_huggingface_logits(
args.model_name, args.max_seq_len, args.save_json
)
print("# Summary: {}".format(json.dumps(huggingface_report)))

View File

@@ -0,0 +1,30 @@
"""
Script for running opt_perf_comparison.py in batch with a series of arguments.
Usage: python opt_perf_comparison_batch.py
"""
from typing import Iterable, List
import shlex
import subprocess
def make_commands() -> Iterable[List[str]]:
command = shlex.split("python opt_perf_comparison.py --no-save-json")
max_seq_lens = [32, 128, 256, 512]
model_names = ["facebook/opt-" + e for e in ["125m", "350m", "1.3b"]]
for max_seq_len in max_seq_lens:
for model_name in model_names:
yield command + [
f"--max-seq-len={max_seq_len}",
f"--model-name={model_name}",
]
def main():
for command in make_commands():
result = subprocess.run(command, check=True)
if __name__ == "__main__":
main()

View File

@@ -16,12 +16,6 @@ import subprocess as sp
import hashlib
import numpy as np
from pathlib import Path
from apps.stable_diffusion.src.models import (
model_wrappers as mw,
)
from apps.stable_diffusion.src.utils.stable_args import (
args,
)
def create_hash(file_name):
@@ -60,31 +54,6 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
print("generating artifacts for: " + torch_model_name)
model = None
input = None
if model_type == "stable_diffusion":
args.use_tuned = False
args.import_mlir = True
args.local_tank_cache = local_tank_cache
precision_values = ["fp16"]
seq_lengths = [64, 77]
for precision_value in precision_values:
args.precision = precision_value
for length in seq_lengths:
model = mw.SharkifyStableDiffusionModel(
model_id=torch_model_name,
custom_weights="",
precision=precision_value,
max_len=length,
width=512,
height=512,
use_base_vae=False,
custom_vae="",
debug=True,
sharktank_dir=local_tank_cache,
generate_vmfb=False,
)
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(
torch_model_name, import_args
@@ -103,10 +72,11 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
model, input, _ = get_hf_img_cls_model(
torch_model_name, import_args
)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name, import_args)
torch_model_name = torch_model_name.replace("/", "_")
if import_args["batch_size"] != 1:
if import_args["batch_size"] > 1:
print(
f"Batch size for this model set to {import_args['batch_size']}"
)
torch_model_dir = os.path.join(
local_tank_cache,
str(torch_model_name)
@@ -391,7 +361,7 @@ if __name__ == "__main__":
# old_import_args = parser.parse_import_args()
import_args = {
"batch_size": "1",
"batch_size": 1,
}
print(import_args)
home = str(Path.home())
@@ -404,11 +374,6 @@ if __name__ == "__main__":
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
)
save_torch_model(
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
WORKDIR,
import_args,
)
save_torch_model(torch_model_csv, WORKDIR, import_args)
save_tf_model(tf_model_csv, WORKDIR, import_args)
save_tflite_model(tflite_model_csv, WORKDIR, import_args)
# save_tf_model(tf_model_csv, WORKDIR, import_args)
# save_tflite_model(tflite_model_csv, WORKDIR, import_args)

View File

@@ -278,7 +278,7 @@ def get_vision_model(torch_model, import_args):
int(import_args["batch_size"]), 3, *input_image_size
)
actual_out = model(test_input)
if fp16_model is not None:
if fp16_model == True:
test_input_fp16 = test_input.to(
device=torch.device("cuda"), dtype=torch.half
)

View File

@@ -145,6 +145,7 @@ class SharkModuleTester:
shark_args.shark_prefix = self.shark_tank_prefix
shark_args.local_tank_cache = self.local_tank_cache
shark_args.dispatch_benchmarks = self.benchmark_dispatches
shark_args.enable_tf32 = self.tf32
if self.benchmark_dispatches is not None:
_m = self.config["model_name"].split("/")
@@ -216,10 +217,12 @@ class SharkModuleTester:
result = shark_module(func_name, inputs)
golden_out, result = self.postprocess_outputs(golden_out, result)
if self.tf32 == "true":
print("Validating with relaxed tolerances.")
atol = 1e-02
rtol = 1e-03
if self.tf32 == True:
print(
"Validating with relaxed tolerances for TensorFloat32 calculations."
)
self.config["atol"] = 1e-01
self.config["rtol"] = 1e-02
try:
np.testing.assert_allclose(
golden_out,
@@ -254,9 +257,6 @@ class SharkModuleTester:
model_config = {
"batch_size": self.batch_size,
}
shark_args.enable_tf32 = self.tf32
if shark_args.enable_tf32 == True:
shark_module.compile()
shark_args.onnx_bench = self.onnx_bench
shark_module.shark_runner.benchmark_all_csv(
@@ -287,6 +287,9 @@ class SharkModuleTester:
repro_path = os.path.join("reproducers", self.tmp_prefix, "*")
bashCommand = f"gsutil cp -r {repro_path} gs://shark-public/builder/repro_artifacts/{self.ci_sha}/{self.tmp_prefix}/"
print(
f"Uploading reproducer {repro_path} to gs://shark-public/builder/repro_artifacts/{self.ci_sha}/{self.tmp_prefix}/"
)
process = subprocess.run(bashCommand.split())
def postprocess_outputs(self, golden_out, result):

View File

@@ -5,7 +5,6 @@ microsoft/MiniLM-L12-H384-uncased,True,hf,True,linalg,False,66M,"nlp;bert-varian
bert-base-uncased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
google/mobilebert-uncased,True,hf,True,linalg,False,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
alexnet,False,vision,True,linalg,False,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
resnet18,False,vision,True,linalg,False,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
resnet50,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
resnet101,False,vision,True,linalg,False,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
@@ -18,11 +17,9 @@ facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,linalg,False,22M
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,linalg,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
nvidia/mit-b0,True,hf_img_cls,False,linalg,False,3.7M,"image-classification,transformer-encoder",SegFormer
mnasnet1_0,False,vision,True,linalg,False,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
1 model_name use_tracing model_type dynamic mlir_type decompose param_count tags notes
5 bert-base-uncased True hf True linalg False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
6 bert-base-cased True hf True linalg False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
7 google/mobilebert-uncased True hf True linalg False 25M nlp,bert-variant,transformer-encoder,mobile 24 layers, 512 hidden size, 128 embedding
alexnet False vision True linalg False 61M cnn,parallel-layers The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod.
8 resnet18 False vision True linalg False 11M cnn,image-classification,residuals,resnet-variant 1 7x7 conv2d and the rest are 3x3 conv2d
9 resnet50 False vision True linalg False 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
10 resnet101 False vision True linalg False 29M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
17 microsoft/beit-base-patch16-224-pt22k-ft22k True hf_img_cls False linalg False 86M image-classification,transformer-encoder,bert-variant,vision-transformer N/A
18 nvidia/mit-b0 True hf_img_cls False linalg False 3.7M image-classification,transformer-encoder SegFormer
19 mnasnet1_0 False vision True linalg False - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
resnet50_fp16 False vision True linalg False 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
bert-base-uncased_fp16 True fp16 False linalg False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
20 bert-large-uncased True hf True linalg False 330M nlp;bert-variant;transformer-encoder 24 layers, 1024 hidden units, 16 attention heads
21 bert-base-uncased True hf False stablehlo False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
22 gpt2 True hf_causallm False stablehlo True 125M nlp;transformer-encoder -
23 facebook/opt-125m True hf False stablehlo True 125M nlp;transformer-encoder -
24 distilgpt2 True hf False stablehlo True 88M nlp;transformer-encoder -
25 microsoft/deberta-v3-base True hf False stablehlo True 88M nlp;transformer-encoder -