Compare commits

...

96 Commits

Author SHA1 Message Date
yzhang93
9644e78545 Fix CUDA tuned model annotation (#880) 2023-01-27 11:35:18 -08:00
dymil
c911189ef0 Add note about latest RDNA3 driver support (#881)
Also tweak other wording
2023-01-27 09:39:19 -08:00
Abhishek Varma
1118b4b651 [SD-CLI] Clean up vmfbs if a retry method fails
-- This commit cleans up vmfb files generated as a result of retry method.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-27 21:55:36 +05:30
PhaneeshB
4be75d4418 fix seed values in SD json and filename 2023-01-27 18:40:26 +05:30
Ean Garvey
fb6beae27c Adds pytest-forked dependency to fix pytest memory accumulation issues. (#876)
* Minor improvements to test-models workflow

- cleaned up pytest command line args in Validate Models job scripts.
- Removed -s flag to provide more readable logs
- Changed shark_cache location to within github workspace and removed --update_tank flag from Linux workflows.

* Use pytest-forked for managing pytest memory usage.
2023-01-26 18:20:15 -06:00
yzhang93
fee73b0b63 Add SD model annotation on fly (#869)
* Add SD model annotation on fly

* Move tuned_compile_through_fx to utils

* Fix SD compilation flags
2023-01-26 11:46:36 -08:00
powderluv
9bbffa519e Add an option to respect LLPC env var (#875)
Also add OSX paths
2023-01-25 13:56:55 -08:00
jinchen62
c3a641f0ab Address TODOs for dataset annotator (#872)
- add args usage, pass gs_url by CL flag
- add support for no existing prompts
2023-01-25 09:28:23 -08:00
yzhang93
aafe7c4701 Add more cuda devices to use tuned model (#868) 2023-01-25 06:36:17 -08:00
Abhishek Varma
9a0b082cf8 [SD-CLI] Add batch_size command-line arg + prompt processing
-- This commit adds `batch_size` command-line arg.
-- It also involves replicating the prompt `batch_size` no. of times.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-25 19:21:25 +05:30
powderluv
8265e34a29 Add SHARK SD CLI tool (#870) 2023-01-24 23:14:32 -08:00
powderluv
8ef8ae097f Update to build 469 2023-01-24 22:16:13 -08:00
powderluv
c3d14293c0 Update sample results 2023-01-24 22:14:06 -08:00
powderluv
d55d8be504 Add signing of release builds 2023-01-24 21:32:21 -08:00
powderluv
03543030d3 use pefile 2023-01-24 18:35:51 -08:00
powderluv
fc6b474b92 Add ordlookup to requirements.txt 2023-01-24 18:30:16 -08:00
powderluv
a5db785dd7 checkoutv2 on windows 2023-01-24 18:23:22 -08:00
powderluv
1c1c5cd611 Build Windows nightly on 7950x 2023-01-24 16:21:56 -08:00
Abhishek Varma
6ed02f70ec [SD-CLI] Make using ckpt_loc and hf_model_id easier
-- Currently we require users to specify the base model on which the custom
   model (.ckpt) is tuned on. Even for running a HuggingFace repo-id, we
   require the users to go a tedious way of adding things to variants.json.

-- This commit aims to address the above issues and will be treated as a
   starting point for a series of design changes which makes using SHARK's SD
   easier.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-24 23:03:46 +05:30
Prashant Kumar
cb78cd8ac0 Add the support for the batch size parameter. 2023-01-24 22:33:13 +05:30
Ean Garvey
0c4590b45a Update generate_sharktank.py 2023-01-24 10:18:03 +05:30
jinchen62
d2e2ee6efa Add multiple prompts support for dataset annotator (#862) 2023-01-23 18:40:36 -08:00
powderluv
6a380a0b48 Add more nvidia cards 2023-01-23 17:07:45 -08:00
powderluv
e5d5acbf1f Remove torchvision requirements from web (#860) 2023-01-23 13:48:53 -08:00
powderluv
00e38abbf0 Add 4080 support 2023-01-23 09:56:34 -08:00
Abhishek Varma
e3e4ea5443 Update README.md
-- Make usage of `hf_model_id` clearer.
2023-01-23 23:25:23 +05:30
Prashant Kumar
a3e4ea3228 Remove the dependency of the torchvision. (#858)
Remove the dependency of torchvision library for the conversion
of tensor layout format to what PIL library expects.
2023-01-23 08:49:57 -08:00
powderluv
56f16d6baf Update SD readme 2023-01-23 06:51:54 -08:00
Abhishek Varma
7a55ab900e [SD-CLI] Fix CKPT script + add more variants + update README.md
-- This commit fixes CKPT script to rely on the previous CKPT to Diffusers
   script.
   TODO: Let go of the script once the CKPT is included in next release
         of diffusers.
-- It also adds many variants as part of `variants.json` and updates
   `README.md` to reflect change in default `hf_model_id`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-23 18:34:24 +05:30
Abhishek Varma
137643fe72 [SD-CLI] Update README.md of custom models to include hf_model_id 2023-01-23 11:37:13 +05:30
Anush Elangovan
d6e59c6241 black format comments 2023-01-22 16:34:40 -08:00
powderluv
458eb5d34c detect RX 7900 better 2023-01-22 16:32:27 -08:00
Erkin Alp Güney
8259f08864 Collapsibles for Win10 and Linux users (#851)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-01-22 09:50:33 -08:00
Prashant Kumar
b3ab0a1843 Add width and height support for the scheduler. 2023-01-22 23:16:50 +05:30
dependabot[bot]
f09f217478 Bump tensorflow from 2.10 to 2.10.1 (#853)
Bumps [tensorflow](https://github.com/tensorflow/tensorflow) from 2.10 to 2.10.1.
- [Release notes](https://github.com/tensorflow/tensorflow/releases)
- [Changelog](https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md)
- [Commits](https://github.com/tensorflow/tensorflow/compare/v2.10.0...v2.10.1)

---
updated-dependencies:
- dependency-name: tensorflow
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-01-22 06:40:17 -08:00
Daniel Garvey
e842c8c19b add main.py testing for sdiff (#836)
Co-authored-by: dan <dan@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-01-22 01:16:17 -08:00
powderluv
f6c3112d44 Revert "potential fix to pre-load DLL dir for torch-mlir (#848)" (#852)
This reverts commit 6c470d8131.
2023-01-22 00:09:35 -08:00
yzhang93
7059610632 Modify the default for --hf_model_id flag 2023-01-21 11:21:47 +05:30
powderluv
2d272930d9 Update to signed build 455 2023-01-20 16:50:42 -08:00
powderluv
6c470d8131 potential fix to pre-load DLL dir for torch-mlir (#848)
Doesn't regress the main.py script but system already pre-loaded
the DLL so needs more testing.
2023-01-20 14:48:45 -08:00
jinchen62
30b29ce8cd Add readme for dataset annotator (#847) 2023-01-20 01:03:33 -08:00
jinchen62
1a9933002f Add dataset annotation tool (#835) 2023-01-19 16:56:08 -08:00
stanley
c4a9365aa1 [Shark][Training] Refresh SharkTrainer to latest APIs. 2023-01-19 20:30:15 +00:00
Prashant Kumar
9d3af37104 bugfix related to the height width params. 2023-01-20 00:21:44 +05:30
Prashant Kumar
7b3d57cff7 Add height and width as args. 2023-01-19 23:43:29 +05:30
Abhishek Varma
a802270da9 [SD-CLI] Update README.md about variants.json 2023-01-19 22:46:54 +05:30
Abhishek Varma
dd194a8758 [SD-CLI] Reorder loading of opt_params when needed
-- This commit reorders loading of opt_params when `import_mlir` is not used.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 22:02:51 +05:30
Abhishek Varma
6de02de221 [SD-CLI] Make using custom models easier
-- This commit makes using custom models easier using a combination of
   `import_mlir`, `ckpt_loc` and `hf_model_id`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 22:02:36 +05:30
Abhishek Varma
85259750bf [SD-CLI] Fix variants.json mapping
-- This commit fixes variants.json's mapping.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 22:02:36 +05:30
Prashant Kumar
1249f0007d Remove args.variant and args.version with args.custom_model. 2023-01-19 19:55:12 +05:30
Abhishek Varma
db0514d3fa [SD-CLI] Fix get_model_configuration to use max_length
-- This commit fixes `get_model_configuration` to use `max_length`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 19:10:04 +05:30
Abhishek Varma
dce42a7fad [SD-CLI] Fix args.max_length range check
This commit fixes args.max_length range check.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-19 18:26:23 +05:30
Prashant Kumar
ec0b380194 Refactor shark_tank models and custom models.
The custom models shouldn't depend on shark_tank in anyway.
2023-01-19 13:56:11 +05:30
Ean Garvey
7f27b61c98 Update setup_venv.sh to install triton if BENCHMARK=1 2023-01-19 00:26:46 -06:00
Guy Nachshon
f0b3557b02 fix: replace malicious and deleted package (#833) 2023-01-18 13:41:05 -08:00
xzuyn
2a1d1c1001 make jpeg optimized and progressive (#820)
* GUI make jpeg optimized and progressive

* CLI make jpeg optimized and progressive
2023-01-17 16:35:36 -08:00
Abhishek Varma
df7eb80e5b [SD-CLI] Make custom_model take highest priority for generating models if present
-- This commit makes `custom_model` take highest priority for generating models if present.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-17 22:50:58 +05:30
Fraser Humphries
b9d947ce6f style: 🎨 Restore whitespace 2023-01-17 17:45:32 +05:30
Fraser Humphries
e6589d2454 fix: 🏗️ Add demo.css to spec file datas 2023-01-17 17:45:32 +05:30
Fraser Humphries
0f5ac6afcf fix: 🐛 resolve css file path relative to __file__
issues-816
2023-01-17 17:45:32 +05:30
Abhishek Varma
bc1bb1d188 [SD-CLI] Fix vmfb naming + update README.md for custom_model
-- This commit introduces a fix for .vmfb naming to strip away any
   non-alphanumeric characters from `custom_model` path.
-- It also updates the README.md to include the `custom_model` arg.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-17 16:27:54 +05:30
Abhishek Varma
3af2dd10ce [SD-CLI] Add CKPT support to update models irrespective of import_mlir flag
-- This commit adds CKPT support to update models irrespective of `import_mlir` flag.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-01-17 13:24:27 +05:30
yzhang93
dd22c65855 Add CUDA tuned models for SD variants (#814) 2023-01-16 09:38:27 -08:00
PhaneeshB
48137ced19 add png as default format 2023-01-16 18:37:36 +05:30
Phaneesh Barwaria
6eb47c12d1 add multi-run in single execution (#812) 2023-01-13 11:12:43 -08:00
Prashant Kumar
5a1fc6675a This PR adds --import-mlir for f16 tensors without cuda. 2023-01-13 22:19:53 +05:30
Prashant Kumar
6f80825814 Modify import_with_fx to import with dtype=f16. 2023-01-13 22:19:53 +05:30
PhaneeshB
f0dd48ed2a remaining disk space warning 2023-01-13 19:34:05 +05:30
Gaurav Shukla
15e2df0db0 [SD][web] Add a UI textbox to show the output location
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-13 19:33:04 +05:30
Fraser Humphries
4ad0109769 fix: 🐛 Extract demo css string to css file
fix: 🐛 Extract demo css string to css file

issues/807

fix: 🐛 Revert background colors
2023-01-13 16:42:05 +05:30
PhaneeshB
ee0009d4b8 pythonize uname for cpu target triple in windows 2023-01-12 22:39:49 +05:30
PhaneeshB
9d851c3346 small fixes 2023-01-12 22:32:24 +05:30
xzuyn
5d117af8ae Increase JPEG output quality & disable subsampling (#801)
* Increase JPEG output quality & disable subsampling

Increased to JPEG95 from the default JPEG75 which is way too compressed. Output image size is now ~100kb. Previously was ~20kb.

* Increase JPEG output quality & disable subsampling

Add jpeg quality increase on cli

* line length changes

* line length changes
2023-01-11 23:06:11 -08:00
yzhang93
bb41c2d15e Add VAE cuda tuned model (#796) 2023-01-11 14:15:03 -08:00
powderluv
eba138ee4a Revert "Change address for connection test (#785)" (#797)
This reverts commit 187f0fa70c.
2023-01-11 12:01:37 -08:00
Gaurav Shukla
3b2bbb74f8 [SD][web] Add support for saving generated images
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-11 22:47:32 +05:30
fokin33
dbc0f81211 Add simple telegram bot (#787) 2023-01-11 09:20:23 -06:00
mariecwhite
d0b613d22e Enable Torch-Inductor Part 2 2023-01-10 20:15:29 -08:00
Ean Garvey
72f29b67d5 Add Resnet50 fp16 variant to pytests. (#760) 2023-01-10 16:31:11 -08:00
Quinn Dawkins
9570045cc3 Fix tuned model selection for non-vulkan devices (#792) 2023-01-10 19:04:21 -05:00
Phaneesh Barwaria
e4efdb5cbb add json data for each image (#790)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-01-10 13:13:07 -08:00
calcifer11
187f0fa70c Change address for connection test (#785)
Some ISP's (like mine) reserves 1.1.1.1 for internal testing, meaning _internet_connected(); needlessly retries for a minute until it fails even though my connection is fine.
Propose 8.8.8.8 instead as this is also publically available and not normally blocked by ISPs.
2023-01-10 10:51:30 -08:00
Gaurav Shukla
472185c3e4 [SD][web] Fix device key error
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-10 20:51:01 +05:30
Gaurav Shukla
f94a571773 [SD] Update spec file to include model_config.json
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-10 20:38:10 +05:30
mariecwhite
183e447d35 Enable Torch Inductor (#784) 2023-01-10 20:57:58 +11:00
xzuyn
12f844d93a Git pull through argument in setup_venv (#623) 2023-01-09 15:42:13 -08:00
yzhang93
47a119a37f [SD] Add CUDA A100 tuned model (#773) 2023-01-09 15:22:27 -08:00
Gaurav Shukla
ee56559b9a [SD][web] Add a json file for model configuration
This cleans model_wrappers.py file.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-10 00:05:46 +05:30
Gaurav Shukla
00e594deea [SD][web] Add version number in performance details
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-09 21:32:34 +05:30
George Petterson
6ad9b213b9 Add GCN4
(cherry picked from commit 3be072b3c09c9b38bc2d79ad6e6900eefee49a1c)
2023-01-09 21:09:50 +05:30
PhaneeshB
e4375e8195 Add support for vulkan target env 2023-01-09 21:09:50 +05:30
mariecwhite
487bf8e29b Enable TF32 in Torch if specified (#768) 2023-01-09 06:48:57 -08:00
Prashant Kumar
fea1694e74 Delete the cached objects explicitly. 2023-01-06 23:04:52 +05:30
Prashant Kumar
4102c124a9 Add the shark upscaler model. (#759) 2023-01-05 14:07:20 -08:00
yzhang93
135bad3280 [SD] Update v1.4 tuned model (#758) 2023-01-05 11:04:30 -08:00
Gaurav Shukla
b604f36881 [SD][web] Add flags for global URL and server port
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-01-05 15:30:30 +05:30
66 changed files with 4052 additions and 873 deletions

View File

@@ -10,14 +10,14 @@ on:
jobs:
windows-build:
runs-on: windows-latest
runs-on: 7950X
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
@@ -52,6 +52,10 @@ jobs:
./setup_venv.ps1
pyinstaller web/shark_sd.spec
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
pyinstaller .\shark\examples\shark_inference\stable_diffusion\shark_sd_cli.spec
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
# GHA windows VM OOMs so disable for now

View File

@@ -100,9 +100,9 @@ jobs:
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cpu --update_tank
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="${GITHUB_WORKSPACE}/shark_tmp/shark_cache" -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
@@ -112,9 +112,10 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cuda --update_tank
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="${GITHUB_WORKSPACE}/shark_tmp/shark_cache" -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
sh build_tools/stable_diff_main_test.sh
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
@@ -125,7 +126,7 @@ jobs:
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest -s --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" tank/test_models.py -k vulkan --update_tank
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" -k vulkan --update_tank
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os != 'MacStudio'
@@ -133,4 +134,4 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k vulkan --update_tank
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="${GITHUB_WORKSPACE}/shark_tmp/shark_cache" -k vulkan

View File

@@ -83,19 +83,15 @@ python3.10 shark/examples/shark_inference/stable_diffusion/main.py --precision=f
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
The output on a 6900XT would like:
The output on a 7900XTX would like:
```shell
44it [00:08, 5.14it/s]i = 44 t = 120 (191ms)
45it [00:08, 5.15it/s]i = 45 t = 100 (191ms)
46it [00:08, 5.16it/s]i = 46 t = 80 (191ms)
47it [00:09, 5.16it/s]i = 47 t = 60 (193ms)
48it [00:09, 5.15it/s]i = 48 t = 40 (195ms)
49it [00:09, 5.12it/s]i = 49 t = 20 (196ms)
50it [00:09, 5.14it/s]
Average step time: 192.8154182434082ms/it
Total image generation runtime (s): 10.390909433364868
(shark.venv) PS C:\g\shark>
Stats for run 0:
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
Total image generation time: 2.5788655281066895sec
```
Here are some samples generated:

View File

@@ -0,0 +1,41 @@
import argparse
import torchvision
import numpy as np
import requests
import shutil
import os
import subprocess
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--newfile")
parser.add_argument(
"-g",
"--golden_url",
default="https://storage.googleapis.com/shark_tank/testdata/cyberpunk_fores_42_0_230119_021148.png",
)
def get_image(url, local_filename):
res = requests.get(url, stream=True)
if res.status_code == 200:
with open(local_filename, "wb") as f:
shutil.copyfileobj(res.raw, f)
return torchvision.io.read_image(local_filename).numpy()
if __name__ == "__main__":
args = parser.parse_args()
new = torchvision.io.read_image(args.newfile).numpy() / 255.0
tempfile_name = os.path.join(os.getcwd(), "golden.png")
golden = get_image(args.golden_url, tempfile_name) / 255.0
diff = np.abs(new - golden)
mean = np.mean(diff)
if not mean < 0.2:
subprocess.run(
["gsutil", "cp", args.newfile, "gs://shark_tank/testdata/builder/"]
)
raise SystemExit("new and golden not close")
else:
print("SUCCESS")

View File

@@ -1,5 +1,5 @@
#!/bin/bash
IMPORTER=1 ./setup_venv.sh
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python generate_sharktank.py --upload=False --ci_tank_dir=True

View File

@@ -0,0 +1,6 @@
rm -rf ./test_images
mkdir test_images
python shark/examples/shark_inference/stable_diffusion/main.py --device=vulkan --output_dir=./test_images --no-load_vmfb --no-use_tuned
python build_tools/image_comparison.py -n ./test_images/*.png
exit $?

27
dataset/README.md Normal file
View File

@@ -0,0 +1,27 @@
# Dataset annotation tool
SHARK annotator for adding or modifying prompts of dataset images
## Set up
Activate SHARK Python virtual environment and install additional packages
```shell
source ../shark.venv/bin/activate
pip install -r requirements.txt
```
## Run annotator
```shell
python annotation_tool.py
```
<img width="1280" alt="annotator" src="https://user-images.githubusercontent.com/49575973/214521137-7ef6ae10-7cd8-46e6-b270-b6c0445157f1.png">
* Select a dataset from `Dataset` dropdown list
* Select an image from `Image` dropdown list
* Image and the existing prompt will be loaded
* Select a prompt from `Prompt` dropdown list to modify or "Add new" to add a prompt
* Click `Save` to save changes, click `Delete` to delete prompt
* Click `Back` or `Next` to switch image, you could also select other images from `Image`
* Click `Finish` when finishing annotation or before switching dataset

248
dataset/annotation_tool.py Normal file
View File

@@ -0,0 +1,248 @@
import gradio as gr
import json
import jsonlines
import os
from args import args
from pathlib import Path
from PIL import Image
from utils import get_datasets
shark_root = Path(__file__).parent.parent
demo_css = shark_root.joinpath("web/demo.css").resolve()
nodlogo_loc = shark_root.joinpath(
"web/models/stable_diffusion/logos/nod-logo.png"
)
with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=100)
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
prompt_data = dict()
with gr.Row(elem_id="ui_body"):
# TODO: add multiselect dataset, there is a gradio version conflict
dataset = gr.Dropdown(label="Dataset", choices=datasets)
image_name = gr.Dropdown(label="Image", choices=[])
with gr.Row(elem_id="ui_body"):
# TODO: add ability to search image by typing
with gr.Column(scale=1, min_width=600):
image = gr.Image(type="filepath").style(height=512)
with gr.Column(scale=1, min_width=600):
prompts = gr.Dropdown(
label="Prompts",
choices=[],
)
prompt = gr.Textbox(
label="Editor",
lines=3,
)
with gr.Row():
save = gr.Button("Save")
delete = gr.Button("Delete")
with gr.Row():
back_image = gr.Button("Back")
next_image = gr.Button("Next")
finish = gr.Button("Finish")
def filter_datasets(dataset):
if dataset is None:
return gr.Dropdown.update(value=None, choices=[])
# create the dataset dir if doesn't exist and download prompt file
dataset_path = str(shark_root) + "/dataset/" + dataset
if not os.path.exists(dataset_path):
os.mkdir(dataset_path)
# read prompt jsonlines file
prompt_data.clear()
if dataset in ds_w_prompts:
prompt_gs_path = args.gs_url + "/" + dataset + "/metadata.jsonl"
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
for line in reader.iter(type=dict, skip_invalid=True):
prompt_data[line["file_name"]] = (
[line["text"]]
if type(line["text"]) is str
else line["text"]
)
return gr.Dropdown.update(choices=images[dataset])
dataset.change(fn=filter_datasets, inputs=dataset, outputs=image_name)
def display_image(dataset, image_name):
if dataset is None or image_name is None:
return gr.Image.update(value=None), gr.Dropdown.update(value=None)
# download and load the image
img_gs_path = args.gs_url + "/" + dataset + "/" + image_name
img_sub_path = "/".join(image_name.split("/")[:-1])
img_dst_path = (
str(shark_root) + "/dataset/" + dataset + "/" + img_sub_path + "/"
)
if not os.path.exists(img_dst_path):
os.mkdir(img_dst_path)
os.system(f'gsutil cp "{img_gs_path}" "{img_dst_path}"')
img = Image.open(img_dst_path + image_name.split("/")[-1])
if image_name not in prompt_data.keys():
prompt_data[image_name] = []
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Image.update(value=img), gr.Dropdown.update(
choices=prompt_choices
)
image_name.change(
fn=display_image,
inputs=[dataset, image_name],
outputs=[image, prompts],
)
def edit_prompt(prompts):
if prompts == "Add new":
return gr.Textbox.update(value=None)
return gr.Textbox.update(value=prompts)
prompts.change(fn=edit_prompt, inputs=prompts, outputs=prompt)
def save_prompt(dataset, image_name, prompts, prompt):
if (
dataset is None
or image_name is None
or prompts is None
or prompt is None
):
return
if prompts == "Add new":
prompt_data[image_name].append(prompt)
else:
idx = prompt_data[image_name].index(prompts)
prompt_data[image_name][idx] = prompt
prompt_path = (
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
)
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
if not value:
continue
v = value if len(value) > 1 else value[0]
f.write(json.dumps({"file_name": key, "text": v}))
f.write("\n")
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Dropdown.update(choices=prompt_choices, value=None)
save.click(
fn=save_prompt,
inputs=[dataset, image_name, prompts, prompt],
outputs=prompts,
)
def delete_prompt(dataset, image_name, prompts):
if dataset is None or image_name is None or prompts is None:
return
if prompts == "Add new":
return
prompt_data[image_name].remove(prompts)
prompt_path = (
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
)
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
if not value:
continue
v = value if len(value) > 1 else value[0]
f.write(json.dumps({"file_name": key, "text": v}))
f.write("\n")
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Dropdown.update(choices=prompt_choices, value=None)
delete.click(
fn=delete_prompt,
inputs=[dataset, image_name, prompts],
outputs=prompts,
)
def get_back_image(dataset, image_name):
if dataset is None or image_name is None:
return
# remove local image
img_path = str(shark_root) + "/dataset/" + dataset + "/" + image_name
os.system(f'rm "{img_path}"')
# get the index for the back image
idx = images[dataset].index(image_name)
if idx == 0:
return gr.Dropdown.update(value=None)
return gr.Dropdown.update(value=images[dataset][idx - 1])
back_image.click(
fn=get_back_image, inputs=[dataset, image_name], outputs=image_name
)
def get_next_image(dataset, image_name):
if dataset is None or image_name is None:
return
# remove local image
img_path = str(shark_root) + "/dataset/" + dataset + "/" + image_name
os.system(f'rm "{img_path}"')
# get the index for the next image
idx = images[dataset].index(image_name)
if idx == len(images[dataset]) - 1:
return gr.Dropdown.update(value=None)
return gr.Dropdown.update(value=images[dataset][idx + 1])
next_image.click(
fn=get_next_image, inputs=[dataset, image_name], outputs=image_name
)
def finish_annotation(dataset):
if dataset is None:
return
# upload prompt and remove local data
dataset_path = str(shark_root) + "/dataset/" + dataset
dataset_gs_path = args.gs_url + "/" + dataset + "/"
os.system(
f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"'
)
os.system(f'rm -rf "{dataset_path}"')
return gr.Dropdown.update(value=None)
finish.click(fn=finish_annotation, inputs=dataset, outputs=dataset)
if __name__ == "__main__":
shark_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

34
dataset/args.py Normal file
View File

@@ -0,0 +1,34 @@
import argparse
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Dataset Annotator flags
##############################################################################
p.add_argument(
"--gs_url",
type=str,
required=True,
help="URL to datasets in GS bucket",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
##############################################################################
args = p.parse_args()

3
dataset/requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
# SHARK Annotator
gradio==3.15.0
jsonlines

29
dataset/utils.py Normal file
View File

@@ -0,0 +1,29 @@
from google.cloud import storage
def get_datasets(gs_url):
datasets = set()
images = dict()
ds_w_prompts = []
storage_client = storage.Client()
bucket_name = gs_url.split("/")[2]
source_blob_name = "/".join(gs_url.split("/")[3:])
blobs = storage_client.list_blobs(bucket_name, prefix=source_blob_name)
for blob in blobs:
dataset_name = blob.name.split("/")[1]
if dataset_name == "":
continue
datasets.add(dataset_name)
if dataset_name not in images.keys():
images[dataset_name] = []
# check if image or jsonl
file_sub_path = "/".join(blob.name.split("/")[2:])
if "/" in file_sub_path:
images[dataset_name] += [file_sub_path]
elif "metadata.jsonl" in file_sub_path:
ds_w_prompts.append(dataset_name)
return list(datasets), images, ds_w_prompts

View File

@@ -14,22 +14,11 @@ import csv
import argparse
from shark.shark_importer import SharkImporter
from shark.parser import shark_args
import tensorflow as tf
import subprocess as sp
import hashlib
import numpy as np
from pathlib import Path
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
def create_hash(file_name):
with open(file_name, "rb") as f:
@@ -41,9 +30,12 @@ def create_hash(file_name):
def save_torch_model(torch_model_list):
from tank.model_utils import get_hf_model
from tank.model_utils import get_vision_model
from tank.model_utils import get_hf_img_cls_model
from tank.model_utils import (
get_hf_model,
get_vision_model,
get_hf_img_cls_model,
get_fp16_model,
)
with open(torch_model_list) as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")
@@ -65,7 +57,8 @@ def save_torch_model(torch_model_list):
model, input, _ = get_hf_model(torch_model_name)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(torch_model_name)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name)
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"
@@ -106,6 +99,17 @@ def save_tf_model(tf_model_list):
get_keras_model,
get_TFhf_model,
)
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
with open(tf_model_list) as csvfile:
tf_reader = csv.reader(csvfile, delimiter=",")

View File

@@ -3,6 +3,8 @@
numpy==1.22.4
torchvision
pytorch-triton
tabulate
tqdm
@@ -13,7 +15,7 @@ iree-tools-tf
# TensorFlow and JAX.
gin-config
tensorflow==2.10
tensorflow==2.10.1
keras==2.10
#tf-models-nightly
#tensorflow-text-nightly

View File

@@ -10,6 +10,7 @@ google-cloud-storage
# Testing
pytest
pytest-xdist
pytest-forked
Pillow
parameterized
@@ -22,4 +23,5 @@ gradio
altair
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller

View File

@@ -1,3 +1,9 @@
param([string]$arguments)
if ($arguments -eq "--update-src"){
git pull
}
#Write-Host "Installing python"
#Start-Process winget install Python.Python.3.10 '/quiet InstallAllUsers=1 PrependPath=1' -wait -NoNewWindow

View File

@@ -123,8 +123,13 @@ fi
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
T_VER=$($PYTHON -m pip show torch | grep Version)
TORCH_VERSION=${T_VER:9:17}
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
TV_VERSION=${TV_VER:9:18}
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu117
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu117."
else

View File

@@ -4,6 +4,60 @@
Follow setup instructions in the main [README.md](https://github.com/nod-ai/SHARK#readme) for regular usage.
## Using other supported Stable Diffusion variants with SHARK:
Currently we support fine-tuned versions of Stable Diffusion such as:
- [AnythingV3](https://huggingface.co/Linaqruf/anything-v3.0)
- [Analog Diffusion](https://huggingface.co/wavymulder/Analog-Diffusion)
use the flag `--hf_model_id=` to specify the repo-id of the model to be used.
```shell
python .\shark\examples\shark_inference\stable_diffusion\main.py --hf_model_id="Linaqruf/anything-v3.0" --max_length=77 --prompt="1girl, brown hair, green eyes, colorful, autumn, cumulonimbus clouds, lighting, blue sky, falling leaves, garden"
```
## Run a custom model using a HuggingFace `.ckpt` file:
* Install the following by running :-
```shell
pip install omegaconf safetensors pytorch_lightning
```
* Download a [.ckpt](https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned-fp32.ckpt) file in case you don't have a locally generated `.ckpt` file for StableDiffusion.
* Now pass the above `.ckpt` file to `ckpt_loc` command-line argument using the following :-
```shell
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --ckpt_loc="/path/to/.ckpt/file"
```
* We use a combination of 2 flags to make this feature work : `import_mlir` and `ckpt_loc`.
* In case `ckpt_loc` is NOT specified then a [default](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) HuggingFace repo-id is run via `hf_model_id`. So, you can use `import_mlir` and `hf_model_id` to run HuggingFace's StableDiffusion variants.
* Use custom model `.ckpt` files from [HuggingFace-StableDiffusion](https://huggingface.co/models?other=stable-diffusion) to generate images.
## Running the model for a `batch_size` and for a set of `runs`:
We currently support batch size in the range `[1, 3]`.
You can specify batch size using `batch_size` flag (defaults to `1`) and the number of times you want to run the model using `runs` flag (defaults to `1`).
In total, you'll be able to generate `batch_size * runs` number of images.
- Usage 1: Using the same prompt -
```shell
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --hf_model_id="runwayml/stable-diffusion-v1-5" --batch_size=3
```
The example above generates `3` different images in total with the same prompt `tajmahal, oil on canvas, sunflowers, 4k, uhd`.
- Usage 2: Using different prompts -
```shell
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --hf_model_id="runwayml/stable-diffusion-v1-5" --batch_size=3 -p="batman riding a horse, oil on canvas, 4k, uhd" -p="superman riding a horse, oil on canvas, 4k, uhd"
```
The example above generates `1` image for each different prompt, thus generating `3` images in total.
- Usage 3: Using `runs` -
```shell
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --hf_model_id="runwayml/stable-diffusion-v1-5" --batch_size=2 --runs=3
```
The example above generates `6` different images in total, `2` images for each `runs`.
</details>
<details>
<summary>Debug Commands</summary>
## Debug commands and other advanced usage follows.
```shell
@@ -43,14 +97,4 @@ unzip ~/.local/shark_tank/<your unet>/inputs.npz
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --function_input=@arr_0.npy --function_input=1xf16 --function_input=@arr_2.npy --function_input=@arr_3.npy --function_input=@arr_4.npy
```
## Using other supported Stable Diffusion variants with SHARK:
Currently we support the following fine-tuned versions of Stable Diffusion:
- [AnythingV3](https://huggingface.co/Linaqruf/anything-v3.0)
- [Analog Diffusion](https://huggingface.co/wavymulder/Analog-Diffusion)
use the flag `--variant=` to specify the model to be used.
```shell
python .\shark\examples\shark_inference\stable_diffusion\main.py --variant=anythingv3 --max_length=77 --prompt="1girl, brown hair, green eyes, colorful, autumn, cumulonimbus clouds, lighting, blue sky, falling leaves, garden"
```
</details>

View File

@@ -1,11 +1,15 @@
import os
import sys
os.environ["AMD_ENABLE_LLPC"] = "1"
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
from transformers import CLIPTextModel, CLIPTokenizer
import torch
from PIL import Image
import torchvision.transforms as T
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
@@ -17,6 +21,11 @@ from tqdm.auto import tqdm
import numpy as np
from random import randint
from stable_args import args
from datetime import datetime as dt
import json
import re
from pathlib import Path
from model_wrappers import SharkifyStableDiffusionModel
# This has to come before importing cache objects
if args.clear_all:
@@ -38,14 +47,12 @@ if args.clear_all:
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
from utils import set_init_device_flags
from utils import set_init_device_flags, disk_space_check, preprocessCKPT
from opt_params import get_unet, get_vae, get_clip
from schedulers import (
SharkEulerDiscreteScheduler,
)
import time
import sys
from shark.iree_utils.compile_utils import dump_isas
# Helper function to profile the vulkan device.
@@ -69,40 +76,59 @@ if __name__ == "__main__":
dtype = torch.float32 if args.precision == "fp32" else torch.half
# Make it as default prompt
if len(args.prompts) == 0:
args.prompts = ["cyberpunk forest by Salvador Dali"]
prompt = args.prompts
neg_prompt = args.negative_prompts
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
if args.version == "v2_1":
height = 768
width = 768
height = args.height
width = args.width
num_inference_steps = args.steps # Number of denoising steps
# Scale for classifier-free guidance
guidance_scale = torch.tensor(args.guidance_scale).to(torch.float32)
# Handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
seed = args.seed
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(
seed
) # Seed generator to create the inital latent noise
batch_size = args.batch_size
prompt = prompt * batch_size if len(prompt) == 1 else prompt
len_of_prompt = len(prompt)
assert (
len_of_prompt == batch_size
), f"no. of prompts ({len_of_prompt}) is not equal to batch_size ({batch_size})"
print("Running StableDiffusion with the following config :-")
print(f"Batch size : {batch_size}")
print(f"Prompts : {prompt}")
print(f"Runs : {args.runs}")
# TODO: Add support for batch_size > 1.
batch_size = len(prompt)
if batch_size != 1:
sys.exit("More than one prompt is not supported yet.")
if batch_size != len(neg_prompt):
sys.exit("prompts and negative prompts must be of same length")
# Try to make neg_prompt equal to batch_size by appending blank strings.
for i in range(batch_size - len(neg_prompt)):
neg_prompt.append("")
set_init_device_flags()
clip = get_clip()
unet = get_unet()
vae = get_vae()
disk_space_check(Path.cwd())
if not args.import_mlir:
from opt_params import get_unet, get_vae, get_clip
clip = get_clip()
unet = get_unet()
vae = get_vae()
else:
if ".ckpt" in args.ckpt_loc:
preprocessCKPT()
mlir_import = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
args.precision,
max_len=args.max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
)
clip, unet, vae = mlir_import()
if args.dump_isa:
dump_isas(args.dispatch_benchmarks_dir)
@@ -112,7 +138,7 @@ if __name__ == "__main__":
subfolder="scheduler",
)
cpu_scheduling = True
if args.version == "v2_1":
if args.hf_model_id == "stabilityai/stable-diffusion-2-1":
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2-1", subfolder="tokenizer"
)
@@ -122,7 +148,7 @@ if __name__ == "__main__":
subfolder="scheduler",
)
if args.version == "v2_1base" and args.variant == "stablediffusion":
if args.hf_model_id == "stabilityai/stable-diffusion-2-1-base":
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
)
@@ -139,116 +165,166 @@ if __name__ == "__main__":
"stabilityai/stable-diffusion-2-1-base",
subfolder="scheduler",
)
for run in range(args.runs):
# Handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
seed = args.seed
if run >= 1 or seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(
seed
) # Seed generator to create the inital latent noise
# create a random initial latent.
latents = torch.randn(
(batch_size, 4, height // 8, width // 8),
generator=generator,
dtype=torch.float32,
).to(dtype)
# Warmup phase to improve performance.
if args.warmup_count >= 1:
vae_warmup_input = torch.clone(latents).detach().numpy()
clip_warmup_input = torch.randint(1, 2, (2, args.max_length))
for i in range(args.warmup_count):
vae("forward", (vae_warmup_input,))
clip("forward", (clip_warmup_input,))
# create a random initial latent.
latents = torch.randn(
(batch_size, 4, height // 8, width // 8),
generator=generator,
dtype=torch.float32,
).to(dtype)
if run == 0:
# Warmup phase to improve performance.
if args.warmup_count >= 1:
vae_warmup_input = torch.clone(latents).detach().numpy()
clip_warmup_input = torch.randint(1, 2, (2, args.max_length))
for i in range(args.warmup_count):
vae("forward", (vae_warmup_input,))
clip("forward", (clip_warmup_input,))
start = time.time()
start = time.time()
if run == 0:
text_input = tokenizer(
prompt,
padding="max_length",
max_length=args.max_length,
truncation=True,
return_tensors="pt",
)
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
neg_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input = torch.cat(
[uncond_input.input_ids, text_input.input_ids]
)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=args.max_length,
truncation=True,
return_tensors="pt",
)
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
neg_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
clip_inf_start = time.time()
text_embeddings = clip("forward", (text_input,))
clip_inf_end = time.time()
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
clip_inf_start = time.time()
text_embeddings = clip("forward", (text_input,))
clip_inf_end = time.time()
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
scheduler.set_timesteps(num_inference_steps)
scheduler.is_scale_input_called = True
scheduler.set_timesteps(num_inference_steps)
scheduler.is_scale_input_called = True
latents = latents * scheduler.init_noise_sigma
latents = latents * scheduler.init_noise_sigma
avg_ms = 0
for i, t in tqdm(
enumerate(scheduler.timesteps), disable=args.hide_steps
):
step_start = time.time()
if not args.hide_steps:
print(f"i = {i} t = {t}", end="")
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = scheduler.scale_model_input(latents, t)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
avg_ms = 0
for i, t in tqdm(enumerate(scheduler.timesteps), disable=args.hide_steps):
step_start = time.time()
if not args.hide_steps:
print(f"i = {i} t = {t}", end="")
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = scheduler.scale_model_input(latents, t)
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = scheduler.step(noise_pred, t, latents).prev_sample
else:
latents = scheduler.step(noise_pred, t, latents)
step_time = time.time() - step_start
avg_ms += step_time
step_ms = int((step_time) * 1000)
if not args.hide_steps:
print(f" ({step_ms}ms)")
# scale and decode the image latents with vae
if args.use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = vae("forward", (latents_numpy,))
vae_end = time.time()
end_profiling(profile_device)
if args.use_base_vae:
image = torch.from_numpy(images)
image = (image.detach().cpu() * 255.0).numpy()
images = image.round()
end_time = time.time()
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = scheduler.step(noise_pred, t, latents).prev_sample
avg_ms = 1000 * avg_ms / args.steps
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
vae_inf_time = (vae_end - vae_start) * 1000
total_time = end_time - start
print(f"\nStats for run {run}:")
print(f"Average step time: {avg_ms}ms/it")
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
print(f"\nTotal image generation time: {total_time}sec")
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
if args.output_dir is not None:
output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
else:
latents = scheduler.step(noise_pred, t, latents)
step_time = time.time() - step_start
avg_ms += step_time
step_ms = int((step_time) * 1000)
if not args.hide_steps:
print(f" ({step_ms}ms)")
# scale and decode the image latents with vae
if args.use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = vae("forward", (latents_numpy,))
vae_end = time.time()
end_profiling(profile_device)
if args.use_base_vae:
image = torch.from_numpy(images)
image = (image.detach().cpu() * 255.0).numpy()
images = image.round()
end_time = time.time()
avg_ms = 1000 * avg_ms / args.steps
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
vae_inf_time = (vae_end - vae_start) * 1000
total_time = end_time - start
print(f"\nAverage step time: {avg_ms}ms/it")
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
print(f"\nTotal image generation time: {total_time}sec")
transform = T.ToPILImage()
pil_images = [
transform(image) for image in torch.from_numpy(images).to(torch.uint8)
]
for i in range(batch_size):
pil_images[i].save(f"{args.prompts[i]}_{i}.jpg")
output_path = Path.cwd()
disk_space_check(output_path, lim=5)
for i in range(batch_size):
json_store = {
"prompt": prompt[i],
"negative prompt": args.negative_prompts[i],
"seed": seed,
"hf_model_id": args.hf_model_id,
"precision": args.precision,
"steps": args.steps,
"guidance_scale": args.guidance_scale,
"scheduler": args.scheduler,
}
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", prompt[i][:15])
img_name = f"{prompt_slice}_{seed}_{run}_{i}_{dt.now().strftime('%y%m%d_%H%M%S')}"
if args.output_img_format == "jpg":
pil_images[i].save(
output_path / f"{img_name}.jpg",
quality=95,
subsampling=0,
optimize=True,
progressive=True,
)
else:
pil_images[i].save(output_path / f"{img_name}.png", "PNG")
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"saving image as png. Supported formats png / jpg"
)
with open(output_path / f"{img_name}.json", "w") as f:
f.write(json.dumps(json_store, indent=4))

View File

@@ -1,285 +1,251 @@
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from utils import compile_through_fx
from stable_args import args
from utils import compile_through_fx, get_opt_flags
from resources import base_models
from collections import defaultdict
import torch
model_config = {
"v2_1": "stabilityai/stable-diffusion-2-1",
"v2_1base": "stabilityai/stable-diffusion-2-1-base",
"v1_4": "CompVis/stable-diffusion-v1-4",
}
# clip has 2 variants of max length 77 or 64.
model_clip_max_length = 64 if args.max_length == 64 else 77
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
model_clip_max_length = 77
elif args.variant == "openjourney":
model_clip_max_length = 64
model_variant = {
"stablediffusion": "SD",
"anythingv3": "Linaqruf/anything-v3.0",
"dreamlike": "dreamlike-art/dreamlike-diffusion-1.0",
"openjourney": "prompthero/openjourney",
"analogdiffusion": "wavymulder/Analog-Diffusion",
}
model_input = {
"v2_1": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
"vae": (torch.randn(1, 4, 96, 96),),
"unet": (
torch.randn(1, 4, 96, 96), # latents
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, model_clip_max_length, 1024), # embedding
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v2_1base": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
"vae": (torch.randn(1, 4, 64, 64),),
"unet": (
torch.randn(1, 4, 64, 64), # latents
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, model_clip_max_length, 1024), # embedding
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v1_4": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
"vae": (torch.randn(1, 4, 64, 64),),
"unet": (
torch.randn(1, 4, 64, 64),
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, model_clip_max_length, 768),
torch.tensor(1).to(torch.float32),
),
},
}
# revision param for from_pretrained defaults to "main" => fp32
model_revision = {
"stablediffusion": "fp16" if args.precision == "fp16" else "main",
"anythingv3": "diffusers",
"analogdiffusion": "main",
"openjourney": "main",
"dreamlike": "main",
}
import sys
def get_clip_mlir(model_name="clip_text", extra_args=[]):
# These shapes are parameter dependent.
def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape = []
for i in range(len(shape)):
if shape[i] == "max_len":
new_shape.append(max_len)
elif shape[i] == "height":
new_shape.append(height)
elif shape[i] == "width":
new_shape.append(width)
elif isinstance(shape[i], str):
if "batch_size" in shape[i]:
mul_val = int(shape[i].split("*")[0])
new_shape.append(batch_size * mul_val)
else:
new_shape.append(shape[i])
return new_shape
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
if args.variant == "stablediffusion":
if args.version != "v1_4":
text_encoder = CLIPTextModel.from_pretrained(
model_config[args.version], subfolder="text_encoder"
)
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
text_encoder = CLIPTextModel.from_pretrained(
model_variant[args.variant],
subfolder="text_encoder",
revision=model_revision[args.variant],
# Get the input info for various models i.e. "unet", "clip", "vae".
def get_input_info(model_info, max_len, width, height, batch_size):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = defaultdict(list)
for k in model_info:
for inp in model_info[k]:
shape = model_info[k][inp]["shape"]
dtype = dtype_config[model_info[k][inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, max_len, width, height, batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map[k].append(tensor)
return input_map
class SharkifyStableDiffusionModel:
def __init__(
self,
model_id: str,
custom_weights: str,
precision: str,
max_len: int = 64,
width: int = 512,
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.model_id = model_id if custom_weights == "" else custom_weights
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
str(batch_size)
+ "_"
+ str(max_len)
+ "_"
+ str(height)
+ "_"
+ str(width)
+ "_"
+ precision
)
else:
raise ValueError(f"{args.variant} not yet added")
self.use_tuned = use_tuned
# We need a better naming convention for the .vmfbs because despite
# using the custom model variant the .vmfb names remain the same and
# it'll always pick up the compiled .vmfb instead of compiling the
# custom model.
# So, currently, we add `self.model_id` in the `self.model_name` of
# .vmfb file.
# TODO: Have a better way of naming the vmfbs using self.model_name.
import re
class CLIPText(torch.nn.Module):
def __init__(self):
super().__init__()
self.text_encoder = text_encoder
model_name = re.sub(r"\W+", "_", self.model_id)
if model_name[0] == "_":
model_name = model_name[1:]
self.model_name = self.model_name + "_" + model_name
def forward(self, input):
return self.text_encoder(input)[0]
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
sys.exit("please specify max_len in the range [32, 77].")
if not (width % 8 == 0 and width >= 384):
sys.exit("width should be greater than 384 and multiple of 8")
if not (height % 8 == 0 and height >= 384):
sys.exit("height should be greater than 384 and multiple of 8")
clip_model = CLIPText()
shark_clip = compile_through_fx(
clip_model,
model_input[args.version]["clip"],
model_name=model_name,
extra_args=extra_args,
)
return shark_clip
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
)
self.base_vae = base_vae
def forward(self, input):
if not self.base_vae:
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
if self.base_vae:
return x
x = x * 255.0
return x.round()
def get_base_vae_mlir(model_name="vae", extra_args=[]):
class BaseVaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="vae",
revision=model_revision[args.variant],
vae = VaeModel()
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
vae_name = "base_vae" if self.base_vae else "vae"
shark_vae = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
model_name=vae_name + self.model_name,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("vae", precision=self.precision),
)
return shark_vae
def get_unet(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(
self, latent, timestep, text_embedding, guidance_scale
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel()
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
shark_unet = compile_through_fx(
unet,
inputs,
model_name="unet" + self.model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_unet
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
)
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText()
shark_clip = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
model_name="clip" + self.model_name,
extra_args=get_opt_flags("clip", precision="fp32"),
)
return shark_clip
def __call__(self):
from utils import get_vmfb_path_name
from stable_args import args
import traceback, functools, operator, os
model_name = ["clip", "base_vae" if self.base_vae else "vae", "unet"]
vmfb_path = [
get_vmfb_path_name(model + self.model_name)[0]
for model in model_name
]
for model_id in base_models:
self.inputs = get_input_info(
base_models[model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
def forward(self, input):
x = self.vae.decode(input, return_dict=False)[0]
return (x / 2 + 0.5).clamp(0, 1)
vae = BaseVaeModel()
if args.variant == "stablediffusion":
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
)
else:
inputs = model_input["v1_4"]["vae"]
else:
raise ValueError(f"{args.variant} not yet added")
shark_vae = compile_through_fx(
vae,
inputs,
model_name=model_name,
extra_args=extra_args,
)
return shark_vae
def get_vae_mlir(model_name="vae", extra_args=[]):
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="vae",
revision=model_revision[args.variant],
)
def forward(self, input):
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
x = x * 255.0
return x.round()
vae = VaeModel()
if args.variant == "stablediffusion":
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
)
else:
inputs = model_input["v1_4"]["vae"]
else:
raise ValueError(f"{args.variant} not yet added")
shark_vae = compile_through_fx(
vae,
inputs,
model_name=model_name,
extra_args=extra_args,
)
return shark_vae
def get_unet_mlir(model_name="unet", extra_args=[]):
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="unet",
revision=model_revision[args.variant],
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, latent, timestep, text_embedding, guidance_scale):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel()
if args.variant == "stablediffusion":
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input[args.version]["unet"]
]
)
else:
inputs = model_input[args.version]["unet"]
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input["v1_4"]["unet"]
]
)
else:
inputs = model_input["v1_4"]["unet"]
else:
raise ValueError(f"{args.variant} is not yet added")
shark_unet = compile_through_fx(
unet,
inputs,
model_name=model_name,
extra_args=extra_args,
)
return shark_unet
try:
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
except Exception as e:
if args.enable_stack_trace:
traceback.print_exc()
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = functools.reduce(
operator.__and__, vmfb_present
)
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(len(vmfb_path)):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
print("Retrying with a different base model configuration")
continue
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please use `enable_stack_trace` and create an issue at https://github.com/nod-ai/SHARK/issues"
)

View File

@@ -1,10 +1,4 @@
import sys
from model_wrappers import (
get_base_vae_mlir,
get_vae_mlir,
get_unet_mlir,
get_clip_mlir,
)
from resources import models_db
from stable_args import args
from utils import get_shark_model
@@ -13,6 +7,18 @@ BATCH_SIZE = len(args.prompts)
if BATCH_SIZE != 1:
sys.exit("Only batch size 1 is supported.")
hf_model_variant_map = {
"Linaqruf/anything-v3.0": ["anythingv3", "v2_1base"],
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v2_1base"],
"prompthero/openjourney": ["openjourney", "v2_1base"],
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1"],
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
}
variant, version = hf_model_variant_map[args.hf_model_id]
def get_params(bucket_key, model_key, model, is_tuned, precision):
iree_flags = []
@@ -33,7 +39,7 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
]
except KeyError:
raise Exception(
f"{bucket}/{model_key} is not present in the models database"
f"{bucket_key}/{model_key} is not present in the models database"
)
if (
@@ -62,13 +68,16 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
def get_unet():
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
bucket_key = f"{args.variant}/{is_tuned}"
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "unet", is_tuned, args.precision
)
if not args.use_tuned and args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
@@ -76,24 +85,25 @@ def get_vae():
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
is_base = "/base" if args.use_base_vae else ""
bucket_key = f"{args.variant}/{is_tuned}"
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
if not args.use_tuned and args.import_mlir:
if args.use_base_vae:
return get_base_vae_mlir(model_name, iree_flags)
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
bucket_key = f"{args.variant}/untuned"
model_key = f"{args.variant}/{args.version}/clip/fp32/length_{args.max_length}/untuned"
bucket_key = f"{variant}/untuned"
model_key = (
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
)
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "clip", "untuned", "fp32"
)
if args.import_mlir:
return get_clip_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)

View File

@@ -11,21 +11,27 @@ def resource_path(relative_path):
return os.path.join(base_path, relative_path)
prompt_examples = []
prompts_loc = resource_path("resources/prompts.json")
if os.path.exists(prompts_loc):
with open(prompts_loc, encoding="utf-8") as fopen:
prompt_examples = json.load(fopen)
def get_json_file(path):
json_var = []
loc_json = resource_path(path)
if os.path.exists(loc_json):
with open(loc_json, encoding="utf-8") as fopen:
json_var = json.load(fopen)
if not prompt_examples:
print("Unable to fetch prompt examples.")
if not json_var:
print(f"Unable to fetch {path}")
return json_var
models_db = []
models_loc = resource_path("resources/model_db.json")
if os.path.exists(models_loc):
with open(models_loc, encoding="utf-8") as fopen:
models_db = json.load(fopen)
# TODO: This shouldn't be called from here, every time the file imports
# it will run all the global vars.
prompts_examples = get_json_file("resources/prompts.json")
models_db = get_json_file("resources/model_db.json")
if len(models_db) != 3:
sys.exit("Error: Unable to load models database.")
# The base_model contains the input configuration for the different
# models and also helps in providing information for the variants.
base_models = get_json_file("resources/base_model.json")
# Contains optimization flags for different models.
opt_flags = get_json_file("resources/opt_flags.json")

View File

@@ -0,0 +1,98 @@
{
"stabilityai/stable-diffusion-2-1": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"CompVis/stable-diffusion-v1-4": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
}
}

View File

@@ -0,0 +1,21 @@
[
{
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
"openjourney/v1_4":"prompthero/openjourney",
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
},
{
"stablediffusion/fp16":"fp16",
"stablediffusion/fp32":"main",
"anythingv3/fp16":"diffusers",
"anythingv3/fp32":"diffusers",
"analogdiffusion/fp16":"main",
"analogdiffusion/fp32":"main",
"openjourney/fp16":"main",
"openjourney/fp32":"main"
}
]

View File

@@ -2,30 +2,40 @@
{
"stablediffusion/untuned":"gs://shark_tank/stable_diffusion",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_1dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet2base_8dec_fp16",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip2base_18dec_fp32",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_19dec_v2p1base_fp32_64",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet2_14dec_fp16",
@@ -34,18 +44,22 @@
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip2_18dec_fp32",
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
"anythingv3/v2_1base/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
@@ -101,16 +115,14 @@
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
"--iree-flow-enable-conv-winograd-transform"
"--iree-flow-enable-conv-img2col-transform"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
"--iree-flow-enable-conv-winograd-transform"
"--iree-flow-enable-conv-img2col-transform"
]
}
},

View File

@@ -0,0 +1,101 @@
{
"unet": {
"tuned": {
"fp16": {
"default_compilation_flags": []
},
"fp32": {
"default_compilation_flags": []
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32"
],
"specified_compilation_flags": {
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
}
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
]
}
}
},
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": ["--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"]
}
},
"fp32": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
}
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
]
}
}
},
"clip": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
}
}
}
}

View File

@@ -15,10 +15,19 @@ import torch
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = len(args.prompts)
if len(args.prompts) == 0:
BATCH_SIZE = 1
model_input = {
"euler": {
"latent": torch.randn(1, 4, 64, 64),
"output": torch.randn(1, 4, 64, 64),
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
@@ -84,7 +93,8 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
self.scaling_model = compile_through_fx(
scaling_model,
(example_latent, example_sigma),
model_name="euler_scale_model_input_" + args.precision,
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
@@ -92,7 +102,8 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
self.step_model = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
model_name="euler_step_" + args.precision,
model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
else:

View File

@@ -1,6 +1,6 @@
import os
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import run_cmd, iree_target_map
from shark.iree_utils._common import iree_target_map, run_cmd
from shark.shark_downloader import (
download_model,
download_public_file,
@@ -8,79 +8,116 @@ from shark.shark_downloader import (
)
from shark.parser import shark_args
from stable_args import args
from opt_params import get_params
from utils import set_init_device_flags
# Downloads the model (Unet or VAE fp16) from shark_tank
set_init_device_flags()
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{args.variant}/untuned"
use_winograd = False
if args.annotation_model == "unet":
if args.version == "v2_1base":
use_winograd = True
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
use_winograd = True
is_base = "/base" if args.use_base_vae else ""
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/untuned{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
device = (
args.device if "://" not in args.device else args.device.split("://")[0]
)
# Downloads the tuned config files from shark_tank
config_bucket = "gs://shark_tank/sd_tuned/configs/"
if use_winograd:
config_name = f"{args.annotation_model}_winograd.json"
# Download the model (Unet or VAE fp16) from shark_tank
def load_model_from_tank():
from opt_params import get_params, version, variant
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{variant}/untuned"
if args.annotation_model == "unet":
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
is_base = "/base" if args.use_base_vae else ""
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)
return mlir_model, model_name
# Download the tuned config files from shark_tank
def load_winograd_configs():
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
winograd_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
if args.annotation_model == "unet":
if args.variant in ["anythingv3", "analogdiffusion"]:
def load_lower_configs():
from opt_params import version, variant
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_version = version
if variant in ["anythingv3", "analogdiffusion"]:
args.max_length = 77
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}.json"
config_version = "v1_4"
if args.annotation_model == "vae":
args.max_length = 77
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json"
full_gs_url = config_bucket + config_name
lowering_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading lowering config file from ", lowering_config_dir)
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
# Annotate the model with Winograd attribute on selected conv ops
if use_winograd:
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with create_context() as ctx:
winograd_model = model_annotation(
ctx,
input_contents=mlir_model,
input_contents=input_mlir,
config_path=winograd_config_dir,
search_op="conv",
winograd=use_winograd,
winograd=True,
)
with open(
f"{args.annotation_output}/{model_name}_tuned_torch.mlir", "w"
) as f:
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return winograd_model, out_file_path
# For Unet annotate the model with tuned lowering configs
if args.annotation_model == "unet":
def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
if use_winograd:
input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
else:
input_mlir = f"{WORKDIR}/{model_name}_torch/{model_name}_torch.mlir"
dump_after = "iree-flow-pad-linalg-ops"
# Dump IR after padding/img2col/winograd passes
device_spec_args = ""
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
gpu_flags = get_iree_gpu_args()
for flag in gpu_flags:
device_spec_args += flag + " "
elif device == "vulkan":
device_spec_args = (
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
print("Applying tuned configs on", model_name)
run_cmd(
f"iree-compile {input_mlir} "
"--iree-input-type=tm_tensor "
f"--iree-hal-target-backends={iree_target_map(args.device)} "
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
f"--iree-hal-target-backends={iree_target_map(device)} "
f"{device_spec_args}"
"--iree-stream-resource-index-bits=64 "
"--iree-vm-target-index-bits=64 "
"--iree-flow-enable-padding-linalg-ops "
@@ -102,7 +139,53 @@ if args.annotation_model == "unet":
# Remove the intermediate mlir and save the final annotated model
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
output_path = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
with open(output_path, "w") as f:
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return tuned_model, out_file_path
def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model, model_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
model_path, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model, output_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
use_winograd = False
if model_from_tank:
mlir_model = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
else:
# Just use this function to convert bytecode to string
orig_model, model_path = annotate_with_winograd(
mlir_model, "", model_name
)
mlir_model = model_path
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)
print(f"Saved the annotated mlir in {output_path}.")
return tuned_model, output_path
if __name__ == "__main__":
mlir_model, model_name = load_model_from_tank()
sd_model_annotation(mlir_model, model_name, model_from_tank=True)

View File

@@ -0,0 +1,74 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torchvision')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('diffusers')
datas += copy_metadata('transformers')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'resources/prompts.json', 'resources'),
( 'resources/model_db.json', 'resources'),
( 'resources/base_model.json', 'resources'),
( 'resources/opt_flags.json', 'resources'),
]
binaries = []
block_cipher = None
a = Analysis(
['main.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core' ],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -15,9 +15,10 @@ p = argparse.ArgumentParser(
##############################################################################
p.add_argument(
"-p",
"--prompts",
nargs="+",
default=["cyberpunk forest by Salvador Dali"],
action="append",
default=[],
help="text of which images to be generated.",
)
@@ -42,6 +43,28 @@ p.add_argument(
help="the seed to use.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `run`.",
)
p.add_argument(
"--height",
type=int,
default=512,
help="the height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
help="the width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
@@ -64,13 +87,6 @@ p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(
"--version",
type=str,
default="v2_1base",
help="Specify version of stable diffusion model",
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)
@@ -110,12 +126,6 @@ p.add_argument(
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--variant",
default="stablediffusion",
help="We now support multiple vairants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported
)
p.add_argument(
"--scheduler",
type=str,
@@ -123,6 +133,48 @@ p.add_argument(
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
)
p.add_argument(
"--runs",
type=int,
default=1,
help="number of images to be generated with random seeds in single execution",
)
p.add_argument(
"--ckpt_loc",
type=str,
default="",
help="Path to SD's .ckpt file.",
)
p.add_argument(
"--hf_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--enable_stack_trace",
default=False,
action=argparse.BooleanOptionalAction,
help="Enable showing the stack trace when retrying the base model configuration",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################

View File

@@ -12,22 +12,23 @@ If it works well for you, please "star" the following GitHub projects... this is
*AMD Software: Adrenalin Edition 22.11.1 for MLIR/IREE Driver Version 22.20.29.09 for Windows® 10 and Windows® 11 (Windows Driver Store Version 31.0.12029.9003)*
First, download this special driver in a folder of your choice. We recommend you keep that driver around since you may need to re-install it later, if Windows Update decides to overwrite it:
First, for RDNA2 users, download this special driver in a folder of your choice. We recommend you keep the installation files around, since you may need to re-install it later, if Windows Update decides to overwrite it:
https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mlir-iree
For RDNA3, the latest driver 23.1.2 supports MLIR/IREE as well: https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-1-2-kb
KNOWN ISSUES with this special AMD driver:
* `Windows Update` may (depending how it's configured) automatically install a new official AMD driver that overwrites this IREE-specific driver. If Stable Diffusion used to work, then a few days later, it slows down a lot or produces incorrect results (e.g. black images), this may be the cause. To fix this problem, please check the installed driver's version, and re-install the special driver if needed. (TODO: document how to prevent this `Windows Update` behavior!)
* Some people using this special driver experience mouse pointer accuracy issues, if you use a larger-than-default mouse pointer. The clicked point isn't centered properly. One possible work-around is to reset the pointer size to "1" in "Change pointer size and color".
* `Windows Update` may (depending how it's configured) automatically install a new official AMD driver that overwrites this IREE-specific driver. If Stable Diffusion used to work, then a few days later, it slows down a lot or produces incorrect results (e.g. black images), this may be the cause. To fix this problem, please check the installed driver version, and re-install the special driver if needed. (TODO: document how to prevent this `Windows Update` behavior!)
* Some people using this special driver experience mouse pointer accuracy issues, especially if using a larger-than-default mouse pointer. The clicked point isn't centered properly. One possible work-around is to reset the pointer size to "1" in "Change pointer size and color".
## Installation
Download the latest Windows SHARK SD binary [423 here](https://github.com/nod-ai/SHARK/releases/download/20230101.423/shark_sd_20230101_423.exe) in a folder of your choice. If you want nighly builds you can look for them in the github releases page. Please read carefully the following notes:
Download the latest Windows SHARK SD binary [469 here](https://github.com/nod-ai/SHARK/releases/download/20230124.469/shark_sd_20230124_469.exe) in a folder of your choice. If you want nighly builds, you can look for them on the GitHub releases page.
Notes:
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR, that can get outdated if you run multiple EXE from the same folder. You can use `--clean_all` flag once to clean all the old files.
* Your browser may warn you about downloading an .exe file
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR which can be outdated if you run a new EXE from the same folder. You can use `--clean_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you:
* clear all the local artifacts with `--clean_all` OR
* clear all the local artifacts with `--clear_all` OR
* clear the Vulkan shader cache: For Windows users this can be done by clearing the contents of `C:\Users\%username%\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
* clear the `huggingface` cache. In Windows, this is `C:\Users\%username%\.cache\huggingface`.
@@ -59,9 +60,9 @@ Here are some samples generated:
<summary>Advanced Installation </summary>
## Setup your Python VirtualEnvironment and Dependencies
### Windows 10/11 Users
## Setup your Python Virtual Environment and Dependencies
<details>
<summary> Windows 10/11 Users </summary>
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
@@ -78,8 +79,10 @@ git clone https://github.com/nod-ai/SHARK.git
cd SHARK
./setup_venv.ps1 #You can re-run this script to get the latest version
```
</details>
### Linux
<details>
<summary>Linux</summary>
```shell
git clone https://github.com/nod-ai/SHARK.git
@@ -87,53 +90,65 @@ cd SHARK
./setup_venv.sh
source shark.venv/bin/activate
```
</details>
### Run Stable Diffusion on your device - WebUI
#### Windows 10/11 Users
<details>
<summary>Windows 10/11 Users</summary>
```powershell
(shark.venv) PS C:\Users\nod\SHARK> cd web
(shark.venv) PS C:\Users\nod\SHARK\web> python index.py
```
#### Linux Users
</details>
<details>
<summary>Linux Users</summary>
```shell
(shark.venv) > cd web
(shark.venv) > python index.py
```
</details>
### Run Stable Diffusion on your device - Commandline
#### Windows 10/11 Users
<details>
<summary>Windows 10/11 Users</summary>
```powershell
(shark.venv) PS C:\g\shark> python .\shark\examples\shark_inference\stable_diffusion\main.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
</details>
#### Linux
<details>
<summary>Linux</summary>
```shell
python3.10 shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
</details>
The output on a 6900XT would like:
The output on a 7900XTX would like:
```shell
44it [00:08, 5.14it/s]i = 44 t = 120 (191ms)
45it [00:08, 5.15it/s]i = 45 t = 100 (191ms)
46it [00:08, 5.16it/s]i = 46 t = 80 (191ms)
47it [00:09, 5.16it/s]i = 47 t = 60 (193ms)
48it [00:09, 5.15it/s]i = 48 t = 40 (195ms)
49it [00:09, 5.12it/s]i = 49 t = 20 (196ms)
50it [00:09, 5.14it/s]
Average step time: 192.8154182434082ms/it
Total image generation runtime (s): 10.390909433364868
(shark.venv) PS C:\g\shark>
Stats for run 0:
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
Total image generation time: 2.5788655281066895sec
```
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)
</details>
<details>
<details>
<summary>Discord link</summary>
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
</details>

View File

@@ -0,0 +1,15 @@
You need to pre-create your bot (https://core.telegram.org/bots#how-do-i-create-a-bot)
Then create in the directory web file .env
In it the record:
TG_TOKEN="your_token"
specifying your bot's token from previous step.
Then run telegram_bot.py with the same parameters that you use when running index.py, for example:
python telegram_bot.py --max_length=77 --vulkan_large_heap_block_size=0 --use_base_vae --local_tank_cache h:\shark\TEMP
Bot commands:
/select_model
/select_scheduler
/set_steps "integer number of steps"
/set_guidance_scale "integer number"
/set_negative_prompt "negative text"
Any other text triggers the creation of an image based on it.

View File

@@ -1,4 +1,5 @@
import os
import gc
import torch
from shark.shark_inference import SharkInference
from stable_args import args
@@ -7,17 +8,26 @@ from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from resources import opt_flags
from sd_annotation import sd_model_annotation
import sys
def get_vmfb_path_name(model_name):
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
return [vmfb_path, extended_name]
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
[vmfb_path, extended_name] = get_vmfb_path_name(model_name)
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
@@ -46,6 +56,8 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
@@ -59,9 +71,41 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
# Converts the torch-module into a shark_module.
def compile_through_fx(model, inputs, model_name, extra_args=[]):
def compile_through_fx(
model,
inputs,
model_name,
is_f16=False,
f16_input_mask=None,
use_tuned=False,
extra_args=[],
):
mlir_module, func_name = import_with_fx(model, inputs)
from shark.parser import shark_args
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
)
if use_tuned:
model_name = model_name + "_tuned"
tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir"
if not os.path.exists(tuned_model_path):
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
tuned_model, tuned_model_path = sd_model_annotation(
mlir_module, model_name
)
del mlir_module, tuned_model
gc.collect()
with open(tuned_model_path, "rb") as f:
mlir_module = f.read()
f.close()
shark_module = SharkInference(
mlir_module,
@@ -180,27 +224,51 @@ def set_init_device_flags():
args.device = "cpu"
# set max_length based on availability.
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
if args.hf_model_id in [
"Linaqruf/anything-v3.0",
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]:
args.max_length = 77
elif args.variant == "openjourney":
elif args.hf_model_id == "prompthero/openjourney":
args.max_length = 64
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
if (
args.variant in ["openjourney", "dreamlike"]
args.hf_model_id
in ["prompthero/openjourney", "dreamlike-art/dreamlike-diffusion-1.0"]
or args.precision != "fp16"
or "vulkan" not in args.device
or "rdna3" not in args.iree_vulkan_target_triple
or args.height != 512
or args.width != 512
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
args.use_tuned = False
print("Tuned models are currently not supported for this setting.")
elif args.use_base_vae and args.variant != "stablediffusion":
elif (
"vulkan" in args.device
and "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
elif "cuda" in args.device and get_cuda_sm_cc() not in [
"sm_80",
"sm_84",
"sm_86",
"sm_89",
]:
args.use_tuned = False
elif args.use_base_vae and args.hf_model_id not in [
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.use_tuned = False
print("Tuned models are currently not supported for this setting.")
if args.use_tuned:
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
print(f"Using {args.device} tuned models for stablediffusion/fp16.")
else:
print("Tuned models are currently not supported for this setting.")
# Utility to get list of devices available.
@@ -229,3 +297,87 @@ def get_available_devices():
available_devices.extend(cuda_devices)
available_devices.append("cpu")
return available_devices
def disk_space_check(path, lim=20):
from shutil import disk_usage
du = disk_usage(path)
free = du.free / (1024 * 1024 * 1024)
if free <= lim:
print(f"[WARNING] Only {free:.2f}GB space available in {path}.")
def get_opt_flags(model, precision="fp16"):
iree_flags = []
is_tuned = "tuned" if args.use_tuned else "untuned"
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
"default_compilation_flags"
]
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
if (
device
not in opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
]
):
device = "default_device"
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
return iree_flags
def preprocessCKPT():
from pathlib import Path
path = Path(args.ckpt_loc)
diffusers_path = path.parent.absolute()
diffusers_directory_name = path.stem
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
print(
"Created directory : ",
diffusers_directory_name,
" at -> ",
diffusers_path,
)
path_to_diffusers = complete_path_to_diffusers.as_posix()
# TODO: Use the SD to Diffusers CKPT pipeline once it's included in the release.
sd_to_diffusers = os.path.join(os.getcwd(), "sd_to_diffusers.py")
if not os.path.isfile(sd_to_diffusers):
url = "https://raw.githubusercontent.com/huggingface/diffusers/8a3f0c1f7178f4a3d5a5b21ae8c2906f473e240d/scripts/convert_original_stable_diffusion_to_diffusers.py"
import requests
req = requests.get(url)
open(sd_to_diffusers, "wb").write(req.content)
print("Downloaded SD to Diffusers converter")
else:
print("SD to Diffusers converter already exists")
os.system(
"python "
+ sd_to_diffusers
+ " --checkpoint_path="
+ args.ckpt_loc
+ " --dump_path="
+ path_to_diffusers
)
args.ckpt_loc = path_to_diffusers
print("Custom model path is : ", args.ckpt_loc)

View File

@@ -0,0 +1,21 @@
import requests
from PIL import Image
from io import BytesIO
from pipeline_shark_stable_diffusion_upscale import (
SharkStableDiffusionUpscalePipeline,
)
import torch
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipeline = SharkStableDiffusionUpscalePipeline(model_id)
# let's download an image
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
response = requests.get(url)
low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
low_res_img = low_res_img.resize((128, 128))
prompt = "a white cat"
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
upscaled_image.save("upsampled_cat.png")

View File

@@ -0,0 +1,99 @@
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from utils import compile_through_fx
import torch
model_id = "stabilityai/stable-diffusion-x4-upscaler"
model_input = {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 128, 128),),
"unet": (
torch.randn(2, 7, 128, 128), # latents
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 1024), # embedding
torch.randn(2).to(torch.int64), # noise_level
),
}
def get_clip_mlir(model_name="clip_text", extra_args=[]):
text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
)
class CLIPText(torch.nn.Module):
def __init__(self):
super().__init__()
self.text_encoder = text_encoder
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText()
shark_clip = compile_through_fx(
clip_model,
model_input["clip"],
model_name=model_name,
extra_args=extra_args,
)
return shark_clip
def get_vae_mlir(model_name="vae", extra_args=[]):
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
)
def forward(self, input):
x = self.vae.decode(input, return_dict=False)[0]
return x
vae = VaeModel()
shark_vae = compile_through_fx(
vae,
model_input["vae"],
model_name=model_name,
extra_args=extra_args,
)
return shark_vae
def get_unet_mlir(model_name="unet", extra_args=[]):
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, latent, timestep, text_embedding, noise_level):
unet_out = self.unet.forward(
latent,
timestep,
text_embedding,
noise_level,
return_dict=False,
)[0]
return unet_out
unet = UnetModel()
f16_input_mask = (True, True, True, False)
shark_unet = compile_through_fx(
unet,
model_input["unet"],
model_name=model_name,
is_f16=True,
f16_input_mask=f16_input_mask,
extra_args=extra_args,
)
return shark_unet

View File

@@ -0,0 +1,53 @@
import sys
from model_wrappers import (
get_vae_mlir,
get_unet_mlir,
get_clip_mlir,
)
from upscaler_args import args
from utils import get_shark_model
BATCH_SIZE = len(args.prompts)
if BATCH_SIZE != 1:
sys.exit("Only batch size 1 is supported.")
unet_flag = [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
vae_flag = [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
clip_flag = [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops",
]
bucket = "gs://shark_tank/stable_diffusion/"
def get_unet():
model_name = "upscaler_unet"
if args.import_mlir:
return get_unet_mlir(model_name, unet_flag)
return get_shark_model(bucket, model_name, unet_flag)
def get_vae():
model_name = "upscaler_vae"
if args.import_mlir:
return get_vae_mlir(model_name, vae_flag)
return get_shark_model(bucket, model_name, vae_flag)
def get_clip():
model_name = "upscaler_clip"
if args.import_mlir:
return get_clip_mlir(model_name, clip_flag)
return get_shark_model(bucket, model_name, clip_flag)

View File

@@ -0,0 +1,490 @@
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import PIL
from PIL import Image
from diffusers.utils import is_accelerate_available
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers import (
DDIMScheduler,
DDPMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers import logging
from diffusers.pipeline_utils import ImagePipelineOutput
from opt_params import get_unet, get_vae, get_clip
from tqdm.auto import tqdm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
def shark_run_wrapper(model, *args):
np_inputs = tuple([x.detach().numpy() for x in args])
outputs = model("forward", np_inputs)
return torch.from_numpy(outputs)
class SharkStableDiffusionUpscalePipeline:
def __init__(
self,
model_id,
):
self.tokenizer = CLIPTokenizer.from_pretrained(
model_id, subfolder="tokenizer"
)
self.low_res_scheduler = DDPMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
self.scheduler = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
self.vae = get_vae()
self.unet = get_unet()
self.text_encoder = get_clip()
self.max_noise_level = (350,)
self._execution_device = "cpu"
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
# if (
# hasattr(self.text_encoder.config, "use_attention_mask")
# and self.text_encoder.config.use_attention_mask
# ):
# attention_mask = text_inputs.attention_mask.to(device)
# else:
# attention_mask = None
text_embeddings = shark_run_wrapper(
self.text_encoder, text_input_ids.to(device)
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
# if (
# hasattr(self.text_encoder.config, "use_attention_mask")
# and self.text_encoder.config.use_attention_mask
# ):
# attention_mask = uncond_input.attention_mask.to(device)
# else:
# attention_mask = None
uncond_embeddings = shark_run_wrapper(
self.text_encoder,
uncond_input.input_ids.to(device),
)
uncond_embeddings = uncond_embeddings
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(
1, num_images_per_prompt, 1
)
uncond_embeddings = uncond_embeddings.view(
batch_size * num_images_per_prompt, seq_len, -1
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
def decode_latents(self, latents):
latents = 1 / 0.08333 * latents
image = shark_run_wrapper(self.vae, latents)
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def check_inputs(self, prompt, image, noise_level, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
)
# verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, list) or isinstance(image, torch.Tensor):
if isinstance(prompt, str):
batch_size = 1
else:
batch_size = len(prompt)
if isinstance(image, list):
image_batch_size = len(image)
else:
image_batch_size = image.shape[0]
if batch_size != image_batch_size:
raise ValueError(
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
" Please make sure that passed `prompt` matches the batch size of `image`."
)
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [
Image.fromarray(image.squeeze(), mode="L") for image in images
]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
if device == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(
shape, generator=generator, device="cpu", dtype=dtype
).to(device)
else:
latents = torch.randn(
shape, generator=generator, device=device, dtype=dtype
)
else:
if latents.shape != shape:
raise ValueError(
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
)
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[
torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]
],
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[
Union[torch.Generator, List[torch.Generator]]
] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[
Callable[[int, int, torch.FloatTensor], None]
] = None,
callback_steps: Optional[int] = 1,
):
# 1. Check inputs
self.check_inputs(prompt, image, noise_level, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
# 4. Preprocess image
image = preprocess(image)
image = image.to(dtype=text_embeddings.dtype, device=device)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Add noise to image
noise_level = torch.tensor(
[noise_level], dtype=torch.long, device=device
)
if device == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(
image.shape,
generator=generator,
device="cpu",
dtype=text_embeddings.dtype,
).to(device)
else:
noise = torch.randn(
image.shape,
generator=generator,
device=device,
dtype=text_embeddings.dtype,
)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1
image = torch.cat([image] * batch_multiplier * num_images_per_prompt)
noise_level = torch.cat([noise_level] * image.shape[0])
# 6. Prepare latent variables
height, width = image.shape[2:]
# num_channels_latents = self.vae.config.latent_channels
num_channels_latents = 4
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 7. Check that sizes of image and latents match
num_channels_image = image.shape[1]
# if (
# num_channels_latents + num_channels_image
# != self.unet.config.in_channels
# ):
# raise ValueError(
# f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
# f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
# f" `num_channels_image`: {num_channels_image} "
# f" = {num_channels_latents+num_channels_image}. Please verify the config of"
# " `pipeline.unet` or your `image` input."
# )
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 9. Denoising loop
num_warmup_steps = (
len(timesteps) - num_inference_steps * self.scheduler.order
)
for i, t in tqdm(enumerate(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2)
if do_classifier_free_guidance
else latents
)
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = torch.cat([latent_model_input, image], dim=1)
timestep = torch.tensor([t]).to(torch.float32)
# predict the noise residual
noise_pred = shark_run_wrapper(
self.unet,
latent_model_input.half(),
timestep,
text_embeddings.half(),
noise_level,
)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
# # call the callback, if provided
# if i == len(timesteps) - 1 or (
# (i + 1) > num_warmup_steps
# and (i + 1) % self.scheduler.order == 0
# ):
# progress_bar.update()
# if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
# 10. Post-processing
# make sure the VAE is in float32 mode, as it overflows in float16
# self.vae.to(dtype=torch.float32)
image = self.decode_latents(latents.float())
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@@ -0,0 +1,111 @@
import argparse
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Stable Diffusion Params
##############################################################################
p.add_argument(
"--prompts",
nargs="+",
default=["cyberpunk forest by Salvador Dali"],
help="text of which images to be generated.",
)
p.add_argument(
"--negative-prompts",
nargs="+",
default=[""],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=int,
default=42,
help="the seed to use.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="the value to be used for guidance scaling.",
)
##############################################################################
### Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree-vulkan-target-triple",
type=str,
default="",
help="Specify target triple for vulkan",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="4147483648",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
)
args = p.parse_args()

View File

@@ -0,0 +1,234 @@
import os
import torch
from shark.shark_inference import SharkInference
from upscaler_args import args
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), extended_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.shark_downloader import download_model
from shark.parser import shark_args
# Set local shark_tank cache directory.
# shark_args.local_tank_cache = args.local_tank_cache
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=tank_url,
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
)
return _compile_module(shark_module, model_name, extra_args)
# Converts the torch-module into a shark_module.
def compile_through_fx(
model, inputs, model_name, is_f16=False, f16_input_mask=None, extra_args=[]
):
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def set_init_device_flags():
if "vulkan" in args.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "cpu" in args.device:
args.device = "cpu"
# set max_length based on availability.
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
args.max_length = 77
elif args.variant == "openjourney":
args.max_length = 64
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
if (
args.variant in ["openjourney", "dreamlike"]
or args.precision != "fp16"
or "vulkan" not in args.device
or "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
print("Tuned models are currently not supported for this setting.")
elif args.use_base_vae and args.variant != "stablediffusion":
args.use_tuned = False
print("Tuned models are currently not supported for this setting.")
if args.use_tuned:
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
# Utility to get list of devices available.
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
for i, device in enumerate(device_list_dict):
device_list.append(f"{driver_name}://{i} => {device['name']}")
return device_list
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
return available_devices

View File

@@ -1,7 +1,7 @@
import torch
from torch.nn.utils import _stateless
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark.shark_runner import SharkTrainer
from shark.shark_trainer import SharkTrainer
class MiniLMSequenceClassification(torch.nn.Module):
@@ -42,6 +42,7 @@ def forward(params, buffers, args):
return params, buffers
shark_module = SharkTrainer(mod, inp, custom_inference_fn=forward)
shark_module = SharkTrainer(mod, inp)
shark_module.compile(forward)
print(shark_module.forward())
print(shark_module.train())

View File

@@ -15,6 +15,7 @@
# All the iree_cpu related functionalities go here.
import subprocess
import platform
def get_cpu_count():
@@ -29,25 +30,16 @@ def get_cpu_count():
# Get the default cpu args.
def get_iree_cpu_args():
find_triple_cmd = "uname -s -m"
os_name, proc_name = (
subprocess.run(
find_triple_cmd, shell=True, stdout=subprocess.PIPE, check=True
)
.stdout.decode("utf-8")
.split()
)
uname = platform.uname()
os_name, proc_name = uname.system, uname.machine
if os_name == "Darwin":
find_kernel_version_cmd = "uname -r"
kernel_version = subprocess.run(
find_kernel_version_cmd,
shell=True,
stdout=subprocess.PIPE,
check=True,
).stdout.decode("utf-8")
kernel_version = uname.release
target_triple = f"{proc_name}-apple-darwin{kernel_version}"
elif os_name == "Linux":
target_triple = f"{proc_name}-linux-gnu"
elif os_name == "Windows":
target_triple = "x86_64-pc-windows-msvc"
else:
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
raise Exception(error_message)

View File

@@ -0,0 +1,470 @@
# Copyright 2020 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
def get_vulkan_target_env(vulkan_target_triple):
arch, product, os = vulkan_target_triple.split("=")[1].split("-")
triple = (arch, product, os)
# get version
version = get_version(triple=triple)
# TODO get revision
revision = 120
# extensions
extensions = get_extensions(triple)
# get vendor
vendor = get_vendor(triple)
# get device type
device_type = get_device_type(triple)
# get capabilities
capabilities = get_vulkan_target_capabilities(triple)
target_env = f"#vk.target_env<{version}, r({revision}), {extensions}, {vendor}:{device_type}, #vk.caps< {capabilities} >>"
return target_env
def get_vulkan_target_env_flag(vulkan_target_triple):
target_env = get_vulkan_target_env(vulkan_target_triple)
target_env_flag = f"--iree-vulkan-target-env={target_env}"
return target_env_flag
def get_version(triple):
arch, product, os = triple
if os in ["android30", "android31"]:
return "v1.1"
if product in ["android30", "android31"]:
return "v1.1"
if arch in ["unknown"]:
return "v1.1"
return "v1.3"
def get_extensions(triple):
def make_ext_list(ext_list):
res = ""
for e in ext_list:
res += e + ", "
res = f"[{res[:-2]}]"
return res
arch, product, os = triple
if arch == "m1":
ext = [
"VK_KHR_16bit_storage",
"VK_KHR_8bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
]
return make_ext_list(ext_list=ext)
if arch == "valhall":
ext = [
"VK_KHR_16bit_storage",
"VK_KHR_8bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_spirv_1_4",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
]
return make_ext_list(ext_list=ext)
if arch == "adreno":
ext = [
"VK_KHR_16bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_spirv_1_4",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
]
if os == "android31":
ext.append("VK_KHR_8bit_storage")
return make_ext_list(ext_list=ext)
if get_vendor(triple) == "SwiftShader":
ext = ["VK_KHR_storage_buffer_storage_class"]
return make_ext_list(ext_list=ext)
if arch == "unknown":
ext = [
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
]
return make_ext_list(ext_list=ext)
ext = [
"VK_KHR_16bit_storage",
"VK_KHR_8bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_spirv_1_4",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
"VK_EXT_subgroup_size_control",
]
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
ext.append("VK_NV_cooperative_matrix")
return make_ext_list(ext_list=ext)
def get_vendor(triple):
arch, product, os = triple
if arch == "unknown":
return "Unknown"
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn4", "rgcn5"]:
return "AMD"
if arch == "valhall":
return "ARM"
if arch == "m1":
return "Apple"
if arch in ["turing", "ampere"]:
return "NVIDIA"
if arch == "ardeno":
return "Qualcomm"
if arch == "cpu":
if product == "swiftshader":
return "SwiftShader"
return "Unknown"
print(f"Vendor for target triple - {triple} not found. Using unknown")
return "Unknown"
def get_device_type(triple):
arch, product, _ = triple
if arch == "unknown":
return "Unknown"
if arch == "cpu":
return "CPU"
if arch in ["turing", "ampere"]:
return "DiscreteGPU"
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]:
if product == "ivega10":
return "IntegratedGPU"
return "DiscreteGPU"
if arch in ["m1", "valhall", "adreno"]:
return "IntegratedGPU"
print(f"Device type for target triple - {triple} not found. Using unknown")
return "Unknown"
# get all the capabilities for the device
# TODO: make a dataclass for capabilites and init using vulkaninfo
def get_vulkan_target_capabilities(triple):
def get_subgroup_val(l):
return int(sum([subgroup_feature[sgf] for sgf in l]))
cap = OrderedDict()
arch, product, os = triple
subgroup_feature = {
"Basic": 1,
"Vote": 2,
"Arithmetic": 4,
"Ballot": 8,
"Shuffle": 16,
"ShuffleRelative": 32,
"Clustered": 64,
"Quad": 128,
"PartitionedNV": 256,
}
cap["maxComputeSharedMemorySize"] = 16384
cap["maxComputeWorkGroupInvocations"] = 128
cap["maxComputeWorkGroupSize"] = [128, 128, 64]
cap["subgroupSize"] = 32
cap["subgroupFeatures"] = ["Basic"]
cap["minSubgroupSize"] = None
cap["maxSubgroupSize"] = None
cap["shaderFloat16"] = False
cap["shaderFloat64"] = False
cap["shaderInt8"] = False
cap["shaderInt16"] = False
cap["shaderInt64"] = False
cap["storageBuffer16BitAccess"] = False
cap["storagePushConstant16"] = False
cap["uniformAndStorageBuffer16BitAccess"] = False
cap["storageBuffer8BitAccess"] = False
cap["storagePushConstant8"] = False
cap["uniformAndStorageBuffer8BitAccess"] = False
cap["variablePointers"] = False
cap["variablePointersStorageBuffer"] = False
cap["coopmatCases"] = None
if arch in ["rdna1", "rdna2", "rdna3"]:
cap["maxComputeSharedMemorySize"] = 65536
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
cap["subgroupSize"] = 64
cap["minSubgroupSize"] = 32
cap["maxSubgroupSize"] = 64
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderFloat64"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
if arch == "rdna3":
# TODO: Get scope value
cap["coopmatCases"] = [
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
]
if product == "rx5700xt":
cap["storagePushConstant16"] = False
cap["storagePushConstant8"] = False
elif arch in ["rgcn5", "rgcn4", "rgcn3"]:
cap["maxComputeSharedMemorySize"] = 65536
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
cap["subgroupSize"] = 64
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["minSubgroupSize"] = 64
cap["maxSubgroupSize"] = 64
if arch == "rgcn5":
cap["shaderFloat16"] = True
cap["shaderFloat64"] = True
cap["storageBuffer16BitAccess"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["storagePushConstant16"] = False
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = False
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "m1":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
cap["subgroupSize"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderFloat64"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "valhall":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 512
cap["maxComputeWorkGroupSize"] = [512, 512, 512]
cap["subgroupSize"] = 16
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Clustered",
"Quad",
]
if os == "android31":
cap["subgroupFeatures"].append("Shuffle")
cap["subgroupFeatures"].append("ShuffleRelative")
cap["shaderFloat16"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "cpu":
if product == "swiftshader":
cap["maxComputeSharedMemorySize"] = 16384
cap["subgroupSize"] = 4
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
]
elif arch in ["ampere", "turing"]:
cap["maxComputeSharedMemorySize"] = 49152
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
cap["subgroupSize"] = 32
cap["minSubgroupSize"] = 32
cap["maxSubgroupSize"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderFloat64"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
cap["coopmatCases"] = [
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
]
elif arch == "adreno":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
cap["subgroupSize"] = 64
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["storageBuffer16BitAccess"] = True
if os == "andorid31":
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "unknown":
cap["subgroupSize"] = 64
cap["variablePointers"] = False
cap["variablePointersStorageBuffer"] = False
else:
print(
f"Architecture {arch} not matched. Using default vulkan target device capability"
)
def get_comma_sep_str(ele_list):
l = ""
for ele in ele_list:
l += f"{ele}, "
l = f"[{l[:-2]}]"
return l
res = ""
for k, v in cap.items():
if v is None or v == False:
continue
if isinstance(v, bool):
res += f"{k} = {'unit' if v == True else None}, "
elif isinstance(v, list):
if k == "subgroupFeatures":
res += f"subgroupFeatures = {get_subgroup_val(v)}: i32, "
elif k == "maxComputeWorkGroupSize":
res += f"maxComputeWorkGroupSize = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
elif k == "coopmatCases":
cmc = ""
for case in v:
cmc += f"#vk.coop_matrix_props<{case}>, "
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
else:
res += f"{k} = {get_comma_sep_str(v)}, "
else:
res += f"{k} = {v}, "
res = res[:-2]
return res

View File

@@ -18,6 +18,7 @@ from os import linesep
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_vulkan_device_name():
@@ -65,11 +66,24 @@ def get_vulkan_target_triple(device_name):
elif all(x in device_name for x in ("RTX", "2080")):
triple = f"turing-rtx2080-{system_os}"
elif all(x in device_name for x in ("A100", "SXM4")):
triple = f"ampere-rtx3080-{system_os}"
triple = f"ampere-a100-{system_os}"
elif all(x in device_name for x in ("RTX", "3090")):
triple = f"ampere-rtx3090-{system_os}"
elif all(x in device_name for x in ("RTX", "3080")):
triple = f"ampere-rtx3080-{system_os}"
elif all(x in device_name for x in ("RTX", "3070")):
triple = f"ampere-rtx3070-{system_os}"
elif all(x in device_name for x in ("RTX", "3060")):
triple = f"ampere-rtx3060-{system_os}"
elif all(x in device_name for x in ("RTX", "3050")):
triple = f"ampere-rtx3050-{system_os}"
# We use ampere until lovelace target triples are plumbed in.
elif all(x in device_name for x in ("RTX", "4090")):
triple = f"ampere-rtx3090-{system_os}"
triple = f"ampere-rtx4090-{system_os}"
elif all(x in device_name for x in ("RTX", "4080")):
triple = f"ampere-rtx4080-{system_os}"
elif all(x in device_name for x in ("RTX", "4070")):
triple = f"ampere-rtx4070-{system_os}"
elif all(x in device_name for x in ("RTX", "4000")):
triple = f"turing-rtx4000-{system_os}"
elif all(x in device_name for x in ("RTX", "5000")):
@@ -88,7 +102,9 @@ def get_vulkan_target_triple(device_name):
triple = f"pascal-gtx1080-{system_os}"
# Amd Targets
elif all(x in device_name for x in ("AMD", "7900")):
# Linux: Radeon RX 7900 XTX
# Windows: AMD Radeon RX 7900 XTX
elif all(x in device_name for x in ("RX", "7900")):
triple = f"rdna3-7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
triple = f"rdna2-unknown-{system_os}"
@@ -97,15 +113,16 @@ def get_vulkan_target_triple(device_name):
return triple
def get_vulkan_triple_flag(device_name=None, extra_args=[]):
def get_vulkan_triple_flag(device_name="", extra_args=[]):
for flag in extra_args:
if "-iree-vulkan-target-triple=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None
vulkan_device = (
device_name if device_name is not None else get_vulkan_device_name()
)
if device_name == "" or device_name == [] or device_name is None:
vulkan_device = get_vulkan_device_name()
else:
vulkan_device = device_name
triple = get_vulkan_target_triple(vulkan_device)
if triple is not None:
print(
@@ -122,11 +139,23 @@ def get_vulkan_triple_flag(device_name=None, extra_args=[]):
def get_iree_vulkan_args(extra_args=[]):
vulkan_flag = []
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
res_vulkan_flag = []
vulkan_triple_flag = None
for arg in extra_args:
if "-iree-vulkan-target-triple=" in arg:
print(f"Using target triple {arg} from command line args")
vulkan_triple_flag = arg
break
if vulkan_triple_flag is None:
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
if vulkan_triple_flag is not None:
vulkan_flag.append(vulkan_triple_flag)
return vulkan_flag
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)
res_vulkan_flag.append(vulkan_target_env)
return res_vulkan_flag
def set_iree_vulkan_runtime_flags(flags):

View File

@@ -47,6 +47,9 @@ def model_annotation(
input_contents = f.read()
module = ir.Module.parse(input_contents)
if config_path == "":
return module
if winograd:
with open(config_path, "r") as f:
data = json.load(f)
@@ -162,7 +165,6 @@ def walk_children(
add_attributes(
child_op, configs[child_op_shape]["options"][0]
)
print(f"Updated op {child_op}", file=sys.stderr)
walk_children(child_op, configs, search_op, winograd)
@@ -394,7 +396,6 @@ def add_winograd_attribute(op: ir.Operation, config: List):
op.attributes["iree_winograd_conv"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), 1
)
print("Apply Winograd on selected conv op: ", op)
def add_attribute_by_name(op: ir.Operation, name: str, val: int):

View File

@@ -23,6 +23,8 @@ from datetime import datetime
import time
import csv
import os
import torch
import torch._dynamo as dynamo
class OnnxFusionOptions(object):
@@ -65,6 +67,7 @@ class SharkBenchmarkRunner(SharkRunner):
extra_args: list = [],
):
self.device = shark_args.device if device == "none" else device
self.enable_tf32 = shark_args.enable_tf32
self.frontend_model = None
self.vmfb_file = None
self.mlir_dialect = mlir_dialect
@@ -107,6 +110,8 @@ class SharkBenchmarkRunner(SharkRunner):
if self.device == "cuda":
torch.set_default_tensor_type(torch.cuda.FloatTensor)
if self.enable_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
else:
torch.set_default_tensor_type(torch.FloatTensor)
torch_device = torch.device(
@@ -114,6 +119,7 @@ class SharkBenchmarkRunner(SharkRunner):
)
HFmodel, input = get_torch_model(modelname)[:2]
frontend_model = HFmodel.model
frontend_model = dynamo.optimize("inductor")(frontend_model)
frontend_model.to(torch_device)
input.to(torch_device)
@@ -333,7 +339,10 @@ for currently supported models. Exiting benchmark ONNX."
else:
bench_result["shape_type"] = "static"
bench_result["device"] = device_str
bench_result["data_type"] = inputs[0].dtype
if "fp16" in modelname:
bench_result["data_type"] = "float16"
else:
bench_result["data_type"] = inputs[0].dtype
for e in engines:
(
bench_result["param_count"],

View File

@@ -169,9 +169,12 @@ def download_model(
os.path.join(model_dir, "upstream_hash.npy"),
single_file=True,
)
upstream_hash = str(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
try:
upstream_hash = str(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
except FileNotFoundError:
upstream_hash = None
if local_hash != upstream_hash:
print(
"Hash does not match upstream in gs://shark_tank/latest. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."

View File

@@ -55,6 +55,7 @@ class SharkImporter:
inputs: tuple = (),
frontend: str = "torch",
raw_model_file: str = "",
return_str: bool = False,
):
self.module = module
self.inputs = None if len(inputs) == 0 else inputs
@@ -65,6 +66,7 @@ class SharkImporter:
)
sys.exit(1)
self.raw_model_file = raw_model_file
self.return_str = return_str
# NOTE: The default function for torch is "forward" and tf-lite is "main".
@@ -72,7 +74,11 @@ class SharkImporter:
from shark.torch_mlir_utils import get_torch_mlir_module
return get_torch_mlir_module(
self.module, self.inputs, is_dynamic, tracing_required
self.module,
self.inputs,
is_dynamic,
tracing_required,
self.return_str,
)
def _tf_mlir(self, func_name, save_dir="./shark_tmp/"):
@@ -245,8 +251,120 @@ class SharkImporter:
)
def get_f16_inputs(inputs, is_f16, f16_input_mask):
if is_f16 == False:
return inputs
if f16_input_mask == None:
return tuple([x.half() for x in inputs])
f16_masked_inputs = []
for i in range(len(inputs)):
if f16_input_mask[i]:
f16_masked_inputs.append(inputs[i].half())
else:
f16_masked_inputs.append(inputs[i])
return tuple(f16_masked_inputs)
def transform_fx(fx_g):
import torch
kwargs_dict = {
"dtype": torch.float16,
"device": torch.device(type="cpu"),
"pin_memory": False,
}
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
]:
node.kwargs = kwargs_dict
# Inputs and outputs of aten.var.mean should be upcasted to fp32.
if node.target in [torch.ops.aten.var_mean]:
with fx_g.graph.inserting_before(node):
new_node = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
node.args = (new_node, node.args[1])
if node.name.startswith("getitem"):
with fx_g.graph.inserting_before(node):
if node.args[0].target in [torch.ops.aten.var_mean]:
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node,),
kwargs={"dtype": torch.float16},
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
new_node.kwargs = {"dtype": torch.float16}
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
# Doesn't replace the None type.
def change_fx_graph_return_to_tuple(fx_g):
for node in fx_g.graph.nodes:
if node.op == "output":
# output nodes always have one argument
node_arg = node.args[0]
out_nodes = []
if isinstance(node_arg, list):
# Don't return NoneType elements.
for out_node in node_arg:
if not isinstance(out_node, type(None)):
out_nodes.append(out_node)
# If there is a single tensor/element to be returned don't
# a tuple for it.
if len(out_nodes) == 1:
node.args = out_nodes
else:
node.args = (tuple(out_nodes),)
fx_g.graph.lint()
fx_g.recompile()
return fx_g
def flatten_training_input(inputs):
flattened_input = []
for i in inputs:
if isinstance(i, dict):
for value in i.values():
flattened_input.append(value.detach())
elif isinstance(i, tuple):
for value in i:
flattened_input.append(value)
else:
flattened_input.append(i)
return tuple(flattened_input)
# Applies fx conversion to the model and imports the mlir.
def import_with_fx(model, inputs, debug=False):
def import_with_fx(
model,
inputs,
is_f16=False,
f16_input_mask=None,
debug=False,
training=False,
return_str=False,
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
@@ -286,16 +404,27 @@ def import_with_fx(model, inputs, debug=False):
strip_overloads(fx_g)
if is_f16:
fx_g = fx_g.half()
transform_fx(fx_g)
fx_g.recompile()
if training:
change_fx_graph_return_to_tuple(fx_g)
inputs = flatten_training_input(inputs)
ts_graph = torch.jit.script(fx_g)
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)
mlir_importer = SharkImporter(
fx_g,
ts_graph,
inputs,
frontend="torch",
return_str=return_str,
)
if debug:
if debug and not is_f16:
(mlir_module, func_name), _, _ = mlir_importer.import_debug()
return mlir_module, func_name
mlir_module, func_name = mlir_importer.import_mlir()
return mlir_module, func_name

View File

@@ -15,6 +15,7 @@
from shark.parser import shark_args
from shark.shark_runner import SharkRunner
from shark.backward_makefx import MakeFxModule
from shark.shark_importer import import_with_fx
import numpy as np
from tqdm import tqdm
import sys
@@ -67,23 +68,21 @@ class SharkTrainer:
self.frontend = frontend
# Training function is needed in the case of torch_fn.
def compile(self, training_fn=None):
def compile(self, training_fn=None, extra_args=[]):
if self.frontend in ["torch", "pytorch"]:
aot_module = MakeFxModule(
self.model, tuple(self.input), custom_inference_fn=training_fn
packed_inputs = (
dict(self.model.named_parameters()),
dict(self.model.named_buffers()),
tuple(self.input),
)
mlir_module, func_name = import_with_fx(
training_fn, packed_inputs, False, [], training=True
)
aot_module.generate_graph()
# Returns the backward graph.
training_graph = aot_module.training_graph
weights = self.get_torch_params()
self.shark_runner = SharkRunner(
training_graph,
weights + self.input,
self.dynamic,
mlir_module,
self.device,
self.jit_trace,
self.from_aot,
self.frontend,
"tm_tensor",
extra_args=extra_args,
)
elif self.frontend in ["tensorflow", "tf", "mhlo"]:
self.shark_runner = SharkRunner(
@@ -112,8 +111,8 @@ class SharkTrainer:
params = [x.numpy() for x in params]
print(f"Training started for {num_iters} iterations:")
for i in tqdm(range(num_iters)):
params = self.shark_runner.forward(
params + self.input, self.frontend
params = self.shark_runner.run(
"forward", params + self.input, self.frontend
)
return params

View File

@@ -56,6 +56,7 @@ def get_torch_mlir_module(
input: tuple,
dynamic: bool,
jit_trace: bool,
return_str: bool = False,
):
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
ignore_traced_shapes = False
@@ -73,6 +74,8 @@ def get_torch_mlir_module(
use_tracing=jit_trace,
ignore_traced_shapes=ignore_traced_shapes,
)
if return_str:
return mlir_module.operation.get_asm()
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()

View File

@@ -14,9 +14,10 @@ microsoft/MiniLM-L12-H384-uncased,mhlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,
microsoft/layoutlm-base-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,""
microsoft/mpnet-base,mhlo,tf,1e-2,1e-2,default,None,False,False,False,""
albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with aten.tanh in torch-mlir"
alexnet,linalg,torch,1e-2,1e-3,default,None,False,False,True,"Assertion Error: Zeros Output"
alexnet,linalg,torch,1e-2,1e-3,default,None,True,False,True,"https://github.com/nod-ai/SHARK/issues/879"
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,False,False,""
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,False,True,""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile."
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311"
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390"
@@ -28,6 +29,7 @@ nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,True,"Vulkan Numerical Error (mostly conv)"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,True,""
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,True,"Vulkan Numerical Error (mostly conv)"
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc,True,False,True,""
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,True,"https://github.com/nod-ai/SHARK/issues/388"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,True,"Vulkan Numerical Error (mostly conv)"
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,True,"https://github.com/nod-ai/SHARK/issues/575"
1 resnet50 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False True Vulkan Numerical Error: mostly conv
14 microsoft/layoutlm-base-uncased mhlo tf 1e-2 1e-3 default None False False False
15 microsoft/mpnet-base mhlo tf 1e-2 1e-2 default None False False False
16 albert-base-v2 linalg torch 1e-2 1e-3 default None True True True issue with aten.tanh in torch-mlir
17 alexnet linalg torch 1e-2 1e-3 default None False True False True Assertion Error: Zeros Output https://github.com/nod-ai/SHARK/issues/879
18 bert-base-cased linalg torch 1e-2 1e-3 default None False False False
19 bert-base-uncased linalg torch 1e-2 1e-3 default None False False False
20 bert-base-uncased_fp16 linalg torch 1e-1 1e-1 default None True False True
21 facebook/deit-small-distilled-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True False Fails during iree-compile.
22 google/vit-base-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/311
23 microsoft/beit-base-patch16-224-pt22k-ft22k linalg torch 1e-2 1e-3 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/390
29 resnet101 linalg torch 1e-2 1e-3 default nhcw-nhwc False False True Vulkan Numerical Error (mostly conv)
30 resnet18 linalg torch 1e-2 1e-3 default None True True True
31 resnet50 linalg torch 1e-2 1e-3 default nhcw-nhwc False False True Vulkan Numerical Error (mostly conv)
32 resnet50_fp16 linalg torch 1e-2 1e-2 default nhcw-nhwc True False True
33 squeezenet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc False False True https://github.com/nod-ai/SHARK/issues/388
34 wide_resnet50_2 linalg torch 1e-2 1e-3 default nhcw-nhwc False False True Vulkan Numerical Error (mostly conv)
35 efficientnet-v2-s mhlo tf 1e-02 1e-3 default nhcw-nhwc False False True https://github.com/nod-ai/SHARK/issues/575

View File

@@ -2,12 +2,14 @@ model_name, use_tracing, dynamic, param_count, tags, notes
microsoft/MiniLM-L12-H384-uncased,True,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
albert-base-v2,True,True,11M,"nlp;bert-variant;transformer-encoder","12 layers; 128 embedding dim; 768 hidden dim; 12 attention heads; Smaller than BERTbase (11M params vs 109M params); Uses weight sharing to reduce # params but computational cost is similar to BERT."
bert-base-uncased,True,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-uncased_fp16,True,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
distilbert-base-uncased,True,True,66M,"nlp;bert-variant;transformer-encoder","Smaller and faster than BERT with 97percent retained accuracy."
google/mobilebert-uncased,True,True,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
alexnet,False,True,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
resnet18,False,True,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
resnet50,False,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
resnet50_fp16,False,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
resnet101,False,True,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
squeezenet1_0,False,True,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
wide_resnet50_2,False,True,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
1 model_name use_tracing dynamic param_count tags notes
2 microsoft/MiniLM-L12-H384-uncased True True 66M nlp;bert-variant;transformer-encoder Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)
3 albert-base-v2 True True 11M nlp;bert-variant;transformer-encoder 12 layers; 128 embedding dim; 768 hidden dim; 12 attention heads; Smaller than BERTbase (11M params vs 109M params); Uses weight sharing to reduce # params but computational cost is similar to BERT.
4 bert-base-uncased True True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
5 bert-base-uncased_fp16 True True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
6 bert-base-cased True True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
7 distilbert-base-uncased True True 66M nlp;bert-variant;transformer-encoder Smaller and faster than BERT with 97percent retained accuracy.
8 google/mobilebert-uncased True True 25M nlp,bert-variant,transformer-encoder,mobile 24 layers, 512 hidden size, 128 embedding
9 alexnet False True 61M cnn,parallel-layers The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod.
10 resnet18 False True 11M cnn,image-classification,residuals,resnet-variant 1 7x7 conv2d and the rest are 3x3 conv2d
11 resnet50 False True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
12 resnet50_fp16 False True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
13 resnet101 False True 29M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
14 squeezenet1_0 False True 1.25M cnn,image-classification,mobile,parallel-layers Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)
15 wide_resnet50_2 False True 69M cnn,image-classification,residuals,resnet-variant Resnet variant where model depth is decreased and width is increased.

View File

@@ -12,6 +12,7 @@ vision_models = [
"resnet101",
"resnet18",
"resnet50",
"resnet50_fp16",
"squeezenet1_0",
"wide_resnet50_2",
"mobilenet_v3_small",
@@ -31,6 +32,8 @@ def get_torch_model(modelname):
return get_vision_model(modelname)
elif modelname in hf_img_cls_models:
return get_hf_img_cls_model(modelname)
elif "fp16" in modelname:
return get_fp16_model(modelname)
else:
return get_hf_model(modelname)
@@ -114,7 +117,6 @@ class HuggingFaceLanguage(torch.nn.Module):
def get_hf_model(name):
from transformers import (
BertTokenizer,
TFBertModel,
)
model = HuggingFaceLanguage(name)
@@ -146,6 +148,7 @@ def get_vision_model(torch_model):
"alexnet": models.alexnet(weights="DEFAULT"),
"resnet18": models.resnet18(weights="DEFAULT"),
"resnet50": models.resnet50(weights="DEFAULT"),
"resnet50_fp16": models.resnet50(weights="DEFAULT"),
"resnet101": models.resnet101(weights="DEFAULT"),
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
@@ -153,10 +156,26 @@ def get_vision_model(torch_model):
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
}
if isinstance(torch_model, str):
fp16_model = None
if "fp16" in torch_model:
fp16_model = True
torch_model = vision_models_dict[torch_model]
model = VisionModule(torch_model)
test_input = torch.randn(1, 3, 224, 224)
actual_out = model(test_input)
if fp16_model is not None:
test_input_fp16 = test_input.to(
device=torch.device("cuda"), dtype=torch.half
)
model_fp16 = model.half()
model_fp16.eval()
model_fp16.to("cuda")
actual_out_fp16 = model_fp16(test_input_fp16)
model, test_input, actual_out = (
model_fp16,
test_input_fp16,
actual_out_fp16,
)
return model, test_input, actual_out
@@ -164,6 +183,49 @@ def get_vision_model(torch_model):
####################### Other PyTorch HF Models ###############################
class BertHalfPrecisionModel(torch.nn.Module):
def __init__(self, hf_model_name):
super().__init__()
from transformers import AutoModelForMaskedLM
self.model = AutoModelForMaskedLM.from_pretrained(
hf_model_name, # The pretrained model.
num_labels=2, # The number of output labels--2 for binary classification.
output_attentions=False, # Whether the model returns attentions weights.
output_hidden_states=False, # Whether the model returns all hidden-states.
torchscript=True,
torch_dtype=torch.float16,
).to("cuda")
def forward(self, tokens):
return self.model.forward(tokens)[0]
def get_fp16_model(torch_model):
from transformers import AutoTokenizer
modelname = torch_model.replace("_fp16", "")
model = BertHalfPrecisionModel(modelname)
tokenizer = AutoTokenizer.from_pretrained(modelname)
text = "Replace me by any text you like."
test_input_fp16 = tokenizer(
text,
truncation=True,
max_length=128,
return_tensors="pt",
).input_ids.to("cuda")
# test_input = torch.randint(2, (1, 128))
# test_input_fp16 = test_input.to(
# device=torch.device("cuda")
# )
model_fp16 = model.half()
model_fp16.eval()
with torch.no_grad():
actual_out_fp16 = model_fp16(test_input_fp16)
return model_fp16, test_input_fp16, actual_out_fp16
# Utility function for comparing two tensors (torch).
def compare_tensors(torch_tensor, numpy_tensor, rtol=1e-02, atol=1e-03):
# torch_to_numpy = torch_tensor.detach().numpy()

View File

@@ -177,26 +177,11 @@ class SharkModuleTester:
if self.ci == True:
self.upload_repro()
if self.benchmark == True:
# p = multiprocessing.Process(
# target=self.benchmark_module,
# args=(shark_module, inputs, dynamic, device),
# )
# p.start()
# p.join()
self.benchmark_module(shark_module, inputs, dynamic, device)
print(msg)
pytest.xfail(reason="Numerics Issue")
pytest.xfail(reason="Numerics Issue, awaiting triage.")
if self.benchmark == True:
# We must create a new process each time we benchmark a model to allow
# for Tensorflow to release GPU resources. Using the same process to
# benchmark multiple models leads to OOM.
# p = multiprocessing.Process(
# target=self.benchmark_module,
# args=(shark_module, inputs, dynamic, device),
# )
# p.start()
# p.join()
self.benchmark_module(shark_module, inputs, dynamic, device)
if self.save_repro == True:

View File

@@ -16,3 +16,5 @@ facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,22M,"image-class
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encoder",SegFormer
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
1 model_name use_tracing model_type dynamic param_count tags notes
16 microsoft/beit-base-patch16-224-pt22k-ft22k True hf_img_cls False 86M image-classification,transformer-encoder,bert-variant,vision-transformer N/A
17 nvidia/mit-b0 True hf_img_cls False 3.7M image-classification,transformer-encoder SegFormer
18 mnasnet1_0 False vision True - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
19 resnet50_fp16 False vision True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
20 bert-base-uncased_fp16 True fp16 False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads

67
web/demo.css Normal file
View File

@@ -0,0 +1,67 @@
.gradio-container {
background-color: black
}
.container {
background-color: black !important;
padding-top: 20px !important;
}
#ui_title {
padding: 10px !important;
}
#top_logo {
background-color: transparent;
border-radius: 0 !important;
border: 0;
}
#demo_title {
background-color: black;
border-radius: 0 !important;
border: 0;
padding-top: 50px;
padding-bottom: 0px;
width: 460px !important;
}
#demo_title_outer {
border-radius: 0;
}
#prompt_box_outer div:first-child {
border-radius: 0 !important
}
#prompt_box textarea {
background-color: #1d1d1d !important
}
#prompt_examples {
margin: 0 !important
}
#prompt_examples svg {
display: none !important;
}
.gr-sample-textbox {
border-radius: 1rem !important;
border-color: rgb(31, 41, 55) !important;
border-width: 2px !important;
}
#ui_body {
background-color: #111111 !important;
padding: 10px !important;
border-radius: 0.5em !important;
}
#img_result+div {
display: none !important;
}
footer {
display: none !important;
}

View File

@@ -1,6 +1,13 @@
import os
import sys
from pathlib import Path
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
os.environ["AMD_ENABLE_LLPC"] = "1"
import gradio as gr
from PIL import Image
from models.stable_diffusion.resources import resource_path, prompt_examples
@@ -12,26 +19,7 @@ nodlogo_loc = resource_path("logos/nod-logo.png")
sdlogo_loc = resource_path("logos/sd-demo-logo.png")
demo_css = """
.gradio-container {background-color: black}
.container {background-color: black !important; padding-top:20px !important; }
#ui_title {padding: 10px !important; }
#top_logo {background-color: transparent; border-radius: 0 !important; border: 0; }
#demo_title {background-color: black; border-radius: 0 !important; border: 0; padding-top: 50px; padding-bottom: 0px; width: 460px !important;}
#demo_title_outer {border-radius: 0; }
#prompt_box_outer div:first-child {border-radius: 0 !important}
#prompt_box textarea {background-color:#1d1d1d !important}
#prompt_examples {margin:0 !important}
#prompt_examples svg {display: none !important;}
.gr-sample-textbox { border-radius: 1rem !important; border-color: rgb(31,41,55) !important; border-width:2px !important; }
#ui_body {background-color: #111111 !important; padding: 10px !important; border-radius: 0.5em !important;}
#img_result+div {display: none !important;}
footer {display: none !important;}
"""
demo_css = Path(__file__).parent.joinpath("demo.css").resolve()
with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
@@ -141,6 +129,13 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
lines=4,
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
prompt.submit(
stable_diff_inf,
@@ -175,8 +170,8 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
shark_web.queue()
shark_web.launch(
share=False,
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=8080,
server_port=args.server_port,
)

View File

@@ -16,6 +16,7 @@ from models.stable_diffusion.stable_args import args
from models.stable_diffusion.schedulers import (
SharkEulerDiscreteScheduler,
)
import gc
model_config = {
@@ -81,16 +82,25 @@ class ModelCache:
self.version = None
self.schedulers = None
self.tokenizer = None
self.vae = None
self.clip = None
self.unet = None
def set_models(self, device_key):
if self.device != device_key or self.variant != args.variant:
self.device = device_key
self.variant = args.variant
self.version = args.version
args.device = device_key.split("=>", 1)[0].strip()
args.device = device_key.split("=>", 1)[1].strip()
args.max_length = 64
args.use_tuned = True
set_init_device_flags()
del self.schedulers
del self.tokenizer
del self.vae
del self.unet
del self.clip
gc.collect()
self.schedulers = get_schedulers(args.version)
self.tokenizer = get_tokenizer(args.version)
self.vae = get_vae()

View File

@@ -1,14 +1,18 @@
import torch
import os
from PIL import Image
import torchvision.transforms as T
from tqdm.auto import tqdm
from models.stable_diffusion.cache_objects import model_cache
from models.stable_diffusion.stable_args import args
from models.stable_diffusion.utils import disk_space_check
from random import randint
import numpy as np
import time
import sys
from datetime import datetime as dt
from csv import DictWriter
import re
from pathlib import Path
if args.clear_all:
@@ -65,6 +69,55 @@ def set_ui_params(
args.variant = variant
# save output images and the inputs correspoding to it.
def save_output_img(output_img):
output_path = args.output_dir if args.output_dir else Path.cwd()
disk_space_check(output_path, lim=5)
generated_imgs_path = Path(output_path, "generated_imgs")
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_history.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{args.seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(
out_img_path,
quality=95,
subsampling=0,
optimize=True,
progressive=True,
)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
output_img.save(out_img_path, "PNG")
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"saving image as png. Supported formats png / jpg"
)
new_entry = {
"VARIANT": args.variant,
"VERSION": args.version,
"SCHEDULER": args.scheduler,
"PROMPT": args.prompts[0],
"NEG_PROMPT": args.negative_prompts[0],
"SEED": args.seed,
"CFG_SCALE": float(args.guidance_scale),
"PRECISION": args.precision,
"STEPS": args.steps,
"OUTPUT": out_img_path,
}
with open(csv_path, "a") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
dictwriter_obj.writerow(new_entry)
csv_obj.close()
def stable_diff_inf(
prompt: str,
negative_prompt: str,
@@ -103,6 +156,7 @@ def stable_diff_inf(
width = 768
# get all cached data.
disk_space_check(Path.cwd())
model_cache.set_models(device_key)
tokenizer = model_cache.tokenizer
scheduler = model_cache.schedulers[args.scheduler]
@@ -213,16 +267,17 @@ def stable_diff_inf(
print(f"\nTotal image generation time: {total_time}sec")
# generate outputs to web.
transform = T.ToPILImage()
pil_images = [
transform(image) for image in torch.from_numpy(images).to(torch.uint8)
]
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nvariant={args.variant}, scheduler={args.scheduler}, device={device_key}"
text_output += f"\nvariant={args.variant}, version={args.version}, scheduler={args.scheduler}"
text_output += f"\ndevice={device_key}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={height}x{width}"
text_output += f"\nAverage step time: {avg_ms:.4f}ms/it"
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(pil_images[0])
return pil_images[0], text_output

View File

@@ -1,14 +1,10 @@
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from models.stable_diffusion.utils import compile_through_fx
from models.stable_diffusion.resources import models_config
from models.stable_diffusion.stable_args import args
import torch
model_config = {
"v2_1": "stabilityai/stable-diffusion-2-1",
"v2_1base": "stabilityai/stable-diffusion-2-1-base",
"v1_4": "CompVis/stable-diffusion-v1-4",
}
# clip has 2 variants of max length 77 or 64.
model_clip_max_length = 64 if args.max_length == 64 else 77
@@ -17,14 +13,6 @@ if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
elif args.variant == "openjourney":
model_clip_max_length = 64
model_variant = {
"stablediffusion": "SD",
"anythingv3": "Linaqruf/anything-v3.0",
"dreamlike": "dreamlike-art/dreamlike-diffusion-1.0",
"openjourney": "prompthero/openjourney",
"analogdiffusion": "wavymulder/Analog-Diffusion",
}
model_input = {
"v2_1": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
@@ -58,45 +46,34 @@ model_input = {
},
}
# revision param for from_pretrained defaults to "main" => fp32
model_revision = {
"stablediffusion": "fp16" if args.precision == "fp16" else "main",
"anythingv3": "diffusers",
"analogdiffusion": "main",
"openjourney": "main",
"dreamlike": "main",
}
version = args.version if args.variant == "stablediffusion" else "v1_4"
def get_configs():
model_id_key = f"{args.variant}/{version}"
revision_key = f"{args.variant}/{args.precision}"
try:
model_id = models_config[0][model_id_key]
revision = models_config[1][revision_key]
except KeyError:
raise Exception(
f"No entry for {model_id_key} or {revision_key} in the models configuration"
)
return model_id, revision
def get_clip_mlir(model_name="clip_text", extra_args=[]):
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
if args.variant == "stablediffusion":
if args.version != "v1_4":
text_encoder = CLIPTextModel.from_pretrained(
model_config[args.version], subfolder="text_encoder"
)
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
text_encoder = CLIPTextModel.from_pretrained(
model_variant[args.variant],
subfolder="text_encoder",
revision=model_revision[args.variant],
)
else:
raise ValueError(f"{args.variant} not yet added")
model_id, revision = get_configs()
class CLIPText(torch.nn.Module):
def __init__(self):
super().__init__()
self.text_encoder = text_encoder
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
revision=revision,
)
def forward(self, input):
return self.text_encoder(input)[0]
@@ -104,23 +81,44 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
clip_model = CLIPText()
shark_clip = compile_through_fx(
clip_model,
model_input[args.version]["clip"],
model_input[version]["clip"],
model_name=model_name,
extra_args=extra_args,
)
return shark_clip
def get_shark_module(model_key, module, model_name, extra_args):
if args.precision == "fp16":
module = module.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input[version][model_key]
]
)
else:
inputs = model_input[version][model_key]
shark_module = compile_through_fx(
module,
inputs,
model_name=model_name,
extra_args=extra_args,
)
return shark_module
def get_base_vae_mlir(model_name="vae", extra_args=[]):
model_id, revision = get_configs()
class BaseVaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
model_id,
subfolder="vae",
revision=model_revision[args.variant],
revision=revision,
)
def forward(self, input):
@@ -128,52 +126,19 @@ def get_base_vae_mlir(model_name="vae", extra_args=[]):
return (x / 2 + 0.5).clamp(0, 1)
vae = BaseVaeModel()
if args.variant == "stablediffusion":
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
)
else:
inputs = model_input["v1_4"]["vae"]
else:
raise ValueError(f"{args.variant} not yet added")
shark_vae = compile_through_fx(
vae,
inputs,
model_name=model_name,
extra_args=extra_args,
)
return shark_vae
return get_shark_module("vae", vae, model_name, extra_args)
def get_vae_mlir(model_name="vae", extra_args=[]):
model_id, revision = get_configs()
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
model_id,
subfolder="vae",
revision=model_revision[args.variant],
revision=revision,
)
def forward(self, input):
@@ -184,52 +149,19 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
return x.round()
vae = VaeModel()
if args.variant == "stablediffusion":
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
)
else:
inputs = model_input["v1_4"]["vae"]
else:
raise ValueError(f"{args.variant} not yet added")
shark_vae = compile_through_fx(
vae,
inputs,
model_name=model_name,
extra_args=extra_args,
)
return shark_vae
return get_shark_module("vae", vae, model_name, extra_args)
def get_unet_mlir(model_name="unet", extra_args=[]):
model_id, revision = get_configs()
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
model_id,
subfolder="unet",
revision=model_revision[args.variant],
revision=revision,
)
self.in_channels = self.unet.in_channels
self.train(False)
@@ -247,39 +179,4 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
return noise_pred
unet = UnetModel()
if args.variant == "stablediffusion":
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input[args.version]["unet"]
]
)
else:
inputs = model_input[args.version]["unet"]
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input["v1_4"]["unet"]
]
)
else:
inputs = model_input["v1_4"]["unet"]
else:
raise ValueError(f"{args.variant} is not yet added")
shark_unet = compile_through_fx(
unet,
inputs,
model_name=model_name,
extra_args=extra_args,
)
return shark_unet
return get_shark_module("unet", unet, model_name, extra_args)

View File

@@ -33,7 +33,7 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
]
except KeyError:
raise Exception(
f"{bucket}/{model_key} is not present in the models database"
f" there is no entry for {model_key} in the models database"
)
if (

View File

@@ -29,3 +29,13 @@ if os.path.exists(models_loc):
if len(models_db) != 3:
sys.exit("Error: Unable to load models database.")
models_config = []
modelconfig_loc = resource_path("resources/model_config.json")
if os.path.exists(modelconfig_loc):
with open(modelconfig_loc, encoding="utf-8") as fopen:
models_config = json.load(fopen)
if len(models_config) != 2:
sys.exit("Error: Unable to load models configuration.")

View File

@@ -0,0 +1,21 @@
[
{
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
"openjourney/v1_4":"prompthero/openjourney",
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
},
{
"stablediffusion/fp16":"fp16",
"stablediffusion/fp32":"main",
"anythingv3/fp16":"diffusers",
"anythingv3/fp32":"diffusers",
"analogdiffusion/fp16":"main",
"analogdiffusion/fp32":"main",
"openjourney/fp16":"main",
"openjourney/fp32":"main"
}
]

View File

@@ -12,7 +12,6 @@
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_1dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",

View File

@@ -117,6 +117,20 @@ p.add_argument(
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
@@ -223,4 +237,20 @@ p.add_argument(
help="flag for removing the pregress bar animation during image generation",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
##############################################################################
args = p.parse_args()

View File

@@ -180,7 +180,9 @@ def set_init_device_flags():
args.device = "cpu"
# set max_length based on availability.
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
if args.version == "v1_4":
args.max_length = 77
elif args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
args.max_length = 77
elif args.variant == "openjourney":
args.max_length = 64
@@ -189,6 +191,7 @@ def set_init_device_flags():
if (
args.variant in ["openjourney", "dreamlike"]
or args.precision != "fp16"
or args.version == "v1_4"
or "vulkan" not in args.device
or "rdna3" not in args.iree_vulkan_target_triple
):
@@ -217,7 +220,7 @@ def get_available_devices():
print(f"{driver_name} devices are not available.")
else:
for i, device in enumerate(device_list_dict):
device_list.append(f"{driver_name}://{i} => {device['name']}")
device_list.append(f"{device['name']} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
@@ -227,5 +230,14 @@ def get_available_devices():
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
# available_devices.append("cpu")
return available_devices
def disk_space_check(path, lim=20):
from shutil import disk_usage
du = disk_usage(path)
free = du.free / (1024 * 1024 * 1024)
if free <= lim:
print(f"[WARNING] Only {free:.2f}GB space available in {path}.")

View File

@@ -26,8 +26,10 @@ datas += collect_data_files('shark')
datas += [
( 'models/stable_diffusion/resources/prompts.json', 'resources' ),
( 'models/stable_diffusion/resources/model_db.json', 'resources' ),
( 'models/stable_diffusion/resources/model_config.json', 'resources' ),
( 'models/stable_diffusion/logos/*', 'logos' )
]
datas += [('demo.css', '.')]
binaries = []

240
web/telegram_bot.py Normal file
View File

@@ -0,0 +1,240 @@
import logging
import os
from models.stable_diffusion.main import stable_diff_inf
from models.stable_diffusion.utils import get_available_devices
from dotenv import load_dotenv
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram import BotCommand
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
from io import BytesIO
import random
log = logging.getLogger("TG.Bot")
logging.basicConfig()
log.warning("Start")
load_dotenv()
os.environ["AMD_ENABLE_LLPC"] = "0"
TG_TOKEN = os.getenv("TG_TOKEN")
SELECTED_MODEL = "stablediffusion"
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
STEPS = 30
NEGATIVE_PROMPT = (
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
" legs, fused fingers, too many fingers"
)
GUIDANCE_SCALE = 6
available_devices = get_available_devices()
models_list = [
"stablediffusion",
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]
sheds_list = [
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
]
def image_to_bytes(image):
bio = BytesIO()
bio.name = "image.jpeg"
image.save(bio, "JPEG")
bio.seek(0)
return bio
def get_try_again_markup():
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
reply_markup = InlineKeyboardMarkup(keyboard)
return reply_markup
def generate_image(prompt):
seed = random.randint(1, 10000)
log.warning(SELECTED_MODEL)
log.warning(STEPS)
image, text = stable_diff_inf(
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
steps=STEPS,
guidance_scale=GUIDANCE_SCALE,
seed=seed,
scheduler_key=SELECTED_SCHEDULER,
variant=SELECTED_MODEL,
device_key=available_devices[0],
)
return image, seed
async def generate_and_send_photo(
update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
progress_msg = await update.message.reply_text(
"Generating image...", reply_to_message_id=update.message.message_id
)
im, seed = generate_image(prompt=update.message.text)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{update.message.text}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=update.message.message_id,
)
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
query = update.callback_query
if query.data in models_list:
global SELECTED_MODEL
SELECTED_MODEL = query.data
await query.answer()
await query.edit_message_text(text=f"Selected model: {query.data}")
return
if query.data in sheds_list:
global SELECTED_SCHEDULER
SELECTED_SCHEDULER = query.data
await query.answer()
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
return
replied_message = query.message.reply_to_message
await query.answer()
progress_msg = await query.message.reply_text(
"Generating image...", reply_to_message_id=replied_message.message_id
)
if query.data == "TRYAGAIN":
prompt = replied_message.text
im, seed = generate_image(prompt)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{prompt}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=replied_message.message_id,
)
async def select_model_handler(update, context):
text = "Select model"
keyboard = []
for model in models_list:
keyboard.append(
[
InlineKeyboardButton(text=model, callback_data=model),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def select_scheduler_handler(update, context):
text = "Select schedule"
keyboard = []
for shed in sheds_list:
keyboard.append(
[
InlineKeyboardButton(text=shed, callback_data=shed),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def set_steps_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_steps ")[1]
global STEPS
STEPS = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_steps 30"
)
await update.message.reply_text(input_args)
async def set_negative_prompt_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_negative_prompt ")[1]
global NEGATIVE_PROMPT
NEGATIVE_PROMPT = input_args
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_negative_prompt ugly, bad art, mutated"
)
await update.message.reply_text(input_args)
async def set_guidance_scale_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_guidance_scale ")[1]
global GUIDANCE_SCALE
GUIDANCE_SCALE = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_guidance_scale 7"
)
await update.message.reply_text(input_args)
async def setup_bot_commands(application: Application) -> None:
await application.bot.set_my_commands(
[
BotCommand("select_model", "to select model"),
BotCommand("select_scheduler", "to select scheduler"),
BotCommand("set_steps", "to set steps"),
BotCommand("set_guidance_scale", "to set guidance scale"),
BotCommand("set_negative_prompt", "to set negative prompt"),
]
)
app = (
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
)
app.add_handler(CommandHandler("select_model", select_model_handler))
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
app.add_handler(CommandHandler("set_steps", set_steps_handler))
app.add_handler(
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
)
app.add_handler(
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
)
app.add_handler(
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
)
app.add_handler(CallbackQueryHandler(button))
log.warning("Start bot")
app.run_polling()