Compare commits

..

1 Commits

Author SHA1 Message Date
Stanley Winata
c47218c972 WIP: Intel-GPU integration 2022-08-10 11:24:00 -07:00
251 changed files with 3920 additions and 26007 deletions

View File

@@ -1,37 +0,0 @@
# See: https://github.com/llvm/torch-mlir/issues/1374
name: Publish releases page
on:
workflow_dispatch:
jobs:
scrape_and_publish_releases:
name: "Scrape and publish releases"
runs-on: ubuntu-latest
# Don't run this in everyone's forks.
if: github.repository == 'nod-ai/SHARK'
steps:
- name: Checking out repository
uses: actions/checkout@v2
with:
token: ${{ secrets.NODAI_INVOCATION_TOKEN }}
- name: Run scrape releases script
run: python ./build_tools/scrape_releases.py nod-ai SHARK > /tmp/index.html
shell: bash
- run: git fetch --all
- run: git switch github-pages
- run: git config --global user.email "none@none.com"
- run: git config --global user.name "nod-ai"
- run: mv /tmp/index.html package-index/index.html
- run: git add package-index/index.html
# Only try to make a commit if the file has changed.
- run: git diff --cached --exit-code || git commit -m "Update releases."
- name: GitHub Push
uses: ad-m/github-push-action@v0.6.0
with:
github_token: ${{ secrets.NODAI_INVOCATION_TOKEN }}
branch: github-pages

View File

@@ -9,91 +9,13 @@ on:
workflow_dispatch:
jobs:
windows-build:
runs-on: 7950X
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Compute version
shell: powershell
run: |
$package_version = $(Get-Date -UFormat "%Y%m%d")+"."+${{ github.run_number }}
$package_version_ = $(Get-Date -UFormat "%Y%m%d")+"_"+${{ github.run_number }}
$tag_name=$package_version
echo "package_version=$package_version" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
echo "package_version_=$package_version_" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
echo "tag_name=$tag_name" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
tag_name: ${{ env.tag_name }}
release_name: nod.ai SHARK ${{ env.tag_name }}
body: |
Automatic snapshot release of nod.ai SHARK.
draft: true
prerelease: false
- name: Build Package
shell: powershell
run: |
./setup_venv.ps1
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
# GHA windows VM OOMs so disable for now
#- name: Build and validate the SHARK Runtime package
# shell: powershell
# run: |
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
- uses: actions/upload-artifact@v2
with:
path: dist/*
- name: Upload Release Assets
id: upload-release-assets
uses: dwenegar/upload-release-assets@v1
env:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
assets_path: ./dist/*
- name: Publish Release
id: publish_release
uses: eregon/publish-release@v1
env:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
linux-build:
build:
runs-on: a100
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
backend: [IREE, SHARK]
steps:
- uses: actions/checkout@v3
@@ -109,56 +31,63 @@ jobs:
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Compute version
run: |
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
tag_name="${package_version}"
echo "package_version=${package_version}" >> $GITHUB_ENV
echo "tag_name=${tag_name}" >> $GITHUB_ENV
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
tag_name: ${{ env.tag_name }}
release_name: nod.ai SHARK ${{ env.tag_name }}
body: |
Automatic snapshot release of nod.ai SHARK.
draft: true
prerelease: false
- name: Install dependencies
run: |
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
if [ -f requirements.txt ]; then pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/SHARK-Runtime/releases; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
- name: Build and validate the IREE package
if: ${{ matrix.backend == 'IREE' }}
continue-on-error: true
run: |
cd $GITHUB_WORKSPACE
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
source iree.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://iree-org.github.io/iree/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt
if !(grep -Fxq " failed" pytest_results.txt)
then
export SHA=$(git log -1 --format='%h')
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/latest/
fi
rm -rf ./wheelhouse/nodai*
- name: Build and validate the SHARK Runtime package
if: ${{ matrix.backend == 'SHARK' }}
- name: Build and validate the package
run: |
cd $GITHUB_WORKSPACE
./setup_venv.sh
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/SHARK-Runtime/releases
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt
pytest -k 'not benchmark' --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py --ignore=shark/tests/test_shark_importer.py --ignore=tank/tf/
- name: Upload Release Assets
id: upload-release-assets
uses: dwenegar/upload-release-assets@v1
env:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
assets_path: ./wheelhouse/nodai_*.whl
- name: Publish Release
id: publish_release
uses: eregon/publish-release@v1
env:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}

View File

@@ -6,31 +6,17 @@ name: Validate Models on Shark Runtime
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
suite: [cpu,cuda,vulkan]
os: [a100, MacStudio, ubuntu-latest]
suite: [cpu,gpu,vulkan]
python-version: ["3.10"]
include:
- os: ubuntu-latest
@@ -39,44 +25,27 @@ jobs:
- os: ubuntu-latest
suite: vulkan
- os: ubuntu-latest
suite: cuda
suite: gpu
- os: ubuntu-latest
suite: cpu
- os: MacStudio
suite: cuda
suite: gpu
- os: MacStudio
suite: cpu
- os: icelake
suite: vulkan
- os: icelake
suite: cuda
- os: a100
suite: cpu
- os: 7950x
suite: cpu
- os: 7950x
suite: cuda
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
if: matrix.os != '7950x'
- name: Set Environment Variables
if: matrix.os != '7950x'
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest'
run: |
# See https://github.com/actions/setup-python/issues/433
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest'
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'
@@ -84,9 +53,6 @@ jobs:
#cache-dependency-path: |
# **/requirements-importer.txt
# **/requirements.txt
- uses: actions/checkout@v2
if: matrix.os == '7950x'
- name: Install dependencies
if: matrix.suite == 'lint'
@@ -105,57 +71,26 @@ jobs:
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude lit.cfg.py
- name: Validate Models on CPU
- name: Validate CPU Models
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./shark_tmp/shark_cache" -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
pytest -k 'cpu' --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
- name: Validate Models on NVIDIA GPU
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./shark_tmp/shark_cache" -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
- name: Validate GPU Models
if: matrix.suite == 'gpu'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./shark_tmp/shark_cache" -k vulkan
pytest -k "gpu" --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
- name: Validate Vulkan Models
if: matrix.suite == 'vulkan'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./shark_tmp/shark_cache" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
pytest --benchmark -k vulkan -s
type bench_results.csv
- name: Validate Stable Diffusion Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
python build_tools/stable_diffusion_testing.py --device=vulkan
pytest -k 'vulkan' --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py

7
.gitignore vendored
View File

@@ -31,6 +31,7 @@ MANIFEST
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
@@ -162,13 +163,7 @@ cython_debug/
# Shark related artefacts
*venv/
shark_tmp/
*.vmfb
.use-iree
tank/dict_configs.py
# ORT related artefacts
cache_models/
onnx_models/
# Generated images
generated_imgs/

218
LICENSE
View File

@@ -1,218 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
---- LLVM Exceptions to the Apache 2.0 License ----
As an exception, if, as a result of your compiling your source code, portions
of this Software are embedded into an Object form of such source code, you
may redistribute such embedded portions in such Object form without complying
with the conditions of Sections 4(a), 4(b) and 4(d) of the License.
In addition, if you combine or link compiled forms of this Software with
software that is licensed under the GPLv2 ("Combined Software") and if a
court of competent jurisdiction determines that the patent provision (Section
3), the indemnity provision (Section 9) or other Section of the License
conflicts with the conditions of the GPLv2, you may retroactively and
prospectively choose to deem waived or otherwise exclude such Section(s) of
the License, but only in their entirety and only with respect to the Combined
Software.

389
README.md
View File

@@ -5,119 +5,25 @@ High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerator
[![Nightly Release](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
[![Validate torch-models on Shark Runtime](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
## Communication Channels
## Installation (Windows, Linux and macOS)
## Check out the code
```shell
git clone https://github.com/nod-ai/SHARK.git
cd SHARK
```
## Setup your Python VirtualEnvironment and Dependencies
### Windows 10/11 Users
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
* Install Git for Windows from [here](https://git-scm.com/download/win)
#### Allow the install script to run in Powershell
```powershell
set-executionpolicy remotesigned
```
#### Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...)
```powershell
./setup_venv.ps1 #You can re-run this script to get the latest version
```
### Linux / macOS Users
```shell
./setup_venv.sh
source shark.venv/bin/activate
```
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
### Run Stable Diffusion on your device - WebUI
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
```
#### Linux / macOS Users
```shell
(shark.venv) > cd apps/stable_diffusion/web
(shark.venv) > python index.py
```
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
### Run Stable Diffusion on your device - Commandline
#### Install your hardware drivers
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mril-iree)
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
Other users please ensure you have your latest vendor drivers and Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home) and if you are using vulkan check `vulkaninfo` works in a terminal window
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
```shell
python3.10 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
The output on a 7900XTX would like:
```shell
Stats for run 0:
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
Total image generation time: 2.5788655281066895sec
```
Here are some samples generated:
![tajmahal, snow, sunflowers, oil on canvas_0](https://user-images.githubusercontent.com/74956/204934186-141f7e43-6eb2-4e89-a99c-4704d20444b3.jpg)
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
## Installation
<details>
<summary>Binary Installation</summary>
<summary>Installation (Linux and macOS)</summary>
### Setup a new pip Virtual Environment
This step sets up a new VirtualEnv for Python
```shell
python --version #Check you have 3.10 on Linux, macOS or Windows Powershell
python --version #Check you have 3.7->3.10 on Linux or 3.10 on macOS
python -m venv shark_venv
source shark_venv/bin/activate # Use shark_venv/Scripts/activate on Windows
source shark_venv/bin/activate
# If you are using conda create and activate a new conda env
@@ -132,21 +38,16 @@ python -m pip install --upgrade pip
This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10
```shell
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install nodai-shark -f https://github.com/nod-ai/SHARK/releases -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/shark-runtime/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```
### Run shark tank model tests.
```shell
pytest tank/test_models.py
```
See tank/README.md for a more detailed walkthrough of our pytest suite and CLI.
If you are on an Intel macOS machine you need this [workaround](https://github.com/nod-ai/SHARK/issues/102) for an upstream issue.
### Download and run Resnet50 sample
```shell
curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/resnet50_script.py
#Install deps for test script
pip install --pre torch torchvision torchaudio tqdm pillow gsutil --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install --pre torch torchvision torchaudio tqdm pillow --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python ./resnet50_script.py --device="cpu" #use cuda or vulkan or metal
```
@@ -160,78 +61,78 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
</details>
<details>
<summary>Development, Testing and Benchmarks</summary>
<summary>Source Installation</summary>
If you want to use Python3.10 and with TF Import tools you can use the environment variables like:
Set `USE_IREE=1` to use upstream IREE
```
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
```
## Check out the code
### Run any of the hundreds of SHARK tank models via the test framework
```shell
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
# Or a pytest
pytest tank/test_models.py -k "MiniLM"
git clone https://github.com/nod-ai/SHARK.git
```
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
## Setup your Python VirtualEnvironment and Dependencies
```shell
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
./setup_venv.sh
source shark.venv/bin/activate
```
For example if you want to use Python3.10 and upstream IREE with TF Import tools you can use the environment variables like:
```
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 USE_IREE=1 ./setup_venv.sh
```
If you are a Torch-mlir developer or an IREE developer and want to test local changes you can uninstall
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
with Python bindings and set your PYTHONPATH as mentioned [here](https://google.github.io/iree/bindings/python/)
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
for Torch-MLIR.
### How to use your locally built Torch-MLIR with SHARK
### Run a demo script
```shell
1.) Run `./setup_venv.sh in SHARK` and activate `shark.venv` virtual env.
2.) Run `pip uninstall torch-mlir`.
3.) Go to your local Torch-MLIR directory.
4.) Activate mlir_venv virtual envirnoment.
5.) Run `pip uninstall -r requirements.txt`.
6.) Run `pip install -r requirements.txt`.
7.) Build Torch-MLIR.
8.) Activate shark.venv virtual environment from the Torch-MLIR directory.
8.) Run `export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples` in the Torch-MLIR directory.
9.) Go to the SHARK directory.
```
Now the SHARK will use your locally build Torch-MLIR repo.
## Benchmarking Dispatches
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your command line argument.
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
```
shark_module = SharkInference(
mlir_model,
func_name,
device=args.device,
mlir_dialect="tm_tensor",
dispatch_benchmarks="all",
dispatch_benchmarks_dir="results"
)
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
# Or a pytest
pytest tank/tf/hf_masked_lm/albert-base-v2_test.py::AlbertBaseModuleTest::test_module_static_cpu
```
Output will include:
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
- An .mlir file containing the dispatch benchmark
- A compiled .vmfb file containing the dispatch benchmark
- An .mlir file containing just the hal executable
- A compiled .vmfb file of the hal executable
- A .txt file containing benchmark output
See tank/README.md for instructions on how to run model tests and benchmarks from the SHARK tank.
</details>
<details>
<summary>Testing</summary>
### Run all model tests on CPU/GPU/VULKAN/Metal
```shell
pytest tank
# If on Linux for quicker results:
pytest tank -n auto
```
### Running specific tests
```shell
# Run tests for a specific model:
pytest tank/<MODEL_NAME> #i.e., pytest tank/bert-base-uncased
# Run tests for a specific case:
pytest tank/<MODEL_NAME>/<MODEL_TEST>.py::<MODEL>ModuleTest::<CASE>
# i.e., pytest tank/bert-base-uncased/bert-base-uncased_test.py::BertModuleTest::test_module_static_cpu
# For frontends other than pytorch, if available for a model, add frontend to filename: tank/bert-base-uncased/bert-base-uncased_tf_test.py
# Run all tests, including tests for benchmarking and SHARK modules:
# From base SHARK directory,
pytest
```
### Run all model benchmark tests on CPU/GPU/VULKAN/Metal
```shell
pytest benchmarks
```
</details>
<details>
<summary>API Reference</summary>
@@ -282,26 +183,160 @@ result = shark_module.forward((arg0, arg1))
```
</details>
## Supported and Validated Models
SHARK is maintained to support the latest innovations in ML Models:
<details>
<summary>PyTorch Models</summary>
| TF HuggingFace Models | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------|----------|-------------|
| BERT | :green_heart: | :green_heart: | :green_heart: |
| DistilBERT | :green_heart: | :green_heart: | :green_heart: |
| GPT2 | :green_heart: | :green_heart: | :green_heart: |
| BLOOM | :green_heart: | :green_heart: | :green_heart: |
| Stable Diffusion | :green_heart: | :green_heart: | :green_heart: |
| Vision Transformer | :green_heart: | :green_heart: | :green_heart: |
| ResNet50 | :green_heart: | :green_heart: | :green_heart: |
### Huggingface PyTorch Models
For a complete list of the models supported in SHARK, please refer to [tank/README.md](https://github.com/nod-ai/SHARK/blob/main/tank/README.md).
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------------------|----------|----------|-------------|
| BERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
| Albert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
| BigBird | :green_heart: (AOT) | | | |
| DistilBERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
| GPT2 | :broken_heart: (AOT) | | | |
| MobileBert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
## Communication Channels
### Torchvision Models
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|--------------------|----------------------|----------|----------|-------------|
| AlexNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| DenseNet121 | :green_heart: (Script) | | | |
| MNasNet1_0 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| MobileNetV2 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| MobileNetV3 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| Unet | :broken_heart: (Script) | | | |
| Resnet18 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| Resnet50 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| Resnet101 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| Resnext50_32x4d | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| ShuffleNet_v2 | :broken_heart: (Script) | | | |
| SqueezeNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| EfficientNet | :green_heart: (Script) | | | |
| Regnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| Resnest | :broken_heart: (Script) | | | |
| Vision Transformer | :green_heart: (Script) | | | |
| VGG 16 | :green_heart: (Script) | :green_heart: | :green_heart: | |
| Wide Resnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
| RAFT | :broken_heart: (JIT) | | | |
For more information refer to [MODEL TRACKING SHEET](https://docs.google.com/spreadsheets/d/15PcjKeHZIrB5LfDyuw7DGEEE8XnQEX2aX8lm8qbxV8A/edit#gid=0)
### PyTorch Training Models
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------------------|----------|----------|-------------|
| BERT | :broken_heart: | :broken_heart: | | |
| FullyConnected | :green_heart: | :green_heart: | | |
</details>
<details>
<summary>JAX Models</summary>
### JAX Models
| Models | JAX-MHLO lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------------------|----------|----------|-------------|
| DALL-E | :broken_heart: | :broken_heart: | | |
| FullyConnected | :green_heart: | :green_heart: | | |
</details>
<details>
<summary>TFLite Models</summary>
### TFLite Models
| Models | TOSA/LinAlg | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------------------|----------|----------|-------------|
| BERT | :broken_heart: | :broken_heart: | | |
| FullyConnected | :green_heart: | :green_heart: | | |
| albert | :green_heart: | :green_heart: | | |
| asr_conformer | :green_heart: | :green_heart: | | |
| bird_classifier | :green_heart: | :green_heart: | | |
| cartoon_gan | :green_heart: | :green_heart: | | |
| craft_text | :green_heart: | :green_heart: | | |
| deeplab_v3 | :green_heart: | :green_heart: | | |
| densenet | :green_heart: | :green_heart: | | |
| east_text_detector | :green_heart: | :green_heart: | | |
| efficientnet_lite0_int8 | :green_heart: | :green_heart: | | |
| efficientnet | :green_heart: | :green_heart: | | |
| gpt2 | :green_heart: | :green_heart: | | |
| image_stylization | :green_heart: | :green_heart: | | |
| inception_v4 | :green_heart: | :green_heart: | | |
| inception_v4_uint8 | :green_heart: | :green_heart: | | |
| lightning_fp16 | :green_heart: | :green_heart: | | |
| lightning_i8 | :green_heart: | :green_heart: | | |
| lightning | :green_heart: | :green_heart: | | |
| magenta | :green_heart: | :green_heart: | | |
| midas | :green_heart: | :green_heart: | | |
| mirnet | :green_heart: | :green_heart: | | |
| mnasnet | :green_heart: | :green_heart: | | |
| mobilebert_edgetpu_s_float | :green_heart: | :green_heart: | | |
| mobilebert_edgetpu_s_quant | :green_heart: | :green_heart: | | |
| mobilebert | :green_heart: | :green_heart: | | |
| mobilebert_tf2_float | :green_heart: | :green_heart: | | |
| mobilebert_tf2_quant | :green_heart: | :green_heart: | | |
| mobilenet_ssd_quant | :green_heart: | :green_heart: | | |
| mobilenet_v1 | :green_heart: | :green_heart: | | |
| mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
| mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
| mobilenet_v2 | :green_heart: | :green_heart: | | |
| mobilenet_v2_uint8 | :green_heart: | :green_heart: | | |
| mobilenet_v3-large | :green_heart: | :green_heart: | | |
| mobilenet_v3-large_uint8 | :green_heart: | :green_heart: | | |
| mobilenet_v35-int8 | :green_heart: | :green_heart: | | |
| nasnet | :green_heart: | :green_heart: | | |
| person_detect | :green_heart: | :green_heart: | | |
| posenet | :green_heart: | :green_heart: | | |
| resnet_50_int8 | :green_heart: | :green_heart: | | |
| rosetta | :green_heart: | :green_heart: | | |
| spice | :green_heart: | :green_heart: | | |
| squeezenet | :green_heart: | :green_heart: | | |
| ssd_mobilenet_v1 | :green_heart: | :green_heart: | | |
| ssd_mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
| ssd_mobilenet_v2_fpnlite | :green_heart: | :green_heart: | | |
| ssd_mobilenet_v2_fpnlite_uint8 | :green_heart: | :green_heart: | | |
| ssd_mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
| ssd_mobilenet_v2 | :green_heart: | :green_heart: | | |
| ssd_spaghettinet_large | :green_heart: | :green_heart: | | |
| ssd_spaghettinet_large_uint8 | :green_heart: | :green_heart: | | |
| visual_wake_words_i8 | :green_heart: | :green_heart: | | |
</details>
<details>
<summary>TF Models</summary>
### Tensorflow Models (Inference)
| Hugging Face Models | tf-mhlo lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------------------|----------|----------|-------------|
| BERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| albert-base-v2 | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| DistilBERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| CamemBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| ConvBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| Deberta | | | | |
| electra | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| funnel | | | | |
| layoutlm | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| longformer | | | | |
| mobile-bert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| remembert | | | | |
| tapas | | | | |
| flaubert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| xlm-roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
| mpnet | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
</details>
## Related Projects

View File

View File

@@ -1,44 +0,0 @@
Compile / Run Instructions:
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
Compile Commands FP32/FP16:
```shell
Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=mhlo for tf models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
CPU:
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
```
Run / Benchmark Command (FP32 - NCHW):
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
```shell
## Vulkan AMD:
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --device=vulkan --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
## CUDA:
iree-benchmark-module --module_file=/path/to/vmfb --entry_function=forward --device=cuda --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
## CPU:
iree-benchmark-module --module_file=/path/to/vmfb --entry_function=forward --device=local-task --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
```
Run via vulkan_gui for RGP Profiling:
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
```shell
./build/vulkan_gui/iree-vulkan-gui --module_file=/path/to/unet.vmfb --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
```

View File

@@ -1 +0,0 @@
from apps.stable_diffusion.scripts.txt2img import txt2img_inf

View File

@@ -1,240 +0,0 @@
import logging
import os
from models.stable_diffusion.main import stable_diff_inf
from models.stable_diffusion.utils import get_available_devices
from dotenv import load_dotenv
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram import BotCommand
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
from io import BytesIO
import random
log = logging.getLogger("TG.Bot")
logging.basicConfig()
log.warning("Start")
load_dotenv()
os.environ["AMD_ENABLE_LLPC"] = "0"
TG_TOKEN = os.getenv("TG_TOKEN")
SELECTED_MODEL = "stablediffusion"
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
STEPS = 30
NEGATIVE_PROMPT = (
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
" legs, fused fingers, too many fingers"
)
GUIDANCE_SCALE = 6
available_devices = get_available_devices()
models_list = [
"stablediffusion",
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]
sheds_list = [
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
]
def image_to_bytes(image):
bio = BytesIO()
bio.name = "image.jpeg"
image.save(bio, "JPEG")
bio.seek(0)
return bio
def get_try_again_markup():
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
reply_markup = InlineKeyboardMarkup(keyboard)
return reply_markup
def generate_image(prompt):
seed = random.randint(1, 10000)
log.warning(SELECTED_MODEL)
log.warning(STEPS)
image, text = stable_diff_inf(
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
steps=STEPS,
guidance_scale=GUIDANCE_SCALE,
seed=seed,
scheduler_key=SELECTED_SCHEDULER,
variant=SELECTED_MODEL,
device_key=available_devices[0],
)
return image, seed
async def generate_and_send_photo(
update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
progress_msg = await update.message.reply_text(
"Generating image...", reply_to_message_id=update.message.message_id
)
im, seed = generate_image(prompt=update.message.text)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{update.message.text}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=update.message.message_id,
)
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
query = update.callback_query
if query.data in models_list:
global SELECTED_MODEL
SELECTED_MODEL = query.data
await query.answer()
await query.edit_message_text(text=f"Selected model: {query.data}")
return
if query.data in sheds_list:
global SELECTED_SCHEDULER
SELECTED_SCHEDULER = query.data
await query.answer()
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
return
replied_message = query.message.reply_to_message
await query.answer()
progress_msg = await query.message.reply_text(
"Generating image...", reply_to_message_id=replied_message.message_id
)
if query.data == "TRYAGAIN":
prompt = replied_message.text
im, seed = generate_image(prompt)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{prompt}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=replied_message.message_id,
)
async def select_model_handler(update, context):
text = "Select model"
keyboard = []
for model in models_list:
keyboard.append(
[
InlineKeyboardButton(text=model, callback_data=model),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def select_scheduler_handler(update, context):
text = "Select schedule"
keyboard = []
for shed in sheds_list:
keyboard.append(
[
InlineKeyboardButton(text=shed, callback_data=shed),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def set_steps_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_steps ")[1]
global STEPS
STEPS = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_steps 30"
)
await update.message.reply_text(input_args)
async def set_negative_prompt_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_negative_prompt ")[1]
global NEGATIVE_PROMPT
NEGATIVE_PROMPT = input_args
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_negative_prompt ugly, bad art, mutated"
)
await update.message.reply_text(input_args)
async def set_guidance_scale_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_guidance_scale ")[1]
global GUIDANCE_SCALE
GUIDANCE_SCALE = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_guidance_scale 7"
)
await update.message.reply_text(input_args)
async def setup_bot_commands(application: Application) -> None:
await application.bot.set_my_commands(
[
BotCommand("select_model", "to select model"),
BotCommand("select_scheduler", "to select scheduler"),
BotCommand("set_steps", "to set steps"),
BotCommand("set_guidance_scale", "to set guidance scale"),
BotCommand("set_negative_prompt", "to set negative prompt"),
]
)
app = (
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
)
app.add_handler(CommandHandler("select_model", select_model_handler))
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
app.add_handler(CommandHandler("set_steps", set_steps_handler))
app.add_handler(
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
)
app.add_handler(
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
)
app.add_handler(
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
)
app.add_handler(CallbackQueryHandler(button))
log.warning("Start bot")
app.run_polling()

View File

@@ -1,284 +0,0 @@
import os
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
import sys
import json
import torch
import re
import time
from pathlib import Path
from PIL import PngImagePlugin
from datetime import datetime as dt
from dataclasses import dataclass
from csv import DictWriter
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
)
@dataclass
class Config:
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
# This has to come before importing cache objects
if args.clear_all:
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
from glob import glob
import shutil
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
if os.path.exists(yaml):
os.remove(yaml)
home = os.path.expanduser("~")
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
# save output images and the inputs correspoding to it.
def save_output_img(output_img):
output_path = args.output_dir if args.output_dir else Path.cwd()
generated_imgs_path = Path(output_path, "generated_imgs")
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{args.seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {args.seed}, Size: {args.width}x{args.height}, Model: {args.hf_model_id}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
)
new_entry = {
"VARIANT": args.hf_model_id,
"SCHEDULER": args.scheduler,
"PROMPT": args.prompts[0],
"NEG_PROMPT": args.negative_prompts[0],
"SEED": args.seed,
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
}
with open(csv_path, "a") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
dictwriter_obj.writerow(new_entry)
csv_obj.close()
if args.save_metadata_to_json:
del new_entry["OUTPUT"]
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
txt2img_obj = None
config_obj = None
schedulers = None
# Exposed to UI.
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_size: int,
scheduler: str,
model_id: str,
custom_model_id: str,
ckpt_loc: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
):
global txt2img_obj
global config_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.scheduler = scheduler
args.hf_model_id = custom_model_id if custom_model_id else model_id
args.ckpt_loc = "" if ckpt_loc == "None" else ckpt_loc
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
)
if config_obj != new_config_obj:
config_obj = new_config_obj
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.use_tuned = True
args.import_mlir = False
set_init_device_flags()
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
)
if not txt2img_obj:
sys.exit("text to image pipeline must not return a null value")
txt2img_obj.scheduler = schedulers[scheduler]
start_time = time.time()
txt2img_obj.log = ""
generated_imgs = txt2img_obj.generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
save_output_img(generated_imgs[0])
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
return generated_imgs, text_output
if __name__ == "__main__":
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
)
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
args.seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0])
print(text_output)

View File

@@ -1,78 +0,0 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torchvision')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('diffusers')
datas += copy_metadata('transformers')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('gradio')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
( 'web/logos/*', 'logos' )
]
binaries = []
block_cipher = None
a = Analysis(
['web/index.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -1,77 +0,0 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torchvision')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('diffusers')
datas += copy_metadata('transformers')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('gradio')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
]
binaries = []
block_cipher = None
a = Analysis(
['scripts/txt2img.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -1,8 +0,0 @@
from apps.stable_diffusion.src.utils import (
args,
set_init_device_flags,
prompt_examples,
get_available_devices,
)
from apps.stable_diffusion.src.pipelines import Text2ImagePipeline
from apps.stable_diffusion.src.schedulers import get_schedulers

View File

@@ -1,11 +0,0 @@
from apps.stable_diffusion.src.models.model_wrappers import (
SharkifyStableDiffusionModel,
)
from apps.stable_diffusion.src.models.opt_params import (
get_vae,
get_unet,
get_clip,
get_tokenizer,
get_params,
get_variant_version,
)

View File

@@ -1,253 +0,0 @@
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from collections import defaultdict
import torch
import traceback
import re
import os, sys, functools, operator
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
base_models,
args,
get_vmfb_path_name,
)
# These shapes are parameter dependent.
def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape = []
for i in range(len(shape)):
if shape[i] == "max_len":
new_shape.append(max_len)
elif shape[i] == "height":
new_shape.append(height)
elif shape[i] == "width":
new_shape.append(width)
elif isinstance(shape[i], str):
if "batch_size" in shape[i]:
mul_val = int(shape[i].split("*")[0])
new_shape.append(batch_size * mul_val)
else:
new_shape.append(shape[i])
return new_shape
# Get the input info for various models i.e. "unet", "clip", "vae".
def get_input_info(model_info, max_len, width, height, batch_size):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = defaultdict(list)
for k in model_info:
for inp in model_info[k]:
shape = model_info[k][inp]["shape"]
dtype = dtype_config[model_info[k][inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, max_len, width, height, batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map[k].append(tensor)
return input_map
class SharkifyStableDiffusionModel:
def __init__(
self,
model_id: str,
custom_weights: str,
precision: str,
max_len: int = 64,
width: int = 512,
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.model_id = model_id if custom_weights == "" else custom_weights
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
str(batch_size)
+ "_"
+ str(max_len)
+ "_"
+ str(height)
+ "_"
+ str(width)
+ "_"
+ precision
)
self.use_tuned = use_tuned
# We need a better naming convention for the .vmfbs because despite
# using the custom model variant the .vmfb names remain the same and
# it'll always pick up the compiled .vmfb instead of compiling the
# custom model.
# So, currently, we add `self.model_id` in the `self.model_name` of
# .vmfb file.
# TODO: Have a better way of naming the vmfbs using self.model_name.
import re
model_name = re.sub(r"\W+", "_", self.model_id)
if model_name[0] == "_":
model_name = model_name[1:]
self.model_name = self.model_name + "_" + model_name
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
sys.exit("please specify max_len in the range [32, 77].")
if not (width % 8 == 0 and width >= 384):
sys.exit("width should be greater than 384 and multiple of 8")
if not (height % 8 == 0 and height >= 384):
sys.exit("height should be greater than 384 and multiple of 8")
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
)
self.base_vae = base_vae
def forward(self, input):
if not self.base_vae:
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
if self.base_vae:
return x
x = x * 255.0
return x.round()
vae = VaeModel()
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
vae_name = "base_vae" if self.base_vae else "vae"
shark_vae = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
model_name=vae_name + self.model_name,
extra_args=get_opt_flags("vae", precision=self.precision),
)
return shark_vae
def get_unet(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(
self, latent, timestep, text_embedding, guidance_scale
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel()
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
shark_unet = compile_through_fx(
unet,
inputs,
model_name="unet" + self.model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_unet
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
)
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText()
shark_clip = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
model_name="clip" + self.model_name,
extra_args=get_opt_flags("clip", precision="fp32"),
)
return shark_clip
def __call__(self):
model_name = ["clip", "base_vae" if self.base_vae else "vae", "unet"]
vmfb_path = [
get_vmfb_path_name(model + self.model_name)[0]
for model in model_name
]
for model_id in base_models:
self.inputs = get_input_info(
base_models[model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
try:
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
except Exception as e:
if args.enable_stack_trace:
traceback.print_exc()
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = functools.reduce(
operator.__and__, vmfb_present
)
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(len(vmfb_path)):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
print("Retrying with a different base model configuration")
continue
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please use `enable_stack_trace` and create an issue at https://github.com/nod-ai/SHARK/issues"
)

View File

@@ -1,117 +0,0 @@
import sys
from transformers import CLIPTokenizer
from apps.stable_diffusion.src.utils import models_db, args, get_shark_model
hf_model_variant_map = {
"Linaqruf/anything-v3.0": ["anythingv3", "v2_1base"],
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v2_1base"],
"prompthero/openjourney": ["openjourney", "v2_1base"],
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1"],
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
}
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
def get_params(bucket_key, model_key, model, is_tuned, precision):
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
try:
bucket = models_db[0][bucket_key]
model_name = models_db[1][model_key]
iree_flags += models_db[2][model][is_tuned][precision][
"default_compilation_flags"
]
except KeyError:
raise Exception(
f"{bucket_key}/{model_key} is not present in the models database"
)
if (
"specified_compilation_flags"
in models_db[2][model][is_tuned][precision]
):
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
if (
device
not in models_db[2][model][is_tuned][precision][
"specified_compilation_flags"
]
):
device = "default_device"
iree_flags += models_db[2][model][is_tuned][precision][
"specified_compilation_flags"
][device]
return bucket, model_name, iree_flags
def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "unet", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
is_base = "/base" if args.use_base_vae else ""
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
variant, version = get_variant_version(args.hf_model_id)
bucket_key = f"{variant}/untuned"
model_key = (
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
)
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "clip", "untuned", "fp32"
)
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
)
return tokenizer

View File

@@ -1,3 +0,0 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)

View File

@@ -1,134 +0,0 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
return all_imgs

View File

@@ -1,208 +0,0 @@
import torch
from transformers import CLIPTokenizer
from PIL import Image
from tqdm.auto import tqdm
import time
from typing import Union
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae,
get_clip,
get_unet,
get_tokenizer,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
preprocessCKPT,
)
class StableDiffusionPipeline:
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
):
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.unet = unet
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
text_input = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
# Get unconditional embeddings as well
uncond_input = self.tokenizer(
neg_prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
clip_inf_start = time.time()
text_embeddings = self.text_encoder("forward", (text_input,))
clip_inf_time = (time.time() - clip_inf_start) * 1000
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
if use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
if use_base_vae:
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def produce_img_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = self.scheduler.scale_model_input(latents, t)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
@classmethod
def from_pretrained(
cls,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
import_mlir: bool,
model_id: str,
ckpt_loc: str,
precision: str,
max_length: int,
batch_size: int,
height: int,
width: int,
use_base_vae: bool,
):
if import_mlir:
if ckpt_loc != "":
assert ckpt_loc.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
ckpt_loc = preprocessCKPT()
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
return cls(
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
)

View File

@@ -1,4 +0,0 @@
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)

View File

@@ -1,51 +0,0 @@
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
def get_schedulers(model_id):
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"SharkEulerDiscrete"
] = SharkEulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers

View File

@@ -1,143 +0,0 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_shark_model,
args,
)
import torch
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
)
def compile(self):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
model_input = {
"euler": {
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
}
example_latent = model_input["euler"]["latent"]
example_output = model_input["euler"]["output"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_sigma = model_input["euler"]["sigma"]
example_dt = model_input["euler"]["dt"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.import_mlir:
scaling_model = ScalingModel()
self.scaling_model = compile_through_fx(
scaling_model,
(example_latent, example_sigma),
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
else:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
return self.scaling_model(
"forward",
(
sample,
sigma,
),
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
return self.step_model(
"forward",
(
noise_pred,
sigma,
latent,
dt,
),
send_to_host=False,
)

View File

@@ -1,24 +0,0 @@
from apps.stable_diffusion.src.utils.profiler import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.utils.resources import (
prompt_examples,
models_db,
base_models,
opt_flags,
resource_path,
)
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.utils import (
get_vmfb_path_name,
get_shark_model,
compile_through_fx,
set_iree_runtime_flags,
map_device_to_name_path,
set_init_device_flags,
get_available_devices,
get_opt_flags,
preprocessCKPT,
)

View File

@@ -1,18 +0,0 @@
from apps.stable_diffusion.src.utils.stable_args import args
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device:
import iree
print(f"Profiling and saving to {file_path}.")
vulkan_device = iree.runtime.get_device(args.device)
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
return vulkan_device
return None
def end_profiling(device):
if device:
return device.end_profiling()

View File

@@ -1,37 +0,0 @@
import os
import json
import sys
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
def get_json_file(path):
json_var = []
loc_json = resource_path(path)
if os.path.exists(loc_json):
with open(loc_json, encoding="utf-8") as fopen:
json_var = json.load(fopen)
if not json_var:
print(f"Unable to fetch {path}")
return json_var
# TODO: This shouldn't be called from here, every time the file imports
# it will run all the global vars.
prompt_examples = get_json_file("resources/prompts.json")
models_db = get_json_file("resources/model_db.json")
# The base_model contains the input configuration for the different
# models and also helps in providing information for the variants.
base_models = get_json_file("resources/base_model.json")
# Contains optimization flags for different models.
opt_flags = get_json_file("resources/opt_flags.json")

View File

@@ -1,98 +0,0 @@
{
"stabilityai/stable-diffusion-2-1": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"CompVis/stable-diffusion-v1-4": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
}
}

View File

@@ -1,21 +0,0 @@
[
{
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
"openjourney/v1_4":"prompthero/openjourney",
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
},
{
"stablediffusion/fp16":"fp16",
"stablediffusion/fp32":"main",
"anythingv3/fp16":"diffusers",
"anythingv3/fp32":"diffusers",
"analogdiffusion/fp16":"main",
"analogdiffusion/fp32":"main",
"openjourney/fp16":"main",
"openjourney/fp32":"main"
}
]

View File

@@ -1,177 +0,0 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/latest",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
"anythingv3/v2_1base/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
"analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
"openjourney/v2_1base/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
"openjourney/v2_1base/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
"openjourney/v2_1base/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
},
{
"unet": {
"tuned": {
"fp16": {
"default_compilation_flags": []
},
"fp32": {
"default_compilation_flags": []
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32"
],
"specified_compilation_flags": {
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
}
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
]
}
}
},
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
]
}
}
},
"clip": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
}
}
}
}
]

View File

@@ -1,101 +0,0 @@
{
"unet": {
"tuned": {
"fp16": {
"default_compilation_flags": []
},
"fp32": {
"default_compilation_flags": []
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32"
],
"specified_compilation_flags": {
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
}
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
]
}
}
},
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": ["--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"]
}
},
"fp32": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
}
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
]
}
}
},
"clip": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
}
}
}
}

View File

@@ -1,8 +0,0 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]

View File

@@ -1,206 +0,0 @@
import os
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import iree_target_map, run_cmd
from shark.shark_downloader import (
download_model,
download_public_file,
WORKDIR,
)
from shark.parser import shark_args
from apps.stable_diffusion.src.utils.stable_args import args
def get_device():
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
return device
# Download the model (Unet or VAE fp16) from shark_tank
def load_model_from_tank():
from apps.stable_diffusion.src.models import (
get_params,
get_variant_version,
)
version, variant = get_variant_version(args.hf_model_id)
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{variant}/untuned"
if args.annotation_model == "unet":
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
is_base = "/base" if args.use_base_vae else ""
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)
return mlir_model, model_name
# Download the tuned config files from shark_tank
def load_winograd_configs():
device = get_device()
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
winograd_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
def load_lower_configs():
from apps.stable_diffusion.src.models import get_variant_version
version, variant = get_variant_version(args.hf_model_id)
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_version = version
if variant in ["anythingv3", "analogdiffusion"]:
args.max_length = 77
config_version = "v1_4"
if args.annotation_model == "vae":
args.max_length = 77
device = get_device()
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json"
full_gs_url = config_bucket + config_name
lowering_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading lowering config file from ", lowering_config_dir)
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
# Annotate the model with Winograd attribute on selected conv ops
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with create_context() as ctx:
winograd_model = model_annotation(
ctx,
input_contents=input_mlir,
config_path=winograd_config_dir,
search_op="conv",
winograd=True,
)
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return winograd_model, out_file_path
# For Unet annotate the model with tuned lowering configs
def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
if use_winograd:
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
else:
dump_after = "iree-flow-pad-linalg-ops"
# Dump IR after padding/img2col/winograd passes
device_spec_args = ""
device = get_device()
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
gpu_flags = get_iree_gpu_args()
for flag in gpu_flags:
device_spec_args += flag + " "
elif device == "vulkan":
device_spec_args = (
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
print("Applying tuned configs on", model_name)
run_cmd(
f"iree-compile {input_mlir} "
"--iree-input-type=tm_tensor "
f"--iree-hal-target-backends={iree_target_map(device)} "
f"{device_spec_args}"
"--iree-stream-resource-index-bits=64 "
"--iree-vm-target-index-bits=64 "
"--iree-flow-enable-padding-linalg-ops "
"--iree-flow-linalg-ops-padding-size=32 "
"--iree-flow-enable-conv-img2col-transform "
f"--mlir-print-ir-after={dump_after} "
"--compile-to=flow "
f"2>{args.annotation_output}/dump_after_winograd.mlir "
)
# Annotate the model with lowering configs in the config file
with create_context() as ctx:
tuned_model = model_annotation(
ctx,
input_contents=f"{args.annotation_output}/dump_after_winograd.mlir",
config_path=lowering_config_dir,
search_op="all",
)
# Remove the intermediate mlir and save the final annotated model
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return tuned_model, out_file_path
def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model, model_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
model_path, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model, output_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
use_winograd = False
if model_from_tank:
mlir_model = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
else:
# Just use this function to convert bytecode to string
orig_model, model_path = annotate_with_winograd(
mlir_model, "", model_name
)
mlir_model = model_path
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)
print(f"Saved the annotated mlir in {output_path}.")
return tuned_model, output_path
if __name__ == "__main__":
mlir_model, model_name = load_model_from_tank()
sd_model_annotation(mlir_model, model_name, model_from_tank=True)

View File

@@ -1,337 +0,0 @@
import argparse
from pathlib import Path
def path_expand(s):
return Path(s).expanduser().resolve()
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Stable Diffusion Params
##############################################################################
p.add_argument(
"-p",
"--prompts",
action="append",
default=[],
help="text of which images to be generated.",
)
p.add_argument(
"--negative_prompts",
nargs="+",
default=[""],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=int,
default=42,
help="the seed to use.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `run`.",
)
p.add_argument(
"--height",
type=int,
default=512,
help="the height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
help="the width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="the value to be used for guidance scaling.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
)
##############################################################################
### Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
)
p.add_argument(
"--runs",
type=int,
default=1,
help="number of images to be generated with random seeds in single execution",
)
p.add_argument(
"--ckpt_loc",
type=str,
default="",
help="Path to SD's .ckpt file.",
)
p.add_argument(
"--hf_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--enable_stack_trace",
default=False,
action=argparse.BooleanOptionalAction,
help="Enable showing the stack trace when retrying the base model configuration",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree-vulkan-target-triple",
type=str,
default="",
help="Specify target triple for vulkan",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="4147483648",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
)
##############################################################################
### Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save a generation information json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
)
##############################################################################
### Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the pregress bar animation during image generation",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
##############################################################################
### SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file",
)
p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)
p.add_argument(
"--use_winograd",
default=False,
action=argparse.BooleanOptionalAction,
help="Apply Winograd on selected conv ops.",
)
args, unknown = p.parse_known_args()

View File

@@ -1,395 +0,0 @@
import os
import gc
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
def get_vmfb_path_name(model_name):
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
return [vmfb_path, extended_name]
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
[vmfb_path, extended_name] = get_vmfb_path_name(model_name)
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), extended_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.shark_downloader import download_model
from shark.parser import shark_args
# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=tank_url,
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
)
return _compile_module(shark_module, model_name, extra_args)
# Converts the torch-module into a shark_module.
def compile_through_fx(
model,
inputs,
model_name,
is_f16=False,
f16_input_mask=None,
use_tuned=False,
extra_args=[],
):
from shark.parser import shark_args
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
)
if use_tuned:
model_name = model_name + "_tuned"
tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir"
if not os.path.exists(tuned_model_path):
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
tuned_model, tuned_model_path = sd_model_annotation(
mlir_module, model_name
)
del mlir_module, tuned_model
gc.collect()
with open(tuned_model_path, "rb") as f:
mlir_module = f.read()
f.close()
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def set_init_device_flags():
if "vulkan" in args.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "cpu" in args.device:
args.device = "cpu"
# set max_length based on availability.
if args.hf_model_id in [
"Linaqruf/anything-v3.0",
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]:
args.max_length = 77
elif args.hf_model_id == "prompthero/openjourney":
args.max_length = 64
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
if (
args.hf_model_id
in ["prompthero/openjourney", "dreamlike-art/dreamlike-diffusion-1.0"]
or args.precision != "fp16"
or args.height != 512
or args.width != 512
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
args.use_tuned = False
elif (
"vulkan" in args.device
and "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
elif "cuda" in args.device and get_cuda_sm_cc() not in [
"sm_80",
"sm_84",
"sm_86",
"sm_89",
]:
args.use_tuned = False
elif args.use_base_vae and args.hf_model_id not in [
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {args.hf_model_id}/fp16/{args.device}.")
else:
print("Tuned models are currently not supported for this setting.")
# set import_mlir to True for unuploaded models.
if args.ckpt_loc != "":
args.import_mlir = True
elif args.hf_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.import_mlir = True
elif args.height != 512 or args.width != 512 or args.batch_size != 1:
args.import_mlir = True
# Utility to get list of devices available.
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
for i, device in enumerate(device_list_dict):
device_list.append(f"{device['name']} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
return available_devices
def disk_space_check(path, lim=20):
from shutil import disk_usage
du = disk_usage(path)
free = du.free / (1024 * 1024 * 1024)
if free <= lim:
print(f"[WARNING] Only {free:.2f}GB space available in {path}.")
def get_opt_flags(model, precision="fp16"):
iree_flags = []
is_tuned = "tuned" if args.use_tuned else "untuned"
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
"default_compilation_flags"
]
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
if (
device
not in opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
]
):
device = "default_device"
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
return iree_flags
def preprocessCKPT():
path = Path(args.ckpt_loc)
diffusers_path = path.parent.absolute()
diffusers_directory_name = path.stem
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
print(
"Created directory : ",
diffusers_directory_name,
" at -> ",
diffusers_path,
)
path_to_diffusers = complete_path_to_diffusers.as_posix()
from_safetensors = (
True if args.ckpt_loc.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
extract_ema = False
print("Loading pipeline from original stable diffusion checkpoint")
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=args.ckpt_loc,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
)
pipe.save_pretrained(path_to_diffusers)
print("Loading complete")
print("Custom model path is : ", path_to_diffusers)
return path_to_diffusers

View File

@@ -1,154 +0,0 @@
# Stable Diffusion optimized for AMD RDNA2/RDNA3 GPUs
Before you start, please be aware that this is beta software that relies on a special AMD driver. Like all StableDiffusion GUIs published so far, you need some technical expertise to set it up. We apologize in advance if you bump into issues. If that happens, please don't hesitate to ask our Discord community for help! If you still can't get it to work, we're sorry, and please be assured that we (Nod and AMD) are working hard to improve the user experience in coming months.
If it works well for you, please "star" the following GitHub projects... this is one of the best ways to help and spread the word!
* https://github.com/nod-ai/SHARK
* https://github.com/iree-org/iree
## Install this specific AMD Drivers (AMD latest may not have all the fixes).
### AMD KB Drivers for RDNA2 and RDNA3:
*AMD Software: Adrenalin Edition 22.11.1 for MLIR/IREE Driver Version 22.20.29.09 for Windows® 10 and Windows® 11 (Windows Driver Store Version 31.0.12029.9003)*
First, for RDNA2 users, download this special driver in a folder of your choice. We recommend you keep the installation files around, since you may need to re-install it later, if Windows Update decides to overwrite it:
https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mlir-iree
For RDNA3, the latest driver 23.1.2 supports MLIR/IREE as well: https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-1-2-kb
KNOWN ISSUES with this special AMD driver:
* `Windows Update` may (depending how it's configured) automatically install a new official AMD driver that overwrites this IREE-specific driver. If Stable Diffusion used to work, then a few days later, it slows down a lot or produces incorrect results (e.g. black images), this may be the cause. To fix this problem, please check the installed driver version, and re-install the special driver if needed. (TODO: document how to prevent this `Windows Update` behavior!)
* Some people using this special driver experience mouse pointer accuracy issues, especially if using a larger-than-default mouse pointer. The clicked point isn't centered properly. One possible work-around is to reset the pointer size to "1" in "Change pointer size and color".
## Installation
Download the latest Windows SHARK SD binary [469 here](https://github.com/nod-ai/SHARK/releases/download/20230124.469/shark_sd_20230124_469.exe) in a folder of your choice. If you want nighly builds, you can look for them on the GitHub releases page.
Notes:
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR which can be outdated if you run a new EXE from the same folder. You can use `--clean_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you:
* clear all the local artifacts with `--clear_all` OR
* clear the Vulkan shader cache: For Windows users this can be done by clearing the contents of `C:\Users\%username%\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
* clear the `huggingface` cache. In Windows, this is `C:\Users\%username%\.cache\huggingface`.
## Running
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE to start the web browser)
* The first run may take about 10-15 minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
* If successful, you will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/?__theme=dark.
## Stopping
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment. The application should stop.
* Please make sure to do the above step before you attempt to update the EXE to a new version.
# Results
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
Here are some samples generated:
![tajmahal, snow, sunflowers, oil on canvas_0](https://user-images.githubusercontent.com/74956/204934186-141f7e43-6eb2-4e89-a99c-4704d20444b3.jpg)
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
<details>
<summary>Advanced Installation </summary>
## Setup your Python Virtual Environment and Dependencies
<details>
<summary> Windows 10/11 Users </summary>
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
* Install Git for Windows from [here](https://git-scm.com/download/win)
#### Allow the install script to run in Powershell
```powershell
set-executionpolicy remotesigned
```
#### Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...)
```powershell
git clone https://github.com/nod-ai/SHARK.git
cd SHARK
./setup_venv.ps1 #You can re-run this script to get the latest version
```
</details>
<details>
<summary>Linux</summary>
```shell
git clone https://github.com/nod-ai/SHARK.git
cd SHARK
./setup_venv.sh
source shark.venv/bin/activate
```
</details>
### Run Stable Diffusion on your device - WebUI
<details>
<summary>Windows 10/11 Users</summary>
```powershell
(shark.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
```
</details>
<details>
<summary>Linux Users</summary>
```shell
(shark.venv) > cd apps/stable_diffusion/web
(shark.venv) > python index.py
```
</details>
### Run Stable Diffusion on your device - Commandline
<details>
<summary>Windows 10/11 Users</summary>
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
</details>
<details>
<summary>Linux</summary>
```shell
python3.10 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
</details>
The output on a 7900XTX would like:
```shell
Stats for run 0:
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
Total image generation time: 2.5788655281066895sec
```
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)
</details>
<details>
<summary>Discord link</summary>
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
</details>

View File

@@ -1,15 +0,0 @@
You need to pre-create your bot (https://core.telegram.org/bots#how-do-i-create-a-bot)
Then create in the directory web file .env
In it the record:
TG_TOKEN="your_token"
specifying your bot's token from previous step.
Then run telegram_bot.py with the same parameters that you use when running index.py, for example:
python telegram_bot.py --max_length=77 --vulkan_large_heap_block_size=0 --use_base_vae --local_tank_cache h:\shark\TEMP
Bot commands:
/select_model
/select_scheduler
/set_steps "integer number of steps"
/set_guidance_scale "integer number"
/set_negative_prompt "negative text"
Any other text triggers the creation of an image based on it.

View File

@@ -1,67 +0,0 @@
.gradio-container {
background-color: black
}
.container {
background-color: black !important;
padding-top: 20px !important;
}
#ui_title {
padding: 10px !important;
}
#top_logo {
background-color: transparent;
border-radius: 0 !important;
border: 0;
}
#demo_title {
background-color: black;
border-radius: 0 !important;
border: 0;
padding-top: 50px;
padding-bottom: 0px;
width: 460px !important;
}
#demo_title_outer {
border-radius: 0;
}
#prompt_box_outer div:first-child {
border-radius: 0 !important
}
#prompt_box textarea {
background-color: #1d1d1d !important
}
#prompt_examples {
margin: 0 !important
}
#prompt_examples svg {
display: none !important;
}
.gr-sample-textbox {
border-radius: 1rem !important;
border-color: rgb(31, 41, 55) !important;
border-width: 2px !important;
}
#ui_body {
background-color: #111111 !important;
padding: 10px !important;
border-radius: 0.5em !important;
}
#img_result+div {
display: none !important;
}
footer {
display: none !important;
}

View File

@@ -1,272 +0,0 @@
import os
import sys
from pathlib import Path
import glob
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
import gradio as gr
from PIL import Image
from apps.stable_diffusion.src import (
prompt_examples,
args,
get_available_devices,
)
from apps.stable_diffusion.scripts import txt2img_inf
nodlogo_loc = resource_path("logos/nod-logo.png")
sdlogo_loc = resource_path("logos/sd-demo-logo.png")
demo_css = resource_path("css/sd_dark_theme.css")
with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
logo2 = Image.open(sdlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=100)
with gr.Column(scale=5, elem_id="demo_title_outer"):
gr.Image(
value=logo2,
show_label=False,
interactive=False,
elem_id="demo_title",
).style(width=150, height=100)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
model_id = gr.Dropdown(
label="Model ID",
value="stabilityai/stable-diffusion-2-1-base",
choices=[
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
],
)
custom_model_id = gr.Textbox(
placeholder="SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
)
with gr.Group():
ckpt_path = "models"
types = (
"*.ckpt",
"*.safetensors",
) # the tuple of file types
ckpt_files = ["None"]
for extn in types:
files = glob.glob(os.path.join(ckpt_path, extn))
ckpt_files.extend(files)
ckpt_loc = gr.Dropdown(
label="Place all checkpoints at SHARK/apps/stable_diffusion/web/models/",
value="None",
choices=ckpt_files,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value="cyberpunk forest by Salvador Dali",
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="trees, green",
lines=1,
elem_id="prompt_box",
)
with gr.Accordion(label="Advance Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
label="Scheduler",
value="SharkEulerDiscrete",
choices=[
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
],
)
batch_size = gr.Slider(
1, 4, value=1, step=1, label="Number of Images"
)
with gr.Row():
height = gr.Slider(
384, 786, value=512, step=8, label="Height"
)
width = gr.Slider(
384, 786, value=512, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=64,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=50, step=1, label="Steps"
)
guidance_scale = gr.Slider(
0,
50,
value=7.5,
step=0.1,
label="CFG Scale",
)
with gr.Row():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=False,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=False,
interactive=True,
)
with gr.Row():
seed = gr.Number(value=-1, precision=0, label="Seed")
available_devices = get_available_devices()
device = gr.Dropdown(
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => Math.floor(Math.random() * 4294967295)",
)
stable_diffusion = gr.Button("Generate Image")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2], height="auto")
std_output = gr.Textbox(
value="Nothing to show.",
lines=4,
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
prompt.submit(
txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_size,
scheduler,
model_id,
custom_model_id,
ckpt_loc,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
stable_diffusion.click(
txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_size,
scheduler,
model_id,
custom_model_id,
ckpt_loc,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
shark_web.queue()
shark_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.0 KiB

View File

@@ -42,7 +42,7 @@ class TFHuggingFaceLanguage(tf.Module):
input_ids=x, attention_mask=y, token_type_ids=z, training=False
)
@tf.function(input_signature=tf_bert_input, jit_compile=True)
@tf.function(input_signature=tf_bert_input)
def forward(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)

View File

@@ -1,45 +0,0 @@
import argparse
from PIL import Image
import numpy as np
import requests
import shutil
import os
import subprocess
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--newfile")
parser.add_argument(
"-g",
"--golden_url",
default="https://storage.googleapis.com/shark_tank/testdata/cyberpunk_fores_42_0_230119_021148.png",
)
def get_image(url, local_filename):
res = requests.get(url, stream=True)
if res.status_code == 200:
with open(local_filename, "wb") as f:
shutil.copyfileobj(res.raw, f)
def compare_images(new_filename, golden_filename):
new = np.array(Image.open(new_filename)) / 255.0
golden = np.array(Image.open(golden_filename)) / 255.0
diff = np.abs(new - golden)
mean = np.mean(diff)
if mean > 0.01:
subprocess.run(
["gsutil", "cp", new_filename, "gs://shark_tank/testdata/builder/"]
)
raise SystemExit("new and golden not close")
else:
print("SUCCESS")
if __name__ == "__main__":
args = parser.parse_args()
tempfile_name = os.path.join(os.getcwd(), "golden.png")
get_image(args.golden_url, tempfile_name)
compare_images(args.newfile, tempfile_name)

View File

@@ -1,5 +0,0 @@
#!/bin/bash
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python generate_sharktank.py --upload=False --ci_tank_dir=True

View File

@@ -1,37 +0,0 @@
"""Scrapes the github releases API to generate a static pip-install-able releases page.
See https://github.com/llvm/torch-mlir/issues/1374
"""
import argparse
import json
import requests
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("owner", type=str)
parser.add_argument("repo", type=str)
args = parser.parse_args()
# Get releases
response = requests.get(
f"https://api.github.com/repos/{args.owner}/{args.repo}/releases"
)
body = json.loads(response.content)
# Parse releases
releases = []
for row in body:
for asset in row["assets"]:
releases.append((asset["name"], asset["browser_download_url"]))
# Output HTML
html = """<!DOCTYPE html>
<html>
<body>
"""
for name, url in releases:
html += f" <a href='{url}'>{name}</a><br />\n"
html += """ </body>
</html>"""
print(html)

View File

@@ -1,7 +0,0 @@
rm -rf ./test_images
mkdir test_images
python shark/examples/shark_inference/stable_diffusion/main.py --device=vulkan --output_dir=./test_images --no-load_vmfb --no-use_tuned
python shark/examples/shark_inference/stable_diffusion/main.py --device=vulkan --output_dir=./test_images --no-load_vmfb --no-use_tuned --beta_models=True
python build_tools/image_comparison.py -n ./test_images/*.png
exit $?

View File

@@ -1,77 +0,0 @@
import os
import subprocess
from apps.stable_diffusion.src.utils.resources import (
get_json_file,
)
from shark.shark_downloader import download_public_file
from image_comparison import compare_images
import argparse
from glob import glob
import shutil
model_config_dicts = get_json_file(
os.path.join(
os.getcwd(),
"apps/stable_diffusion/src/utils/resources/model_config.json",
)
)
def test_loop(device="vulkan", beta=False, extra_flags=[]):
# Get golden values from tank
shutil.rmtree("./test_images", ignore_errors=True)
os.mkdir("./test_images")
os.mkdir("./test_images/golden")
hf_model_names = model_config_dicts[0].values()
tuned_options = ["--no-use_tuned"] #'use_tuned']
devices = ["vulkan"]
if beta:
extra_flags.append("--beta_models=True")
for model_name in hf_model_names:
for use_tune in tuned_options:
command = [
"python",
"apps/stable_diffusion/scripts/txt2img.py",
"--device=" + device,
"--output_dir=./test_images/" + model_name,
"--hf_model_id=" + model_name,
use_tune,
]
command += extra_flags
generated_image = not subprocess.call(
command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
if generated_image:
os.makedirs(
"./test_images/golden/" + model_name, exist_ok=True
)
download_public_file(
"gs://shark_tank/testdata/golden/" + model_name,
"./test_images/golden/" + model_name,
)
comparison = [
"python",
"build_tools/image_comparison.py",
"--golden_url=gs://shark_tank/testdata/golden/"
+ model_name
+ "/*.png",
"--newfile=./test_images/" + model_name + "/*.png",
]
test_file = glob("./test_images/" + model_name + "/*.png")[0]
golden_path = "./test_images/golden/" + model_name + "/*.png"
golden_file = glob(golden_path)[0]
compare_images(test_file, golden_file)
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--device", default="vulkan")
parser.add_argument(
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
)
if __name__ == "__main__":
args = parser.parse_args()
print(args)
test_loop(args.device, args.beta, [])

View File

@@ -1,5 +1,17 @@
def pytest_addoption(parser):
# Attaches SHARK command-line arguments to the pytest machinery.
parser.addoption(
"--save_mlir",
action="store_true",
default="False",
help="Pass option to save input MLIR",
)
parser.addoption(
"--save_vmfb",
action="store_true",
default="False",
help="Pass option to save IREE output .vmfb",
)
parser.addoption(
"--benchmark",
action="store_true",
@@ -7,56 +19,8 @@ def pytest_addoption(parser):
help="Pass option to benchmark and write results.csv",
)
parser.addoption(
"--onnx_bench",
"--save_temps",
action="store_true",
default="False",
help="Add ONNX benchmark results to pytest benchmarks.",
)
parser.addoption(
"--tf32",
action="store_true",
default="False",
help="Use TensorFloat-32 calculations.",
)
parser.addoption(
"--save_repro",
action="store_true",
default="False",
help="Pass option to save reproduction artifacts to SHARK/shark_tmp/test_case/",
)
parser.addoption(
"--save_fails",
action="store_true",
default="False",
help="Save reproduction artifacts for a test case only if it fails. Default is False.",
)
parser.addoption(
"--ci",
action="store_true",
default="False",
help="Enables uploading of reproduction artifacts upon test case failure during iree-compile or validation. Must be passed with --ci_sha option ",
)
parser.addoption(
"--update_tank",
action="store_true",
default="False",
help="Update local shark tank with latest artifacts.",
)
parser.addoption(
"--ci_sha",
action="store",
default="None",
help="Passes the github SHA of the CI workflow to include in google storage directory for reproduction artifacts.",
)
parser.addoption(
"--local_tank_cache",
action="store",
default="",
help="Specify the directory in which all downloaded shark_tank artifacts will be cached.",
)
parser.addoption(
"--tank_url",
type=str,
default="gs://shark_tank/latest",
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
help="Saves IREE reproduction artifacts for filing upstream issues.",
)

3
cpp/.gitignore vendored
View File

@@ -1,3 +0,0 @@
*.mlir
*.vmfb
*.ini

View File

@@ -1,52 +0,0 @@
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
cmake_minimum_required(VERSION 3.21...3.23)
#-------------------------------------------------------------------------------
# Project configuration
#-------------------------------------------------------------------------------
project(iree-samples C CXX)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 17)
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
#-------------------------------------------------------------------------------
# Core project dependency
#-------------------------------------------------------------------------------
message(STATUS "Fetching core IREE repo (this may take a few minutes)...")
# Note: for log output, set -DFETCHCONTENT_QUIET=OFF,
# see https://gitlab.kitware.com/cmake/cmake/-/issues/18238#note_440475
include(FetchContent)
FetchContent_Declare(
iree
GIT_REPOSITORY https://github.com/nod-ai/shark-runtime.git
GIT_TAG shark
GIT_SUBMODULES_RECURSE OFF
GIT_SHALLOW OFF
GIT_PROGRESS ON
USES_TERMINAL_DOWNLOAD ON
)
# Extend module path to find MLIR CMake modules.
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_BINARY_DIR}/lib/cmake/mlir")
# Disable core project features not needed for these out of tree samples.
set(IREE_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(IREE_BUILD_SAMPLES OFF CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(iree)
FetchContent_GetProperties(iree SOURCE_DIR IREE_SOURCE_DIR)
#-------------------------------------------------------------------------------
# Individual samples
#-------------------------------------------------------------------------------
add_subdirectory(vulkan_gui)

View File

@@ -1,82 +0,0 @@
# SHARK C/C++ Samples
These C/C++ samples can be built using CMake. The samples depend on the main
SHARK-Runtime project's C/C++ sources, including both the runtime and the compiler.
Individual samples may require additional dependencies. Watch CMake's output
for information about which you are missing for individual samples.
On Windows we recommend using https://github.com/microsoft/vcpkg to download packages for
your system. The general setup flow looks like
*Install and activate SHARK*
```bash
source shark.venv/bin/activate #follow main repo instructions to setup your venv
```
*Install Dependencies*
```bash
vcpkg install [library] --triplet [your platform]
vcpkg integrate install
# Then pass `-DCMAKE_TOOLCHAIN_FILE=[check logs for path]` when configuring CMake
```
In Ubuntu Linux you can install
```bash
sudo apt install libsdl2-dev
```
*Build*
```bash
cd cpp
cmake -GNinja -B build/
cmake --build build/
```
*Prepare the model*
```bash
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvm-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
```
*Prepare the input*
```bash
python save_img.py
```
Note that this requires tensorflow, e.g.
```bash
python -m pip install tensorflow
```
*Run the vulkan_gui*
```bash
./build/vulkan_gui/iree-samples-resnet-vulkan-gui
```
## Other models
A tool for benchmarking other models is built and can be invoked with a command like the following
```bash
./build/vulkan_gui/iree-vulkan-gui --module-file=path/to/.vmfb --function_input=...
```
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
```bash
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
```
VAE and Autoencoder are also available
```bash
# VAE
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
# CLIP Autoencoder
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
```

Binary file not shown.

Before

Width:  |  Height:  |  Size: 26 KiB

View File

@@ -1,18 +0,0 @@
import numpy as np
import tensorflow as tf
from shark.shark_inference import SharkInference
def load_and_preprocess_image(fname: str):
image = tf.io.read_file(fname)
image = tf.image.decode_image(image, channels=3)
image = tf.image.resize(image, (224, 224))
image = image[tf.newaxis, :]
# preprocessing pipeline
input_tensor = tf.keras.applications.resnet50.preprocess_input(image)
return input_tensor
data = load_and_preprocess_image("dog_imagenet.jpg").numpy()
data.tofile("dog.bin")

View File

@@ -1,84 +0,0 @@
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
if(NOT IREE_TARGET_BACKEND_LLVM_CPU OR
NOT IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF)
message(STATUS "Missing LLVM backend and/or embeddded elf loader, skipping vision_inference sample")
return()
endif()
# vcpkg install stb
# tested with version 2021-09-10
find_package(Stb)
if(NOT Stb_FOUND)
message(STATUS "Could not find Stb, skipping vision inference sample")
return()
endif()
# Compile mnist.mlir to mnist.vmfb.
set(_COMPILE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)
set(_COMPILE_ARGS)
list(APPEND _COMPILE_ARGS "--iree-input-type=mhlo")
list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu")
list(APPEND _COMPILE_ARGS "${IREE_SOURCE_DIR}/samples/models/mnist.mlir")
list(APPEND _COMPILE_ARGS "-o")
list(APPEND _COMPILE_ARGS "mnist.vmfb")
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb
COMMAND ${_COMPILE_TOOL_EXECUTABLE} ${_COMPILE_ARGS}
DEPENDS ${_COMPILE_TOOL_EXECUTABLE} "${IREE_SOURCE_DIR}/samples/models/mnist.mlir"
)
# Embed mnist.vmfb into a C file as mnist_bytecode_module_c.[h/c]
set(_EMBED_DATA_EXECUTABLE $<TARGET_FILE:generate_embed_data>)
set(_EMBED_ARGS)
list(APPEND _EMBED_ARGS "--output_header=mnist_bytecode_module_c.h")
list(APPEND _EMBED_ARGS "--output_impl=mnist_bytecode_module_c.c")
list(APPEND _EMBED_ARGS "--identifier=iree_samples_vision_inference_mnist_bytecode_module")
list(APPEND _EMBED_ARGS "--flatten")
list(APPEND _EMBED_ARGS "${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb")
add_custom_command(
OUTPUT "mnist_bytecode_module_c.h" "mnist_bytecode_module_c.c"
COMMAND ${_EMBED_DATA_EXECUTABLE} ${_EMBED_ARGS}
DEPENDS ${_EMBED_DATA_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb
)
# Define a library target for mnist_bytecode_module_c.
add_library(iree_samples_vision_inference_mnist_bytecode_module_c OBJECT)
target_sources(iree_samples_vision_inference_mnist_bytecode_module_c
PRIVATE
mnist_bytecode_module_c.h
mnist_bytecode_module_c.c
)
# Define the sample executable.
set(_NAME "iree-run-mnist-module")
add_executable(${_NAME} "")
target_sources(${_NAME}
PRIVATE
"image_util.h"
"image_util.c"
"iree-run-mnist-module.c"
)
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "iree-run-mnist-module")
target_include_directories(${_NAME} PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
)
target_include_directories(${_NAME} PRIVATE
${Stb_INCLUDE_DIR}
)
target_link_libraries(${_NAME}
iree_base_base
iree_base_tracing
iree_hal_hal
iree_runtime_runtime
iree_samples_vision_inference_mnist_bytecode_module_c
)
# Define a target that copies the test image into the build directory.
add_custom_target(iree_samples_vision_inference_test_image
COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/mnist_test.png" "${CMAKE_CURRENT_BINARY_DIR}/mnist_test.png")
add_dependencies(${_NAME} iree_samples_vision_inference_test_image)
message(STATUS "Configured vision_inference sample successfully")

View File

@@ -1,8 +0,0 @@
# Vision Inference Sample (C code)
This sample demonstrates how to run a MNIST handwritten digit detection vision
model on an image using IREE's C API.
A similar sample is implemented using a Python script and IREE's command line
tools over in the primary iree repository at
https://github.com/iree-org/iree/tree/main/samples/vision_inference

View File

@@ -1,224 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "image_util.h"
#include <math.h>
#include "iree/base/internal/flags.h"
#include "iree/base/tracing.h"
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
iree_status_t iree_tools_utils_pixel_rescaled_to_buffer(
const uint8_t* pixel_data, iree_host_size_t buffer_length,
const float* input_range, iree_host_size_t range_length,
float* out_buffer) {
IREE_TRACE_ZONE_BEGIN(z0);
if (range_length != 2) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"range defined as 2-element [min, max] array.");
}
float input_scale = fabsf(input_range[1] - input_range[0]) / 2.0f;
float input_offset = (input_range[0] + input_range[1]) / 2.0f;
const float kUint8Mean = 127.5f;
for (int i = 0; i < buffer_length; ++i) {
out_buffer[i] =
(((float)(pixel_data[i])) - kUint8Mean) / kUint8Mean * input_scale +
input_offset;
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
iree_status_t iree_tools_utils_load_pixel_data_impl(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length) {
int img_dims[3];
if (stbi_info(filename.data, img_dims, &(img_dims[1]), &(img_dims[2])) == 0) {
return iree_make_status(IREE_STATUS_NOT_FOUND, "can't load image %.*s",
(int)filename.size, filename.data);
}
if (!(element_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 ||
element_type == IREE_HAL_ELEMENT_TYPE_SINT_8 ||
element_type == IREE_HAL_ELEMENT_TYPE_UINT_8)) {
char element_type_str[16];
IREE_RETURN_IF_ERROR(iree_hal_format_element_type(
element_type, sizeof(element_type_str), element_type_str, NULL));
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"element type %s not supported", element_type_str);
}
switch (shape_rank) {
case 2: { // Assume tensor <height x width>
if (img_dims[2] != 1 || (shape[0] != img_dims[1]) ||
(shape[1] != img_dims[0])) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"image size: %dx%dx%d, expected: %" PRIdim "x%" PRIdim, img_dims[0],
img_dims[1], img_dims[2], shape[1], shape[0]);
}
break;
}
case 3: { // Assume tensor <height x width x channel>
if (shape[0] != img_dims[1] || shape[1] != img_dims[0] ||
shape[2] != img_dims[2]) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"image size: %dx%dx%d, expected: %" PRIdim
"x%" PRIdim "x%" PRIdim,
img_dims[0], img_dims[1], img_dims[2], shape[1],
shape[0], shape[2]);
}
break;
}
case 4: { // Assume tensor <batch x height x width x channel>
if (shape[1] != img_dims[1] || shape[2] != img_dims[0] ||
shape[3] != img_dims[2]) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"image size: %dx%dx%d, expected: %" PRIdim
"x%" PRIdim "x%" PRIdim,
img_dims[0], img_dims[1], img_dims[2], shape[2],
shape[1], shape[3]);
}
break;
}
default:
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"Input buffer shape rank %" PRIhsz " not supported", shape_rank);
}
// Drop the alpha channel if present.
int req_ch = (img_dims[2] >= 3) ? 3 : 0;
*out_pixel_data = stbi_load(filename.data, img_dims, &(img_dims[1]),
&(img_dims[2]), req_ch);
if (*out_pixel_data == NULL) {
return iree_make_status(IREE_STATUS_NOT_FOUND, "can't load image %.*s",
(int)filename.size, filename.data);
}
*out_buffer_length =
img_dims[0] * img_dims[1] * (img_dims[2] > 3 ? 3 : img_dims[2]);
return iree_ok_status();
}
iree_status_t iree_tools_utils_load_pixel_data(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t result = iree_tools_utils_load_pixel_data_impl(
filename, shape, shape_rank, element_type, out_pixel_data,
out_buffer_length);
IREE_TRACE_ZONE_END(z0);
return result;
}
iree_status_t iree_tools_utils_buffer_view_from_image(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
iree_hal_allocator_t* allocator, iree_hal_buffer_view_t** out_buffer_view) {
IREE_TRACE_ZONE_BEGIN(z0);
*out_buffer_view = NULL;
if (element_type != IREE_HAL_ELEMENT_TYPE_SINT_8 &&
element_type != IREE_HAL_ELEMENT_TYPE_UINT_8) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"element type should be i8 or u8");
}
iree_status_t result;
uint8_t* pixel_data = NULL;
iree_host_size_t buffer_length;
result = iree_tools_utils_load_pixel_data(
filename, shape, shape_rank, element_type, &pixel_data, &buffer_length);
if (iree_status_is_ok(result)) {
iree_host_size_t element_byte =
iree_hal_element_dense_byte_count(element_type);
// SINT_8 and UINT_8 perform direct buffer wrap.
result = iree_hal_buffer_view_allocate_buffer(
allocator, shape_rank, shape, element_type,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.access = IREE_HAL_MEMORY_ACCESS_READ,
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
IREE_HAL_BUFFER_USAGE_TRANSFER,
},
iree_make_const_byte_span(pixel_data, element_byte * buffer_length),
out_buffer_view);
}
stbi_image_free(pixel_data);
IREE_TRACE_ZONE_END(z0);
return result;
}
typedef struct iree_tools_utils_buffer_view_load_params_t {
const uint8_t* pixel_data;
iree_host_size_t pixel_data_length;
const float* input_range;
iree_host_size_t input_range_length;
} iree_tools_utils_buffer_view_load_params_t;
static iree_status_t iree_tools_utils_buffer_view_load_image_rescaled(
iree_hal_buffer_mapping_t* mapping, void* user_data) {
iree_tools_utils_buffer_view_load_params_t* params =
(iree_tools_utils_buffer_view_load_params_t*)user_data;
return iree_tools_utils_pixel_rescaled_to_buffer(
params->pixel_data, params->pixel_data_length, params->input_range,
params->input_range_length, (float*)mapping->contents.data);
}
iree_status_t iree_tools_utils_buffer_view_from_image_rescaled(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
iree_hal_allocator_t* allocator, const float* input_range,
iree_host_size_t input_range_length,
iree_hal_buffer_view_t** out_buffer_view) {
IREE_TRACE_ZONE_BEGIN(z0);
*out_buffer_view = NULL;
if (element_type != IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"element type should be f32");
}
// Classic row-major image layout.
iree_hal_encoding_type_t encoding_type =
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
// Load pixel data from the file into a new host memory allocation (the only
// interface stb_image provides). A real application would want to use the
// generation callback to directly decode the image into the target mapped
// device buffer.
uint8_t* pixel_data = NULL;
iree_host_size_t buffer_length = 0;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_tools_utils_load_pixel_data(filename, shape, shape_rank,
element_type, &pixel_data,
&buffer_length));
iree_tools_utils_buffer_view_load_params_t params = {
.pixel_data = pixel_data,
.pixel_data_length = buffer_length,
.input_range = input_range,
.input_range_length = input_range_length,
};
iree_status_t status = iree_hal_buffer_view_generate_buffer(
allocator, shape_rank, shape, element_type, encoding_type,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
IREE_HAL_BUFFER_USAGE_TRANSFER |
IREE_HAL_BUFFER_USAGE_MAPPING,
},
iree_tools_utils_buffer_view_load_image_rescaled, &params,
out_buffer_view);
stbi_image_free(pixel_data);
IREE_TRACE_ZONE_END(z0);
return status;
}

View File

@@ -1,77 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
#define IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/hal/buffer_view.h"
#if __cplusplus
extern "C" {
#endif // __cplusplus
// Loads the image at |filename| into |out_pixel_data| and sets
// |out_buffer_length| to its length.
//
// The image dimension must match the width, height, and channel in|shape|,
// while 2 <= |shape_rank| <= 4 to match the image tensor format.
//
// The file must be in a format supported by stb_image.h.
// The returned |out_pixel_data| buffer must be released by the caller.
iree_status_t iree_tools_utils_load_pixel_data(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length);
// Parse the content in an image file in |filename| into a HAL buffer view
// |out_buffer_view|. |out_buffer_view| properties are defined by |shape|,
// |shape_rank|, and |element_type|, while being allocated by |allocator|.
//
// The |element_type| has to be SINT_8 or UINT_8. For FLOAT_32, use
// |iree_tools_utils_buffer_view_from_image_rescaled| instead.
//
// The returned |out_buffer_view| must be released by the caller.
iree_status_t iree_tools_utils_buffer_view_from_image(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
iree_hal_allocator_t* allocator, iree_hal_buffer_view_t** out_buffer_view);
// Parse the content in an image file in |filename| into a HAL buffer view
// |out_buffer_view|. |out_buffer_view| properties are defined by |shape|,
// |shape_rank|, and |element_type|, while being allocated by |allocator|.
// The value in |out_buffer_view| is rescaled with |input_range|.
//
// The |element_type| has to be FLOAT_32, For SINT_8 or UINT_8, use
// |iree_tools_utils_buffer_view_from_image| instead.
//
// The returned |out_buffer_view| must be released by the caller.
iree_status_t iree_tools_utils_buffer_view_from_image_rescaled(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
iree_hal_allocator_t* allocator, const float* input_range,
iree_host_size_t input_range_length,
iree_hal_buffer_view_t** out_buffer_view);
// Normalize uint8_t |pixel_data| of the size |buffer_length| to float buffer
// |out_buffer| with the range |input_range|.
//
// float32_x = (uint8_x - 127.5) / 127.5 * input_scale + input_offset, where
// input_scale = abs(|input_range[0]| - |input_range[1]| / 2
// input_offset = |input_range[0]| + |input_range[1]| / 2
//
// |out_buffer| needs to be allocated before the call.
iree_status_t iree_tools_utils_pixel_rescaled_to_buffer(
const uint8_t* pixel_data, iree_host_size_t pixel_count,
const float* input_range, iree_host_size_t input_range_length,
float* out_buffer);
#if __cplusplus
}
#endif // __cplusplus
#endif // IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_

View File

@@ -1,121 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// This sample uses image_util to load a hand-written image as an
// iree_hal_buffer_view_t then passes it to the bytecode module built from
// mnist.mlir on the CPU backend with the local-task driver.
#include <float.h>
#include "image_util.h"
#include "iree/runtime/api.h"
#include "mnist_bytecode_module_c.h"
iree_status_t Run(const iree_string_view_t image_path) {
iree_runtime_instance_options_t instance_options;
iree_runtime_instance_options_initialize(IREE_API_VERSION_LATEST,
&instance_options);
iree_runtime_instance_options_use_all_available_drivers(&instance_options);
iree_runtime_instance_t* instance = NULL;
IREE_RETURN_IF_ERROR(iree_runtime_instance_create(
&instance_options, iree_allocator_system(), &instance));
// TODO(#5724): move device selection into the compiled modules.
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_runtime_instance_try_create_default_device(
instance, iree_make_cstring_view("local-task"), &device));
// Create one session per loaded module to hold the module state.
iree_runtime_session_options_t session_options;
iree_runtime_session_options_initialize(&session_options);
iree_runtime_session_t* session = NULL;
IREE_RETURN_IF_ERROR(iree_runtime_session_create_with_device(
instance, &session_options, device,
iree_runtime_instance_host_allocator(instance), &session));
iree_hal_device_release(device);
const struct iree_file_toc_t* module_file =
iree_samples_vision_inference_mnist_bytecode_module_create();
IREE_RETURN_IF_ERROR(iree_runtime_session_append_bytecode_module_from_memory(
session, iree_make_const_byte_span(module_file->data, module_file->size),
iree_allocator_null()));
iree_runtime_call_t call;
IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name(
session, iree_make_cstring_view("module.predict"), &call));
// Prepare the input hal buffer view with image_util library.
// The input of the mmist model is single 28x28 pixel image as a
// tensor<1x28x28x1xf32>, with pixels in [0.0, 1.0].
iree_hal_buffer_view_t* buffer_view = NULL;
iree_hal_dim_t buffer_shape[] = {1, 28, 28, 1};
iree_hal_element_type_t hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32;
float input_range[2] = {0.0f, 1.0f};
IREE_RETURN_IF_ERROR(
iree_tools_utils_buffer_view_from_image_rescaled(
image_path, buffer_shape, IREE_ARRAYSIZE(buffer_shape),
hal_element_type, iree_hal_device_allocator(device), input_range,
IREE_ARRAYSIZE(input_range), &buffer_view),
"load image");
IREE_RETURN_IF_ERROR(
iree_runtime_call_inputs_push_back_buffer_view(&call, buffer_view));
iree_hal_buffer_view_release(buffer_view);
IREE_RETURN_IF_ERROR(iree_runtime_call_invoke(&call, /*flags=*/0));
// Get the result buffers from the invocation.
iree_hal_buffer_view_t* ret_buffer_view = NULL;
IREE_RETURN_IF_ERROR(
iree_runtime_call_outputs_pop_front_buffer_view(&call, &ret_buffer_view));
// Read back the results. The output of the mnist model is a 1x10 prediction
// confidence values for each digit in [0, 9].
float predictions[1 * 10] = {0.0f};
IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
iree_runtime_session_device(session),
iree_hal_buffer_view_buffer(ret_buffer_view), 0, predictions,
sizeof(predictions), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout()));
iree_hal_buffer_view_release(ret_buffer_view);
// Get the highest index from the output.
float result_val = FLT_MIN;
int result_idx = 0;
for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(predictions); ++i) {
if (predictions[i] > result_val) {
result_val = predictions[i];
result_idx = i;
}
}
fprintf(stdout, "Detected number: %d\n", result_idx);
iree_runtime_call_deinitialize(&call);
iree_runtime_session_release(session);
iree_runtime_instance_release(instance);
return iree_ok_status();
}
int main(int argc, char** argv) {
if (argc > 2) {
fprintf(stderr, "Usage: iree-run-mnist-module <image file>\n");
return -1;
}
iree_string_view_t image_path;
if (argc == 1) {
image_path = iree_make_cstring_view("mnist_test.png");
} else {
image_path = iree_make_cstring_view(argv[1]);
}
iree_status_t result = Run(image_path);
if (!iree_status_is_ok(result)) {
iree_status_fprint(stderr, result);
iree_status_ignore(result);
return -1;
}
iree_status_ignore(result);
return 0;
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 261 B

View File

@@ -1,116 +0,0 @@
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
if(NOT IREE_TARGET_BACKEND_VULKAN_SPIRV OR
NOT IREE_HAL_DRIVER_VULKAN)
message(STATUS "Missing Vulkan backend and/or driver, skipping vulkan_gui sample")
return()
endif()
# This target statically links against Vulkan.
# One way to achieve this is by installing the Vulkan SDK from
# https://vulkan.lunarg.com/.
include(FindVulkan)
if(NOT Vulkan_FOUND)
message(STATUS "Could not find Vulkan, skipping vulkan_gui sample")
return()
endif()
# vcpkg install sdl2[vulkan]
# tested with versions 2.0.14#4 - 2.0.22#1
find_package(SDL2)
if(NOT SDL2_FOUND)
message(STATUS "Could not find SDL2, skipping vulkan_gui sample")
return()
endif()
FetchContent_Declare(
imgui
GIT_REPOSITORY https://github.com/ocornut/imgui
GIT_TAG master
)
FetchContent_MakeAvailable(imgui)
# Dear ImGui
set(IMGUI_DIR ${CMAKE_BINARY_DIR}/_deps/imgui-src)
message("Looking for Imgui in ${IMGUI_DIR}")
include_directories(${IMGUI_DIR} ${IMGUI_DIR}/backends ..)
function(iree_vulkan_sample)
cmake_parse_arguments(
_RULE
""
"NAME"
"SRCS"
${ARGN}
)
# Define the sample executable.
set(_NAME "${_RULE_NAME}")
set(SRCS "${_RULE_SRCS}")
add_executable(${_NAME} "")
target_sources(${_NAME}
PRIVATE
${SRCS}
"${IMGUI_DIR}/backends/imgui_impl_sdl.cpp"
"${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp"
"${IMGUI_DIR}/imgui.cpp"
"${IMGUI_DIR}/imgui_draw.cpp"
"${IMGUI_DIR}/imgui_demo.cpp"
"${IMGUI_DIR}/imgui_tables.cpp"
"${IMGUI_DIR}/imgui_widgets.cpp"
)
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_NAME}")
target_include_directories(${_NAME} PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
)
target_link_libraries(${_NAME}
SDL2::SDL2
Vulkan::Vulkan
iree_runtime_runtime
iree_base_internal_main
iree_hal_drivers_vulkan_registration_registration
iree_modules_hal_hal
iree_vm_vm
iree_vm_bytecode_module
iree_vm_cc
iree_tooling_vm_util_cc
iree_tooling_context_util
)
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
set(_GUI_LINKOPTS "-SUBSYSTEM:CONSOLE")
else()
set(_GUI_LINKOPTS "")
endif()
target_link_options(${_NAME}
PRIVATE
${_GUI_LINKOPTS}
)
endfunction()
iree_vulkan_sample(
NAME
iree-samples-resnet-vulkan-gui
SRCS
vulkan_resnet_inference_gui.cc
)
iree_vulkan_sample(
NAME
iree-vulkan-gui
SRCS
vulkan_inference_gui.cc
)
message(STATUS "Configured vulkan_gui sample successfully")

View File

@@ -1,4 +0,0 @@
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = "arith.mulf"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

File diff suppressed because it is too large Load Diff

View File

@@ -1,957 +0,0 @@
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Vulkan Graphics + IREE API Integration Sample.
#include <SDL.h>
#include <SDL_vulkan.h>
#include <imgui.h>
#include <imgui_impl_sdl.h>
#include <imgui_impl_vulkan.h>
#include <vulkan/vulkan.h>
#include <cstring>
#include <set>
#include <vector>
#include <fstream>
#include <array>
#include <cstdio>
#include <cstdlib>
#include <iterator>
#include <string>
#include <utility>
#include "iree/hal/drivers/vulkan/api.h"
// IREE's C API:
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/hal/drivers/vulkan/registration/driver_module.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "iree/vm/ref_cc.h"
// iree-run-module
#include "iree/base/internal/flags.h"
#include "iree/base/status_cc.h"
#include "iree/base/tracing.h"
#include "iree/modules/hal/types.h"
#include "iree/tooling/comparison.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/vm_util_cc.h"
// Other dependencies (helpers, etc.)
#include "iree/base/internal/main.h"
#define IMGUI_UNLIMITED_FRAME_RATE
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
IREE_FLAG(string, entry_function, "",
"Name of a function contained in the module specified by module_file "
"to run.");
// TODO(benvanik): move --function_input= flag into a util.
static iree_status_t parse_function_io(iree_string_view_t flag_name,
void* storage,
iree_string_view_t value) {
auto* list = (std::vector<std::string>*)storage;
list->push_back(std::string(value.data, value.size));
return iree_ok_status();
}
static void print_function_io(iree_string_view_t flag_name, void* storage,
FILE* file) {
auto* list = (std::vector<std::string>*)storage;
if (list->empty()) {
fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
} else {
for (size_t i = 0; i < list->size(); ++i) {
fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
list->at(i).c_str());
}
}
}
static std::vector<std::string> FLAG_function_inputs;
IREE_FLAG_CALLBACK(
parse_function_io, print_function_io, &FLAG_function_inputs, function_input,
"An input (a) value or (b) buffer of the format:\n"
" (a) scalar value\n"
" value\n"
" e.g.: --function_input=\"3.14\"\n"
" (b) buffer:\n"
" [shape]xtype=[value]\n"
" e.g.: --function_input=\"2x2xi32=1 2 3 4\"\n"
"Optionally, brackets may be used to separate the element values:\n"
" 2x2xi32=[[1 2][3 4]]\n"
"Raw binary files can be read to provide buffer contents:\n"
" 2x2xi32=@some/file.bin\n"
"numpy npy files (from numpy.save) can be read to provide 1+ values:\n"
" @some.npy\n"
"Each occurrence of the flag indicates an input in the order they were\n"
"specified on the command line.");
typedef struct iree_file_toc_t {
const char* name; // the file's original name
char* data; // beginning of the file
size_t size; // length of the file
} iree_file_toc_t;
bool load_file(const char* filename, char** pOut, size_t* pSize)
{
FILE* f = fopen(filename, "rb");
if (f == NULL)
{
fprintf(stderr, "Can't open %s\n", filename);
return false;
}
fseek(f, 0L, SEEK_END);
*pSize = ftell(f);
fseek(f, 0L, SEEK_SET);
*pOut = (char*)malloc(*pSize);
size_t size = fread(*pOut, *pSize, 1, f);
fclose(f);
return size != 0;
}
static VkAllocationCallbacks* g_Allocator = NULL;
static VkInstance g_Instance = VK_NULL_HANDLE;
static VkPhysicalDevice g_PhysicalDevice = VK_NULL_HANDLE;
static VkDevice g_Device = VK_NULL_HANDLE;
static uint32_t g_QueueFamily = (uint32_t)-1;
static VkQueue g_Queue = VK_NULL_HANDLE;
static VkPipelineCache g_PipelineCache = VK_NULL_HANDLE;
static VkDescriptorPool g_DescriptorPool = VK_NULL_HANDLE;
static ImGui_ImplVulkanH_Window g_MainWindowData;
static uint32_t g_MinImageCount = 2;
static bool g_SwapChainRebuild = false;
static int g_SwapChainResizeWidth = 0;
static int g_SwapChainResizeHeight = 0;
static void check_vk_result(VkResult err) {
if (err == 0) return;
fprintf(stderr, "VkResult: %d\n", err);
abort();
}
// Returns the names of the Vulkan layers used for the given IREE
// |extensibility_set| and |features|.
std::vector<const char*> GetIreeLayers(
iree_hal_vulkan_extensibility_set_t extensibility_set,
iree_hal_vulkan_features_t features) {
iree_host_size_t required_count;
iree_hal_vulkan_query_extensibility_set(
features, extensibility_set, /*string_capacity=*/0, &required_count,
/*out_string_values=*/NULL);
std::vector<const char*> layers(required_count);
iree_hal_vulkan_query_extensibility_set(features, extensibility_set,
layers.size(), &required_count,
layers.data());
return layers;
}
// Returns the names of the Vulkan extensions used for the given IREE
// |extensibility_set| and |features|.
std::vector<const char*> GetIreeExtensions(
iree_hal_vulkan_extensibility_set_t extensibility_set,
iree_hal_vulkan_features_t features) {
iree_host_size_t required_count;
iree_hal_vulkan_query_extensibility_set(
features, extensibility_set, /*string_capacity=*/0, &required_count,
/*out_string_values=*/NULL);
std::vector<const char*> extensions(required_count);
iree_hal_vulkan_query_extensibility_set(features, extensibility_set,
extensions.size(), &required_count,
extensions.data());
return extensions;
}
// Returns the names of the Vulkan extensions used for the given IREE
// |vulkan_features|.
std::vector<const char*> GetDeviceExtensions(
VkPhysicalDevice physical_device,
iree_hal_vulkan_features_t vulkan_features) {
std::vector<const char*> iree_required_extensions = GetIreeExtensions(
IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_REQUIRED,
vulkan_features);
std::vector<const char*> iree_optional_extensions = GetIreeExtensions(
IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
vulkan_features);
uint32_t extension_count = 0;
check_vk_result(vkEnumerateDeviceExtensionProperties(
physical_device, nullptr, &extension_count, nullptr));
std::vector<VkExtensionProperties> extension_properties(extension_count);
check_vk_result(vkEnumerateDeviceExtensionProperties(
physical_device, nullptr, &extension_count, extension_properties.data()));
// Merge extensions lists, including optional and required for simplicity.
std::set<const char*> ext_set;
ext_set.insert("VK_KHR_swapchain");
ext_set.insert(iree_required_extensions.begin(),
iree_required_extensions.end());
for (int i = 0; i < iree_optional_extensions.size(); ++i) {
const char* optional_extension = iree_optional_extensions[i];
for (int j = 0; j < extension_count; ++j) {
if (strcmp(optional_extension, extension_properties[j].extensionName) ==
0) {
ext_set.insert(optional_extension);
break;
}
}
}
std::vector<const char*> extensions(ext_set.begin(), ext_set.end());
return extensions;
}
std::vector<const char*> GetInstanceLayers(
iree_hal_vulkan_features_t vulkan_features) {
// Query the layers that IREE wants / needs.
std::vector<const char*> required_layers = GetIreeLayers(
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_REQUIRED, vulkan_features);
std::vector<const char*> optional_layers = GetIreeLayers(
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL, vulkan_features);
// Query the layers that are available on the Vulkan ICD.
uint32_t layer_property_count = 0;
check_vk_result(
vkEnumerateInstanceLayerProperties(&layer_property_count, NULL));
std::vector<VkLayerProperties> layer_properties(layer_property_count);
check_vk_result(vkEnumerateInstanceLayerProperties(&layer_property_count,
layer_properties.data()));
// Match between optional/required and available layers.
std::vector<const char*> layers;
for (const char* layer_name : required_layers) {
bool found = false;
for (const auto& layer_property : layer_properties) {
if (std::strcmp(layer_name, layer_property.layerName) == 0) {
found = true;
layers.push_back(layer_name);
break;
}
}
if (!found) {
fprintf(stderr, "Required layer %s not available\n", layer_name);
abort();
}
}
for (const char* layer_name : optional_layers) {
for (const auto& layer_property : layer_properties) {
if (std::strcmp(layer_name, layer_property.layerName) == 0) {
layers.push_back(layer_name);
break;
}
}
}
return layers;
}
std::vector<const char*> GetInstanceExtensions(
SDL_Window* window, iree_hal_vulkan_features_t vulkan_features) {
// Ask SDL for its list of required instance extensions.
uint32_t sdl_extensions_count = 0;
SDL_Vulkan_GetInstanceExtensions(window, &sdl_extensions_count, NULL);
std::vector<const char*> sdl_extensions(sdl_extensions_count);
SDL_Vulkan_GetInstanceExtensions(window, &sdl_extensions_count,
sdl_extensions.data());
std::vector<const char*> iree_required_extensions = GetIreeExtensions(
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_REQUIRED,
vulkan_features);
std::vector<const char*> iree_optional_extensions = GetIreeExtensions(
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL,
vulkan_features);
// Merge extensions lists, including optional and required for simplicity.
std::set<const char*> ext_set;
ext_set.insert(sdl_extensions.begin(), sdl_extensions.end());
ext_set.insert(iree_required_extensions.begin(),
iree_required_extensions.end());
ext_set.insert(iree_optional_extensions.begin(),
iree_optional_extensions.end());
std::vector<const char*> extensions(ext_set.begin(), ext_set.end());
return extensions;
}
void SetupVulkan(iree_hal_vulkan_features_t vulkan_features,
const char** instance_layers, uint32_t instance_layers_count,
const char** instance_extensions,
uint32_t instance_extensions_count,
const VkAllocationCallbacks* allocator, VkInstance* instance,
uint32_t* queue_family_index,
VkPhysicalDevice* physical_device, VkQueue* queue,
VkDevice* device, VkDescriptorPool* descriptor_pool) {
VkResult err;
// Create Vulkan Instance
{
VkInstanceCreateInfo create_info = {};
create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
create_info.enabledLayerCount = instance_layers_count;
create_info.ppEnabledLayerNames = instance_layers;
create_info.enabledExtensionCount = instance_extensions_count;
create_info.ppEnabledExtensionNames = instance_extensions;
err = vkCreateInstance(&create_info, allocator, instance);
check_vk_result(err);
}
// Select GPU
{
uint32_t gpu_count;
err = vkEnumeratePhysicalDevices(*instance, &gpu_count, NULL);
check_vk_result(err);
IM_ASSERT(gpu_count > 0);
VkPhysicalDevice* gpus =
(VkPhysicalDevice*)malloc(sizeof(VkPhysicalDevice) * gpu_count);
err = vkEnumeratePhysicalDevices(*instance, &gpu_count, gpus);
check_vk_result(err);
// Use the first reported GPU for simplicity.
*physical_device = gpus[0];
VkPhysicalDeviceProperties properties;
vkGetPhysicalDeviceProperties(*physical_device, &properties);
fprintf(stdout, "Selected Vulkan device: '%s'\n", properties.deviceName);
free(gpus);
}
// Select queue family. We want a single queue with graphics and compute for
// simplicity, but we could also discover and use separate queues for each.
{
uint32_t count;
vkGetPhysicalDeviceQueueFamilyProperties(*physical_device, &count, NULL);
VkQueueFamilyProperties* queues = (VkQueueFamilyProperties*)malloc(
sizeof(VkQueueFamilyProperties) * count);
vkGetPhysicalDeviceQueueFamilyProperties(*physical_device, &count, queues);
for (uint32_t i = 0; i < count; i++) {
if (queues[i].queueFlags &
(VK_QUEUE_GRAPHICS_BIT | VK_QUEUE_COMPUTE_BIT)) {
*queue_family_index = i;
break;
}
}
free(queues);
IM_ASSERT(*queue_family_index != (uint32_t)-1);
}
// Create Logical Device (with 1 queue)
{
std::vector<const char*> device_extensions =
GetDeviceExtensions(*physical_device, vulkan_features);
const float queue_priority[] = {1.0f};
VkDeviceQueueCreateInfo queue_info = {};
queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
queue_info.queueFamilyIndex = *queue_family_index;
queue_info.queueCount = 1;
queue_info.pQueuePriorities = queue_priority;
VkDeviceCreateInfo create_info = {};
create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
create_info.queueCreateInfoCount = 1;
create_info.pQueueCreateInfos = &queue_info;
create_info.enabledExtensionCount =
static_cast<uint32_t>(device_extensions.size());
create_info.ppEnabledExtensionNames = device_extensions.data();
// Enable timeline semaphores.
VkPhysicalDeviceFeatures2 features2;
memset(&features2, 0, sizeof(features2));
features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
create_info.pNext = &features2;
VkPhysicalDeviceTimelineSemaphoreFeatures semaphore_features;
memset(&semaphore_features, 0, sizeof(semaphore_features));
semaphore_features.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_TIMELINE_SEMAPHORE_FEATURES;
semaphore_features.pNext = features2.pNext;
features2.pNext = &semaphore_features;
semaphore_features.timelineSemaphore = VK_TRUE;
err = vkCreateDevice(*physical_device, &create_info, allocator, device);
check_vk_result(err);
vkGetDeviceQueue(*device, *queue_family_index, 0, queue);
}
// Create Descriptor Pool
{
VkDescriptorPoolSize pool_sizes[] = {
{VK_DESCRIPTOR_TYPE_SAMPLER, 1000},
{VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, 1000},
{VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, 1000},
{VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, 1000},
{VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER, 1000},
{VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER, 1000},
{VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, 1000},
{VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1000},
{VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC, 1000},
{VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC, 1000},
{VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT, 1000}};
VkDescriptorPoolCreateInfo pool_info = {};
pool_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
pool_info.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
pool_info.maxSets = 1000 * IREE_ARRAYSIZE(pool_sizes);
pool_info.poolSizeCount = (uint32_t)IREE_ARRAYSIZE(pool_sizes);
pool_info.pPoolSizes = pool_sizes;
err =
vkCreateDescriptorPool(*device, &pool_info, allocator, descriptor_pool);
check_vk_result(err);
}
}
void SetupVulkanWindow(ImGui_ImplVulkanH_Window* wd,
const VkAllocationCallbacks* allocator,
VkInstance instance, uint32_t queue_family_index,
VkPhysicalDevice physical_device, VkDevice device,
VkSurfaceKHR surface, int width, int height,
uint32_t min_image_count) {
wd->Surface = surface;
// Check for WSI support
VkBool32 res;
vkGetPhysicalDeviceSurfaceSupportKHR(physical_device, queue_family_index,
wd->Surface, &res);
if (res != VK_TRUE) {
fprintf(stderr, "Error no WSI support on physical device 0\n");
exit(-1);
}
// Select Surface Format
const VkFormat requestSurfaceImageFormat[] = {
VK_FORMAT_B8G8R8A8_UNORM, VK_FORMAT_R8G8B8A8_UNORM,
VK_FORMAT_B8G8R8_UNORM, VK_FORMAT_R8G8B8_UNORM};
const VkColorSpaceKHR requestSurfaceColorSpace =
VK_COLORSPACE_SRGB_NONLINEAR_KHR;
wd->SurfaceFormat = ImGui_ImplVulkanH_SelectSurfaceFormat(
physical_device, wd->Surface, requestSurfaceImageFormat,
(size_t)IREE_ARRAYSIZE(requestSurfaceImageFormat),
requestSurfaceColorSpace);
// Select Present Mode
#ifdef IMGUI_UNLIMITED_FRAME_RATE
VkPresentModeKHR present_modes[] = {VK_PRESENT_MODE_MAILBOX_KHR,
VK_PRESENT_MODE_IMMEDIATE_KHR,
VK_PRESENT_MODE_FIFO_KHR};
#else
VkPresentModeKHR present_modes[] = {VK_PRESENT_MODE_FIFO_KHR};
#endif
wd->PresentMode = ImGui_ImplVulkanH_SelectPresentMode(
physical_device, wd->Surface, &present_modes[0],
IREE_ARRAYSIZE(present_modes));
// Create SwapChain, RenderPass, Framebuffer, etc.
IM_ASSERT(min_image_count >= 2);
ImGui_ImplVulkanH_CreateOrResizeWindow(instance, physical_device, device, wd,
queue_family_index, allocator, width,
height, min_image_count);
// Set clear color.
ImVec4 clear_color = ImVec4(0.45f, 0.55f, 0.60f, 1.00f);
memcpy(&wd->ClearValue.color.float32[0], &clear_color, 4 * sizeof(float));
}
void RenderFrame(ImGui_ImplVulkanH_Window* wd, VkDevice device, VkQueue queue) {
VkResult err;
VkSemaphore image_acquired_semaphore =
wd->FrameSemaphores[wd->SemaphoreIndex].ImageAcquiredSemaphore;
VkSemaphore render_complete_semaphore =
wd->FrameSemaphores[wd->SemaphoreIndex].RenderCompleteSemaphore;
err = vkAcquireNextImageKHR(device, wd->Swapchain, UINT64_MAX,
image_acquired_semaphore, VK_NULL_HANDLE,
&wd->FrameIndex);
check_vk_result(err);
ImGui_ImplVulkanH_Frame* fd = &wd->Frames[wd->FrameIndex];
{
err = vkWaitForFences(
device, 1, &fd->Fence, VK_TRUE,
UINT64_MAX); // wait indefinitely instead of periodically checking
check_vk_result(err);
err = vkResetFences(device, 1, &fd->Fence);
check_vk_result(err);
}
{
err = vkResetCommandPool(device, fd->CommandPool, 0);
check_vk_result(err);
VkCommandBufferBeginInfo info = {};
info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
err = vkBeginCommandBuffer(fd->CommandBuffer, &info);
check_vk_result(err);
}
{
VkRenderPassBeginInfo info = {};
info.sType = VK_STRUCTURE_TYPE_RENDER_PASS_BEGIN_INFO;
info.renderPass = wd->RenderPass;
info.framebuffer = fd->Framebuffer;
info.renderArea.extent.width = wd->Width;
info.renderArea.extent.height = wd->Height;
info.clearValueCount = 1;
info.pClearValues = &wd->ClearValue;
vkCmdBeginRenderPass(fd->CommandBuffer, &info, VK_SUBPASS_CONTENTS_INLINE);
}
// Record Imgui Draw Data and draw funcs into command buffer
ImGui_ImplVulkan_RenderDrawData(ImGui::GetDrawData(), fd->CommandBuffer);
// Submit command buffer
vkCmdEndRenderPass(fd->CommandBuffer);
{
VkPipelineStageFlags wait_stage =
VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT;
VkSubmitInfo info = {};
info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
info.waitSemaphoreCount = 1;
info.pWaitSemaphores = &image_acquired_semaphore;
info.pWaitDstStageMask = &wait_stage;
info.commandBufferCount = 1;
info.pCommandBuffers = &fd->CommandBuffer;
info.signalSemaphoreCount = 1;
info.pSignalSemaphores = &render_complete_semaphore;
err = vkEndCommandBuffer(fd->CommandBuffer);
check_vk_result(err);
err = vkQueueSubmit(queue, 1, &info, fd->Fence);
check_vk_result(err);
}
}
void PresentFrame(ImGui_ImplVulkanH_Window* wd, VkQueue queue) {
VkSemaphore render_complete_semaphore =
wd->FrameSemaphores[wd->SemaphoreIndex].RenderCompleteSemaphore;
VkPresentInfoKHR info = {};
info.sType = VK_STRUCTURE_TYPE_PRESENT_INFO_KHR;
info.waitSemaphoreCount = 1;
info.pWaitSemaphores = &render_complete_semaphore;
info.swapchainCount = 1;
info.pSwapchains = &wd->Swapchain;
info.pImageIndices = &wd->FrameIndex;
VkResult err = vkQueuePresentKHR(queue, &info);
check_vk_result(err);
wd->SemaphoreIndex =
(wd->SemaphoreIndex + 1) %
wd->ImageCount; // Now we can use the next set of semaphores
}
static void CleanupVulkan() {
vkDestroyDescriptorPool(g_Device, g_DescriptorPool, g_Allocator);
vkDestroyDevice(g_Device, g_Allocator);
vkDestroyInstance(g_Instance, g_Allocator);
}
static void CleanupVulkanWindow() {
ImGui_ImplVulkanH_DestroyWindow(g_Instance, g_Device, &g_MainWindowData,
g_Allocator);
}
namespace iree {
extern "C" int iree_main(int argc, char** argv) {
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
if (argc > 1) {
// Avoid iree-run-module spinning endlessly on stdin if the user uses single
// dashes for flags.
printf(
"[ERROR] unexpected positional argument (expected none)."
" Did you use pass a flag with a single dash ('-')?"
" Use '--' instead.\n");
return 1;
}
// --------------------------------------------------------------------------
// Create a window.
if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) {
fprintf(stderr, "Failed to initialize SDL\n");
abort();
return 1;
}
// Setup window
// clang-format off
SDL_WindowFlags window_flags = (SDL_WindowFlags)(
SDL_WINDOW_VULKAN | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
// clang-format on
SDL_Window* window = SDL_CreateWindow(
"IREE Samples - Vulkan Inference GUI", SDL_WINDOWPOS_CENTERED,
SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
if (window == nullptr)
{
const char* sdl_err = SDL_GetError();
fprintf(stderr, "Error, SDL_CreateWindow returned: %s\n", sdl_err);
abort();
return 1;
}
// Setup Vulkan
iree_hal_vulkan_features_t iree_vulkan_features =
static_cast<iree_hal_vulkan_features_t>(
IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS |
IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS);
std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features);
std::vector<const char*> extensions =
GetInstanceExtensions(window, iree_vulkan_features);
SetupVulkan(iree_vulkan_features, layers.data(),
static_cast<uint32_t>(layers.size()), extensions.data(),
static_cast<uint32_t>(extensions.size()), g_Allocator,
&g_Instance, &g_QueueFamily, &g_PhysicalDevice, &g_Queue,
&g_Device, &g_DescriptorPool);
// Create Window Surface
VkSurfaceKHR surface;
VkResult err;
if (SDL_Vulkan_CreateSurface(window, g_Instance, &surface) == 0) {
fprintf(stderr, "Failed to create Vulkan surface.\n");
abort();
return 1;
}
// Create Framebuffers
int w, h;
SDL_GetWindowSize(window, &w, &h);
ImGui_ImplVulkanH_Window* wd = &g_MainWindowData;
SetupVulkanWindow(wd, g_Allocator, g_Instance, g_QueueFamily,
g_PhysicalDevice, g_Device, surface, w, h, g_MinImageCount);
// Setup Dear ImGui context
IMGUI_CHECKVERSION();
ImGui::CreateContext();
ImGuiIO& io = ImGui::GetIO();
(void)io;
ImGui::StyleColorsDark();
// Setup Platform/Renderer bindings
ImGui_ImplSDL2_InitForVulkan(window);
ImGui_ImplVulkan_InitInfo init_info = {};
init_info.Instance = g_Instance;
init_info.PhysicalDevice = g_PhysicalDevice;
init_info.Device = g_Device;
init_info.QueueFamily = g_QueueFamily;
init_info.Queue = g_Queue;
init_info.PipelineCache = g_PipelineCache;
init_info.DescriptorPool = g_DescriptorPool;
init_info.Allocator = g_Allocator;
init_info.MinImageCount = g_MinImageCount;
init_info.ImageCount = wd->ImageCount;
init_info.CheckVkResultFn = check_vk_result;
ImGui_ImplVulkan_Init(&init_info, wd->RenderPass);
// Upload Fonts
{
// Use any command queue
VkCommandPool command_pool = wd->Frames[wd->FrameIndex].CommandPool;
VkCommandBuffer command_buffer = wd->Frames[wd->FrameIndex].CommandBuffer;
err = vkResetCommandPool(g_Device, command_pool, 0);
check_vk_result(err);
VkCommandBufferBeginInfo begin_info = {};
begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
begin_info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
err = vkBeginCommandBuffer(command_buffer, &begin_info);
check_vk_result(err);
ImGui_ImplVulkan_CreateFontsTexture(command_buffer);
VkSubmitInfo end_info = {};
end_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
end_info.commandBufferCount = 1;
end_info.pCommandBuffers = &command_buffer;
err = vkEndCommandBuffer(command_buffer);
check_vk_result(err);
err = vkQueueSubmit(g_Queue, 1, &end_info, VK_NULL_HANDLE);
check_vk_result(err);
err = vkDeviceWaitIdle(g_Device);
check_vk_result(err);
ImGui_ImplVulkan_DestroyFontUploadObjects();
}
// Demo state.
bool show_iree_window = true;
// --------------------------------------------------------------------------
// Setup IREE.
// Check API version.
iree_api_version_t actual_version;
iree_status_t status =
iree_api_version_check(IREE_API_VERSION_LATEST, &actual_version);
if (iree_status_is_ok(status)) {
fprintf(stdout, "IREE runtime API version: %d\n", actual_version);
} else {
fprintf(stderr, "Unsupported runtime API version: %d\n", actual_version);
abort();
}
// Create a runtime Instance.
iree_vm_instance_t* iree_instance = nullptr;
IREE_CHECK_OK(
iree_vm_instance_create(iree_allocator_system(), &iree_instance));
// Register HAL drivers and VM module types.
IREE_CHECK_OK(iree_hal_vulkan_driver_module_register(
iree_hal_driver_registry_default()));
IREE_CHECK_OK(iree_hal_module_register_all_types(iree_instance));
// Create IREE Vulkan Driver and Device, sharing our VkInstance/VkDevice.
fprintf(stdout, "Creating Vulkan driver/device\n");
// Load symbols from our static `vkGetInstanceProcAddr` for IREE to use.
iree_hal_vulkan_syms_t* iree_vk_syms = nullptr;
IREE_CHECK_OK(iree_hal_vulkan_syms_create(
reinterpret_cast<void*>(&vkGetInstanceProcAddr), iree_allocator_system(),
&iree_vk_syms));
// Create the driver sharing our VkInstance.
iree_hal_driver_t* iree_vk_driver = nullptr;
iree_string_view_t driver_identifier = iree_make_cstring_view("vulkan");
iree_hal_vulkan_driver_options_t driver_options;
driver_options.api_version = VK_API_VERSION_1_0;
driver_options.requested_features = static_cast<iree_hal_vulkan_features_t>(
IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS);
IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance(
driver_identifier, &driver_options, iree_vk_syms, g_Instance,
iree_allocator_system(), &iree_vk_driver));
// Create a device sharing our VkDevice and queue.
// We could also create a separate (possibly low priority) compute queue for
// IREE, and/or provide a dedicated transfer queue.
iree_string_view_t device_identifier = iree_make_cstring_view("vulkan");
iree_hal_vulkan_queue_set_t compute_queue_set;
compute_queue_set.queue_family_index = g_QueueFamily;
compute_queue_set.queue_indices = 1 << 0;
iree_hal_vulkan_queue_set_t transfer_queue_set;
transfer_queue_set.queue_indices = 0;
iree_hal_device_t* iree_vk_device = nullptr;
IREE_CHECK_OK(iree_hal_vulkan_wrap_device(
device_identifier, &driver_options.device_options, iree_vk_syms,
g_Instance, g_PhysicalDevice, g_Device, &compute_queue_set,
&transfer_queue_set, iree_allocator_system(), &iree_vk_device));
// Create a HAL module using the HAL device.
iree_vm_module_t* hal_module = nullptr;
IREE_CHECK_OK(iree_hal_module_create(iree_instance, iree_vk_device,
IREE_HAL_MODULE_FLAG_NONE,
iree_allocator_system(), &hal_module));
// Load bytecode module
//iree_file_toc_t module_file_toc;
//const char network_model[] = "resnet50_tf.vmfb";
//fprintf(stdout, "Loading: %s\n", network_model);
//if (load_file(network_model, &module_file_toc.data, &module_file_toc.size) == false)
//{
// abort();
// return 1;
//}
//fprintf(stdout, "module size: %zu\n", module_file_toc.size);
iree_vm_module_t* bytecode_module = nullptr;
iree_status_t module_status = iree_tooling_load_module_from_flags(
iree_instance, iree_allocator_system(), &bytecode_module);
if (!iree_status_is_ok(module_status))
return -1;
//IREE_CHECK_OK(iree_vm_bytecode_module_create(
// iree_instance,
// iree_const_byte_span_t{
// reinterpret_cast<const uint8_t*>(module_file_toc.data),
// module_file_toc.size},
// iree_allocator_null(), iree_allocator_system(), &bytecode_module));
//// Query for details about what is in the loaded module.
//iree_vm_module_signature_t bytecode_module_signature =
// iree_vm_module_signature(bytecode_module);
//fprintf(stdout, "Module loaded, have <%" PRIhsz "> exported functions:\n",
// bytecode_module_signature.export_function_count);
//for (int i = 0; i < bytecode_module_signature.export_function_count; ++i) {
// iree_vm_function_t function;
// IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal(
// bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
// auto function_name = iree_vm_function_name(&function);
// auto function_signature = iree_vm_function_signature(&function);
// fprintf(stdout, " %d: '%.*s' with calling convention '%.*s'\n", i,
// (int)function_name.size, function_name.data,
// (int)function_signature.calling_convention.size,
// function_signature.calling_convention.data);
//}
// Allocate a context that will hold the module state across invocations.
iree_vm_context_t* iree_context = nullptr;
std::vector<iree_vm_module_t*> modules = {hal_module, bytecode_module};
IREE_CHECK_OK(iree_vm_context_create_with_modules(
iree_instance, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(),
iree_allocator_system(), &iree_context));
fprintf(stdout, "Context with modules is ready for use\n");
// Lookup the entry point function.
iree_vm_function_t main_function;
const char kMainFunctionName[] = "module.forward";
IREE_CHECK_OK(iree_vm_context_resolve_function(
iree_context,
iree_string_view_t{kMainFunctionName, sizeof(kMainFunctionName) - 1},
&main_function));
iree_string_view_t main_function_name = iree_vm_function_name(&main_function);
fprintf(stdout, "Resolved main function named '%.*s'\n",
(int)main_function_name.size, main_function_name.data);
// --------------------------------------------------------------------------
// Write inputs into mappable buffers.
iree_hal_allocator_t* allocator =
iree_hal_device_allocator(iree_vk_device);
//iree_hal_memory_type_t input_memory_type =
// static_cast<iree_hal_memory_type_t>(
// IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
// IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
//iree_hal_buffer_usage_t input_buffer_usage =
// static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
//iree_hal_buffer_params_t buffer_params;
//buffer_params.type = input_memory_type;
//buffer_params.usage = input_buffer_usage;
//buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE;
// Wrap input buffers in buffer views.
vm::ref<iree_vm_list_t> inputs;
iree_status_t input_status = ParseToVariantList(
allocator,
iree::span<const std::string>{FLAG_function_inputs.data(),
FLAG_function_inputs.size()},
iree_allocator_system(), &inputs);
if (!iree_status_is_ok(input_status))
return -1;
//vm::ref<iree_vm_list_t> inputs;
//IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
//iree_hal_buffer_view_t* input0_buffer_view = nullptr;
//constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
//IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
// allocator,
// /*shape_rank=*/4, /*shape=*/input_buffer_shape,
// IREE_HAL_ELEMENT_TYPE_FLOAT_32,
// IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
// iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
// &input0_buffer_view));
//auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
//IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
// Prepare outputs list to accept results from the invocation.
vm::ref<iree_vm_list_t> outputs;
constexpr iree_hal_dim_t kOutputCount = 1000;
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, kOutputCount * sizeof(float), iree_allocator_system(), &outputs));
// --------------------------------------------------------------------------
// Main loop.
bool done = false;
while (!done) {
SDL_Event event;
while (SDL_PollEvent(&event)) {
if (event.type == SDL_QUIT) {
done = true;
}
ImGui_ImplSDL2_ProcessEvent(&event);
if (event.type == SDL_QUIT) done = true;
if (event.type == SDL_WINDOWEVENT &&
event.window.event == SDL_WINDOWEVENT_RESIZED &&
event.window.windowID == SDL_GetWindowID(window)) {
g_SwapChainResizeWidth = (int)event.window.data1;
g_SwapChainResizeHeight = (int)event.window.data2;
g_SwapChainRebuild = true;
}
}
if (g_SwapChainRebuild) {
g_SwapChainRebuild = false;
ImGui_ImplVulkan_SetMinImageCount(g_MinImageCount);
ImGui_ImplVulkanH_CreateOrResizeWindow(
g_Instance, g_PhysicalDevice, g_Device, &g_MainWindowData,
g_QueueFamily, g_Allocator, g_SwapChainResizeWidth,
g_SwapChainResizeHeight, g_MinImageCount);
g_MainWindowData.FrameIndex = 0;
}
// Start the Dear ImGui frame
ImGui_ImplVulkan_NewFrame();
ImGui_ImplSDL2_NewFrame(window);
ImGui::NewFrame();
// Custom window.
{
ImGui::Begin("IREE Vulkan Integration Demo", &show_iree_window);
ImGui::Separator();
// ImGui Inputs for two input tensors.
// Run computation whenever any of the values changes.
static bool dirty = true;
if (dirty) {
// Synchronously invoke the function.
IREE_CHECK_OK(iree_vm_invoke(iree_context, main_function,
IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/nullptr, inputs.get(),
outputs.get(), iree_allocator_system()));
// we want to run continuously so we can use tools like RenderDoc, RGP, etc...
dirty = true;
}
// Framerate counter.
ImGui::Text("Application average %.3f ms/frame (%.1f FPS)",
1000.0f / ImGui::GetIO().Framerate, ImGui::GetIO().Framerate);
ImGui::End();
}
// Rendering
ImGui::Render();
RenderFrame(wd, g_Device, g_Queue);
PresentFrame(wd, g_Queue);
}
// --------------------------------------------------------------------------
// --------------------------------------------------------------------------
// Cleanup
iree_vm_module_release(hal_module);
iree_vm_module_release(bytecode_module);
iree_vm_context_release(iree_context);
iree_hal_device_release(iree_vk_device);
iree_hal_allocator_release(allocator);
iree_hal_driver_release(iree_vk_driver);
iree_hal_vulkan_syms_release(iree_vk_syms);
iree_vm_instance_release(iree_instance);
err = vkDeviceWaitIdle(g_Device);
check_vk_result(err);
ImGui_ImplVulkan_Shutdown();
ImGui_ImplSDL2_Shutdown();
ImGui::DestroyContext();
CleanupVulkanWindow();
CleanupVulkan();
SDL_DestroyWindow(window);
SDL_Quit();
// --------------------------------------------------------------------------
return 0;
}
} // namespace iree

File diff suppressed because it is too large Load Diff

View File

@@ -1,27 +0,0 @@
# Dataset annotation tool
SHARK annotator for adding or modifying prompts of dataset images
## Set up
Activate SHARK Python virtual environment and install additional packages
```shell
source ../shark.venv/bin/activate
pip install -r requirements.txt
```
## Run annotator
```shell
python annotation_tool.py
```
<img width="1280" alt="annotator" src="https://user-images.githubusercontent.com/49575973/214521137-7ef6ae10-7cd8-46e6-b270-b6c0445157f1.png">
* Select a dataset from `Dataset` dropdown list
* Select an image from `Image` dropdown list
* Image and the existing prompt will be loaded
* Select a prompt from `Prompt` dropdown list to modify or "Add new" to add a prompt
* Click `Save` to save changes, click `Delete` to delete prompt
* Click `Back` or `Next` to switch image, you could also select other images from `Image`
* Click `Finish` when finishing annotation or before switching dataset

View File

@@ -1,247 +0,0 @@
import gradio as gr
import json
import jsonlines
import os
from args import args
from pathlib import Path
from PIL import Image
from utils import get_datasets
shark_root = Path(__file__).parent.parent
demo_css = shark_root.joinpath("web/demo.css").resolve()
nodlogo_loc = shark_root.joinpath(
"web/models/stable_diffusion/logos/nod-logo.png"
)
with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=100)
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
prompt_data = dict()
with gr.Row(elem_id="ui_body"):
# TODO: add multiselect dataset, there is a gradio version conflict
dataset = gr.Dropdown(label="Dataset", choices=datasets)
image_name = gr.Dropdown(label="Image", choices=[])
with gr.Row(elem_id="ui_body"):
# TODO: add ability to search image by typing
with gr.Column(scale=1, min_width=600):
image = gr.Image(type="filepath").style(height=512)
with gr.Column(scale=1, min_width=600):
prompts = gr.Dropdown(
label="Prompts",
choices=[],
)
prompt = gr.Textbox(
label="Editor",
lines=3,
)
with gr.Row():
save = gr.Button("Save")
delete = gr.Button("Delete")
with gr.Row():
back_image = gr.Button("Back")
next_image = gr.Button("Next")
finish = gr.Button("Finish")
def filter_datasets(dataset):
if dataset is None:
return gr.Dropdown.update(value=None, choices=[])
# create the dataset dir if doesn't exist and download prompt file
dataset_path = str(shark_root) + "/dataset/" + dataset
if not os.path.exists(dataset_path):
os.mkdir(dataset_path)
# read prompt jsonlines file
prompt_data.clear()
if dataset in ds_w_prompts:
prompt_gs_path = args.gs_url + "/" + dataset + "/metadata.jsonl"
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
for line in reader.iter(type=dict, skip_invalid=True):
prompt_data[line["file_name"]] = (
[line["text"]]
if type(line["text"]) is str
else line["text"]
)
return gr.Dropdown.update(choices=images[dataset])
dataset.change(fn=filter_datasets, inputs=dataset, outputs=image_name)
def display_image(dataset, image_name):
if dataset is None or image_name is None:
return gr.Image.update(value=None), gr.Dropdown.update(value=None)
# download and load the image
img_gs_path = args.gs_url + "/" + dataset + "/" + image_name
img_sub_path = "/".join(image_name.split("/")[:-1])
img_dst_path = (
str(shark_root) + "/dataset/" + dataset + "/" + img_sub_path + "/"
)
if not os.path.exists(img_dst_path):
os.mkdir(img_dst_path)
os.system(f'gsutil cp "{img_gs_path}" "{img_dst_path}"')
img = Image.open(img_dst_path + image_name.split("/")[-1])
if image_name not in prompt_data.keys():
prompt_data[image_name] = []
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Image.update(value=img), gr.Dropdown.update(
choices=prompt_choices
)
image_name.change(
fn=display_image,
inputs=[dataset, image_name],
outputs=[image, prompts],
)
def edit_prompt(prompts):
if prompts == "Add new":
return gr.Textbox.update(value=None)
return gr.Textbox.update(value=prompts)
prompts.change(fn=edit_prompt, inputs=prompts, outputs=prompt)
def save_prompt(dataset, image_name, prompts, prompt):
if (
dataset is None
or image_name is None
or prompts is None
or prompt is None
):
return
if prompts == "Add new":
prompt_data[image_name].append(prompt)
else:
idx = prompt_data[image_name].index(prompts)
prompt_data[image_name][idx] = prompt
prompt_path = (
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
)
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
if not value:
continue
v = value if len(value) > 1 else value[0]
f.write(json.dumps({"file_name": key, "text": v}))
f.write("\n")
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Dropdown.update(choices=prompt_choices, value=None)
save.click(
fn=save_prompt,
inputs=[dataset, image_name, prompts, prompt],
outputs=prompts,
)
def delete_prompt(dataset, image_name, prompts):
if dataset is None or image_name is None or prompts is None:
return
if prompts == "Add new":
return
prompt_data[image_name].remove(prompts)
prompt_path = (
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
)
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
if not value:
continue
v = value if len(value) > 1 else value[0]
f.write(json.dumps({"file_name": key, "text": v}))
f.write("\n")
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Dropdown.update(choices=prompt_choices, value=None)
delete.click(
fn=delete_prompt,
inputs=[dataset, image_name, prompts],
outputs=prompts,
)
def get_back_image(dataset, image_name):
if dataset is None or image_name is None:
return
# remove local image
img_path = str(shark_root) + "/dataset/" + dataset + "/" + image_name
os.system(f'rm "{img_path}"')
# get the index for the back image
idx = images[dataset].index(image_name)
if idx == 0:
return gr.Dropdown.update(value=None)
return gr.Dropdown.update(value=images[dataset][idx - 1])
back_image.click(
fn=get_back_image, inputs=[dataset, image_name], outputs=image_name
)
def get_next_image(dataset, image_name):
if dataset is None or image_name is None:
return
# remove local image
img_path = str(shark_root) + "/dataset/" + dataset + "/" + image_name
os.system(f'rm "{img_path}"')
# get the index for the next image
idx = images[dataset].index(image_name)
if idx == len(images[dataset]) - 1:
return gr.Dropdown.update(value=None)
return gr.Dropdown.update(value=images[dataset][idx + 1])
next_image.click(
fn=get_next_image, inputs=[dataset, image_name], outputs=image_name
)
def finish_annotation(dataset):
if dataset is None:
return
# upload prompt and remove local data
dataset_path = str(shark_root) + "/dataset/" + dataset
dataset_gs_path = args.gs_url + "/" + dataset + "/"
os.system(
f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"'
)
os.system(f'rm -rf "{dataset_path}"')
return gr.Dropdown.update(value=None)
finish.click(fn=finish_annotation, inputs=dataset, outputs=dataset)
if __name__ == "__main__":
shark_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

View File

@@ -1,34 +0,0 @@
import argparse
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Dataset Annotator flags
##############################################################################
p.add_argument(
"--gs_url",
type=str,
required=True,
help="URL to datasets in GS bucket",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
##############################################################################
args = p.parse_args()

View File

@@ -1,3 +0,0 @@
# SHARK Annotator
gradio==3.15.0
jsonlines

View File

@@ -1,29 +0,0 @@
from google.cloud import storage
def get_datasets(gs_url):
datasets = set()
images = dict()
ds_w_prompts = []
storage_client = storage.Client()
bucket_name = gs_url.split("/")[2]
source_blob_name = "/".join(gs_url.split("/")[3:])
blobs = storage_client.list_blobs(bucket_name, prefix=source_blob_name)
for blob in blobs:
dataset_name = blob.name.split("/")[1]
if dataset_name == "":
continue
datasets.add(dataset_name)
if dataset_name not in images.keys():
images[dataset_name] = []
# check if image or jsonl
file_sub_path = "/".join(blob.name.split("/")[2:])
if "/" in file_sub_path:
images[dataset_name] += [file_sub_path]
elif "metadata.jsonl" in file_sub_path:
ds_w_prompts.append(dataset_name)
return list(datasets), images, ds_w_prompts

View File

@@ -2,28 +2,32 @@
"""SHARK Tank"""
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
# will generate local shark tank folder like this:
# HOME
# /.local
# /shark_tank
# /albert_lite_base
# /...model_name...
# /SHARK
# /gen_shark_tank
# /albert_lite_base
# /...model_name...
#
import os
import csv
import argparse
from shark.shark_importer import SharkImporter
from shark.parser import shark_args
import subprocess as sp
import tensorflow as tf
import hashlib
import numpy as np
from pathlib import Path
from apps.stable_diffusion.src.models import (
model_wrappers as mw,
)
from apps.stable_diffusion.src.utils.stable_args import (
args,
)
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
# All generated models and metadata will be saved under this directory.
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
def create_hash(file_name):
@@ -36,12 +40,8 @@ def create_hash(file_name):
def save_torch_model(torch_model_list):
from tank.model_utils import (
get_hf_model,
get_vision_model,
get_hf_img_cls_model,
get_fp16_model,
)
from tank.model_utils import get_hf_model
from tank.model_utils import get_vision_model
with open(torch_model_list) as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")
@@ -50,46 +50,16 @@ def save_torch_model(torch_model_list):
torch_model_name = row[0]
tracing_required = row[1]
model_type = row[2]
is_dynamic = row[3]
tracing_required = False if tracing_required == "False" else True
is_dynamic = False if is_dynamic == "False" else True
model = None
input = None
if model_type == "stable_diffusion":
args.use_tuned = False
args.import_mlir = True
args.use_tuned = False
args.local_tank_cache = WORKDIR
precision_values = ["fp16"]
seq_lengths = [64, 77]
for precision_value in precision_values:
args.precision = precision_value
for length in seq_lengths:
model = mw.SharkifyStableDiffusionModel(
model_id=torch_model_name,
custom_weights="",
precision=precision_value,
max_len=length,
width=512,
height=512,
use_base_vae=False,
debug=True,
sharktank_dir=WORKDIR,
generate_vmfb=False,
)
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(torch_model_name)
elif model_type == "hf":
model, input, _ = get_hf_model(torch_model_name)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(torch_model_name)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name)
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"
@@ -114,33 +84,17 @@ def save_torch_model(torch_model_list):
)
np.save(os.path.join(torch_model_dir, "hash"), np.array(mlir_hash))
# Generate torch dynamic models.
if is_dynamic:
mlir_importer.import_debug(
is_dynamic=True,
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name + "_dynamic",
)
mlir_importer.import_debug(
is_dynamic=True,
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name + "_dynamic",
)
def save_tf_model(tf_model_list):
from tank.model_utils_tf import (
get_causal_image_model,
get_causal_lm_model,
get_keras_model,
get_TFhf_model,
)
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
from tank.model_utils_tf import get_causal_lm_model
from tank.model_utils_tf import get_causal_image_model
with open(tf_model_list) as csvfile:
tf_reader = csv.reader(csvfile, delimiter=",")
@@ -151,15 +105,11 @@ def save_tf_model(tf_model_list):
model = None
input = None
print(f"Generating artifacts for model {tf_model_name}")
print(model_type)
if model_type == "hf":
model, input, _ = get_causal_lm_model(tf_model_name)
if model_type == "img":
model, input, _ = get_causal_image_model(tf_model_name)
if model_type == "keras":
model, input, _ = get_keras_model(tf_model_name)
if model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name)
tf_model_name = tf_model_name.replace("/", "_")
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
@@ -236,42 +186,29 @@ def is_valid_file(arg):
if __name__ == "__main__":
# Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--torch_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/torch_model_list.csv",
# help="""Contains the file with torch_model name and args.
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
# )
# parser.add_argument(
# "--tf_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tf_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--tflite_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tflite/tflite_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--ci_tank_dir",
# type=bool,
# default=False,
# )
# parser.add_argument("--upload", type=bool, default=False)
# old_args = parser.parse_args()
home = str(Path.home())
if args.ci_tank_dir == True:
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
else:
WORKDIR = os.path.join(home, ".local/shark_tank/")
parser = argparse.ArgumentParser()
parser.add_argument(
"--torch_model_csv",
type=lambda x: is_valid_file(x),
default="./tank/pytorch/torch_model_list.csv",
help="""Contains the file with torch_model name and args.
Please see: https://github.com/nod-ai/SHARK/blob/main/tank/pytorch/torch_model_list.csv""",
)
parser.add_argument(
"--tf_model_csv",
type=lambda x: is_valid_file(x),
default="./tank/tf/tf_model_list.csv",
help="Contains the file with tf model name and args.",
)
parser.add_argument(
"--tflite_model_csv",
type=lambda x: is_valid_file(x),
default="./tank/tflite/tflite_model_list.csv",
help="Contains the file with tf model name and args.",
)
parser.add_argument("--upload", type=bool, default=False)
args = parser.parse_args()
if args.torch_model_csv:
save_torch_model(args.torch_model_csv)
@@ -280,3 +217,7 @@ if __name__ == "__main__":
if args.tflite_model_csv:
save_tflite_model(args.tflite_model_csv)
if args.upload:
print("uploading files to gs://shark_tank/")
os.system("gsutil cp -r ./gen_shark_tank/* gs://shark_tank/")

View File

@@ -4,9 +4,9 @@ requires = [
"wheel",
"packaging",
"numpy>=1.22.4",
"torch-mlir>=20221021.633",
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
"numpy==1.22.4",
"torch-mlir>=20220428.420",
"iree-compiler>=20220427.13",
"iree-runtime>=20220427.13",
]
build-backend = "setuptools.build_meta"

View File

@@ -1,3 +1,3 @@
[pytest]
addopts = --verbose -p no:warnings
norecursedirs = inference tank/tflite examples benchmarks shark
norecursedirs = inference tank/tflite

View File

@@ -1,4 +1,4 @@
-f https://download.pytorch.org/whl/nightly/cpu/
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
numpy
@@ -19,17 +19,13 @@ tensorflow-macos
tensorflow-metal
#tf-models-nightly
#tensorflow-text-nightly
transformers
transformers==4.18.0
tensorflow-probability
#jax[cpu]
# tflitehub dependencies.
Pillow
# web dependecies.
gradio
altair
# Testing and support.
#lit
#pyyaml

View File

@@ -2,9 +2,8 @@
--pre
numpy==1.22.4
torch
torchvision
pytorch-triton
tabulate
tqdm
@@ -15,12 +14,10 @@ iree-tools-tf
# TensorFlow and JAX.
gin-config
tensorflow==2.10.1
keras==2.10
tensorflow
#tf-models-nightly
#tensorflow-text-nightly
transformers
diffusers
transformers==4.18.0
#tensorflow-probability
#jax[cpu]
@@ -31,13 +28,6 @@ Pillow
# Testing and support.
lit
pyyaml
python-dateutil
sacremoses
# web dependecies.
gradio
altair
scipy
#ONNX and ORT for benchmarking
#--extra-index-url https://test.pypi.org/simple/

View File

@@ -5,25 +5,9 @@ wheel
tqdm
# SHARK Downloader
google-cloud-storage
gsutil
# Testing
pytest
pytest-xdist
pytest-forked
Pillow
parameterized
# Add transformers, diffusers and scipy since it most commonly used
transformers
diffusers
scipy
ftfy
gradio
altair
omegaconf
safetensors
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller

View File

@@ -2,18 +2,11 @@ from setuptools import find_packages
from setuptools import setup
import os
import glob
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5"
backend_deps = []
if "NO_BACKEND" in os.environ.keys():
backend_deps = [
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
]
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.4"
setup(
name="nodai-SHARK",
@@ -34,12 +27,12 @@ setup(
"Operating System :: OS Independent",
],
packages=find_packages(exclude=("examples")),
python_requires=">=3.9",
data_files=glob.glob("apps/stable_diffusion/resources/**"),
python_requires=">=3.7",
install_requires=[
"numpy",
"PyYAML",
"torch-mlir>=20221021.633",
]
+ backend_deps,
"torch-mlir>=20220428.420",
"iree-compiler>=20220427.13",
"iree-runtime>=20220427.13",
],
)

View File

@@ -1,45 +0,0 @@
param([string]$arguments)
if ($arguments -eq "--update-src"){
git pull
}
#Write-Host "Installing python"
#Start-Process winget install Python.Python.3.10 '/quiet InstallAllUsers=1 PrependPath=1' -wait -NoNewWindow
#Write-Host "python installation completed successfully"
#Write-Host "Reload environment variables"
#$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
#Write-Host "Reloaded environment variables"
# redirect stderr into stdout
$p = &{python -V} 2>&1
# check if an ErrorRecord was returned
$version = if($p -is [System.Management.Automation.ErrorRecord])
{
# grab the version string from the error message
$p.Exception.Message
}
else
{
# otherwise return as is
$p
}
Write-Host "Python version found is"
Write-Host $p
Write-Host "Installing Build Dependencies"
python -m venv .\shark.venv\
.\shark.venv\Scripts\activate
pip install -r requirements.txt
pip install --pre torch-mlir torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
Write-Host "Build and installation completed successfully"
Write-Host "Source your venv with ./shark.venv/Scripts/activate"

View File

@@ -7,8 +7,6 @@
# VENV_DIR=myshark.venv #create a venv called myshark.venv
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
# IMPORTER=1 #Install importer deps
# BENCHMARK=1 #Install benchmark deps
# NO_BACKEND=1 #Don't install iree or shark backend
# if you run the script from a conda env it will install in your conda env
TD="$(cd $(dirname $0) && pwd)"
@@ -76,16 +74,11 @@ fi
$PYTHON -m pip install --upgrade pip || die "Could not upgrade pip"
$PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
if [ "$torch_mlir_bin" = true ]; then
if [[ $(uname -s) = 'Darwin' ]]; then
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
$PYTHON -m pip install --find-links https://github.com/llvm/torch-mlir/releases torch-mlir --extra-index-url https://download.pytorch.org/whl/nightly/cpu
if [ $? -eq 0 ];then
echo "Successfully Installed torch-mlir"
else
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
if [ $? -eq 0 ];then
echo "Successfully Installed torch-mlir"
else
echo "Could not install torch-mlir" >&2
fi
echo "Could not install torch-mlir" >&2
fi
else
echo "${Red}No binaries found for Python $PYTHON_VERSION_X_Y on $(uname -s)"
@@ -94,56 +87,34 @@ else
exit 1
fi
if [[ -z "${USE_IREE}" ]]; then
rm .use-iree
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
RUNTIME="nod-ai/SHARK-Runtime"
else
touch ./.use-iree
RUNTIME="https://iree-org.github.io/iree/pip-release-links.html"
fi
if [[ -z "${NO_BACKEND}" ]]; then
echo "Installing ${RUNTIME}..."
$PYTHON -m pip install --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
else
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
RUNTIME="google/iree"
fi
echo "Installing ${RUNTIME}..."
$PYTHON -m pip install --find-links https://github.com/${RUNTIME}/releases iree-compiler iree-runtime
if [[ ! -z "${IMPORTER}" ]]; then
echo "${Yellow}Installing importer tools.."
if [[ $(uname -s) = 'Linux' ]]; then
echo "${Yellow}Linux detected.. installing Linux importer tools"
#Always get the importer tools from upstream IREE
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://iree-org.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
$PYTHON -m pip install --upgrade -r "$TD/requirements-importer.txt" -f https://github.com/${RUNTIME}/releases --extra-index-url https://test.pypi.org/simple/ --extra-index-url https://download.pytorch.org/whl/nightly/cu116
elif [[ $(uname -s) = 'Darwin' ]]; then
echo "${Yellow}macOS detected.. installing macOS importer tools"
#Conda seems to have some problems installing these packages and hope they get resolved upstream.
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer-macos.txt" -f ${RUNTIME} --extra-index-url https://download.pytorch.org/whl/nightly/cpu
$PYTHON -m pip install --upgrade -r "$TD/requirements-importer-macos.txt" -f https://github.com/${RUNTIME}/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
fi
fi
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
$PYTHON -m pip install -e . --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://github.com/llvm/torch-mlir/releases -f https://github.com/${RUNTIME}/releases
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
T_VER=$($PYTHON -m pip show torch | grep Version)
TORCH_VERSION=${T_VER:9:17}
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
TV_VERSION=${TV_VER:9:18}
if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl
$PYTHON -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu116
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu117."
echo "Successfully Installed torch + cu116."
else
echo "Could not install torch + cu117." >&2
fi
fi
if [[ ! -z "${ONNX}" ]]; then
echo "${Yellow}Installing ONNX and onnxruntime for benchmarks..."
$PYTHON -m pip install onnx onnxruntime psutil
if [ $? -eq 0 ];then
echo "Successfully installed ONNX and ONNX runtime."
else
echo "Could not install ONNX." >&2
echo "Could not install torch + cu116." >&2
fi
fi

View File

@@ -1,70 +0,0 @@
import torchdynamo
import torch
import torch_mlir
from shark.sharkdynamo.utils import make_shark_compiler
import warnings, logging
warnings.simplefilter("ignore")
torchdynamo.config.log_level = logging.ERROR
torchdynamo.reset()
@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
)
def foo(t):
return 2 * t
example_input = torch.rand((2, 3))
x = foo(example_input)
print(x)
torchdynamo.reset()
@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
)
def foo(a, b):
x = a / (a + 1)
if b.sum() < 0:
b = b * -1
return x * b
print(foo(torch.rand((2, 3)), -torch.rand((2, 3))))
torchdynamo.reset()
@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
)
def foo(a):
for i in range(10):
a += 1.0
return a
print(foo(torch.rand((1, 2))))
torchdynamo.reset()
@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
)
def test_unsupported_types(t, y):
return t, 2 * y
str_input = "hello"
tensor_input = torch.randn(2)
print(test_unsupported_types(str_input, tensor_input))

View File

@@ -36,9 +36,7 @@
" from torchdynamo.optimizations.backends import create_backend\n",
" from torchdynamo.optimizations.subgraph import SubGraph\n",
"except ModuleNotFoundError:\n",
" print(\n",
" \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n",
" )\n",
" print(\"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\")\n",
" exit()\n",
"\n",
"# torch-mlir imports for compiling\n",
@@ -99,9 +97,7 @@
"\n",
" for node in fx_g.graph.nodes:\n",
" if node.op == \"output\":\n",
" assert (\n",
" len(node.args) == 1\n",
" ), \"Output node must have a single argument\"\n",
" assert len(node.args) == 1, \"Output node must have a single argument\"\n",
" node_arg = node.args[0]\n",
" if isinstance(node_arg, tuple) and len(node_arg) == 1:\n",
" node.args = (node_arg[0],)\n",
@@ -120,12 +116,8 @@
" if len(args) == 1 and isinstance(args[0], list):\n",
" args = args[0]\n",
"\n",
" linalg_module = compile(\n",
" ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n",
" )\n",
" callable, _ = get_iree_compiled_module(\n",
" linalg_module, \"cuda\", func_name=\"forward\"\n",
" )\n",
" linalg_module = compile(ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS)\n",
" callable, _ = get_iree_compiled_module(linalg_module, \"cuda\", func_name=\"forward\")\n",
"\n",
" def forward(*inputs):\n",
" return callable(*inputs)\n",
@@ -220,7 +212,6 @@
" assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n",
" return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n",
"\n",
"\n",
"@torchdynamo.optimize(\"torch_mlir\")\n",
"def toy_example2(*args):\n",
" a, b = args\n",

View File

@@ -1,73 +0,0 @@
import torch
import numpy as np
model = torch.hub.load(
"pytorch/vision:v0.10.0", "squeezenet1_0", pretrained=True
)
model.eval()
# from PIL import Image
# from torchvision import transforms
# import urllib
#
# url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
# try: urllib.URLopener().retrieve(url, filename)
# except: urllib.request.urlretrieve(url, filename)
#
#
# input_image = Image.open(filename)
# preprocess = transforms.Compose([
# transforms.Resize(256),
# transforms.CenterCrop(224),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])
# input_tensor = preprocess(input_image)
# input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
# print(input_batch.shape) # size = [1, 3, 224, 224]
# The above is code for generating sample inputs from an image. We can just use
# random values for accuracy testing though
input_batch = torch.randn(1, 3, 224, 224)
# Focus on CPU for now
if False and torch.cuda.is_available():
input_batch = input_batch.to("cuda")
model.to("cuda")
with torch.no_grad():
output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
golden_confidences = output[0]
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
golden_probabilities = torch.nn.functional.softmax(
golden_confidences, dim=0
).numpy()
golden_confidences = golden_confidences.numpy()
from shark.torch_mlir_lockstep_tensor import TorchMLIRLockstepTensor
input_detached_clone = input_batch.clone()
eager_input_batch = TorchMLIRLockstepTensor(input_detached_clone)
print("getting torch-mlir result")
output = model(eager_input_batch)
static_output = output.elem
confidences = static_output[0]
probabilities = torch.nn.functional.softmax(
torch.from_numpy(confidences), dim=0
).numpy()
print("The obtained result via shark is: ", confidences)
print("The golden result is:", golden_confidences)
np.testing.assert_allclose(
golden_confidences, confidences, rtol=1e-02, atol=1e-03
)
np.testing.assert_allclose(
golden_probabilities, probabilities, rtol=1e-02, atol=1e-03
)

View File

@@ -22,7 +22,7 @@ class CLIPModule(tf.Module):
input_ids=x, attention_mask=y, pixel_values=z
)
@tf.function(input_signature=clip_vit_inputs, jit_compile=True)
@tf.function(input_signature=clip_vit_inputs)
def forward(self, input_ids, attention_mask, pixel_values):
return self.m.predict(
input_ids, attention_mask, pixel_values

View File

@@ -1,15 +0,0 @@
## Running ESRGAN
```
1. pip install numpy opencv-python
2. mkdir InputImages
(this is where all the input images will reside in)
3. mkdir OutputImages
(this is where the model will generate all the images)
4. mkdir models
(save the .pth checkpoint file here)
5. python esrgan.py
```
- Download [RRDB_ESRGAN_x4.pth](https://drive.google.com/drive/u/0/folders/17VYV_SoZZesU6mbxz2dMAIccSSlqLecY) and place it in the `models` directory as mentioned above in step 4.
- Credits : [ESRGAN](https://github.com/xinntao/ESRGAN)

View File

@@ -1,239 +0,0 @@
from ast import arg
import os.path as osp
import glob
import cv2
import numpy as np
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from shark.shark_inference import SharkInference
import torch_mlir
import tempfile
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
"""Residual in Residual Dense Block"""
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(
self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest"))
)
fea = self.lrelu(
self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest"))
)
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
############### Parsing args #####################
import argparse
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
p.add_argument("--device", type=str, default="cpu", help="the device to use")
p.add_argument(
"--mlir_loc",
type=str,
default=None,
help="location of the model's mlir file",
)
args = p.parse_args()
###################################################
def inference(input_m):
return model(input_m)
def load_mlir(mlir_loc):
import os
if mlir_loc == None:
return None
print(f"Trying to load the model from {mlir_loc}.")
with open(os.path.join(mlir_loc)) as f:
mlir_module = f.read()
return mlir_module
def compile_through_fx(model, inputs, mlir_loc=None):
module = load_mlir(mlir_loc)
if module == None:
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(inputs)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
print("Torchscript graph generated successfully")
module = torch_mlir.compile(
ts_g,
inputs,
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
mlir_model = str(module)
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
)
shark_module.compile()
return shark_module
model_path = "models/RRDB_ESRGAN_x4.pth" # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
# device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
device = torch.device("cpu")
test_img_folder = "InputImages/*"
model = RRDBNet(3, 3, 64, 23, gc=32)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)
print("Model path {:s}. \nTesting...".format(model_path))
if __name__ == "__main__":
idx = 0
for path in glob.glob(test_img_folder):
idx += 1
base = osp.splitext(osp.basename(path))[0]
print(idx, base)
# read images
img = cv2.imread(path, cv2.IMREAD_COLOR)
img = img * 1.0 / 255
img = torch.from_numpy(
np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))
).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
with torch.no_grad():
shark_module = compile_through_fx(inference, img_LR)
shark_output = shark_module.forward((img_LR,))
shark_output = torch.from_numpy(shark_output)
shark_output = (
shark_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
)
esrgan_output = (
model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
)
# SHARK OUTPUT
shark_output = np.transpose(shark_output[[2, 1, 0], :, :], (1, 2, 0))
shark_output = (shark_output * 255.0).round()
cv2.imwrite(
"OutputImages/{:s}_rlt_shark_output.png".format(base), shark_output
)
print("Generated SHARK's output")
# ESRGAN OUTPUT
esrgan_output = np.transpose(esrgan_output[[2, 1, 0], :, :], (1, 2, 0))
esrgan_output = (esrgan_output * 255.0).round()
cv2.imwrite(
"OutputImages/{:s}_rlt_esrgan_output.png".format(base),
esrgan_output,
)
print("Generated ESRGAN's output")

View File

@@ -18,23 +18,14 @@ class AlbertModule(torch.nn.Module):
self.model.eval()
def forward(self, input_ids, attention_mask):
return self.model(
input_ids=input_ids, attention_mask=attention_mask
).logits
return self.model(input_ids=input_ids, attention_mask=attention_mask).logits
if __name__ == "__main__":
# Prepping Data
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
text = "This [MASK] is very tasty."
encoded_inputs = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
return_tensors="pt",
)
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
encoded_inputs = tokenizer(text, padding='max_length', truncation=True, max_length=MAX_SEQUENCE_LENGTH, return_tensors="pt")
inputs = (encoded_inputs["input_ids"],encoded_inputs["attention_mask"])
mlir_importer = SharkImporter(
AlbertModule(),
inputs,
@@ -43,46 +34,26 @@ if __name__ == "__main__":
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=False, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, mlir_dialect="linalg"
)
shark_module = SharkInference(minilm_mlir, func_name, mlir_dialect="linalg")
shark_module.compile()
token_logits = torch.tensor(shark_module.forward(inputs))
mask_id = torch.where(
encoded_inputs["input_ids"] == tokenizer.mask_token_id
)[1]
mask_id = torch.where(encoded_inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_id, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
for token in top_5_tokens:
print(
f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
)
print(f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'")
while True:
try:
new_text = input("Give me a sentence with [MASK] to fill: ")
encoded_inputs = tokenizer(
new_text,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
return_tensors="pt",
)
inputs = (
encoded_inputs["input_ids"],
encoded_inputs["attention_mask"],
)
encoded_inputs = tokenizer(new_text, padding='max_length', truncation=True, max_length=MAX_SEQUENCE_LENGTH, return_tensors="pt")
inputs = (encoded_inputs["input_ids"],encoded_inputs["attention_mask"])
token_logits = torch.tensor(shark_module.forward(inputs))
mask_id = torch.where(
encoded_inputs["input_ids"] == tokenizer.mask_token_id
)[1]
mask_id = torch.where(encoded_inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_id, :]
top_5_tokens = (
torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
)
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
for token in top_5_tokens:
print(
f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
)
print(f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'")
except KeyboardInterrupt:
print("Exiting program.")
break

View File

@@ -18,17 +18,15 @@ BATCH_SIZE = 1
# Create a set of inputs
t5_inputs = [
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32)
]
class AlbertModule(tf.Module):
def __init__(self):
super(AlbertModule, self).__init__()
self.m = TFAutoModelForMaskedLM.from_pretrained("albert-base-v2")
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
self.m.predict = lambda x,y: self.m(input_ids=x, attention_mask=y)
@tf.function(input_signature=t5_inputs, jit_compile=True)
@tf.function(input_signature=t5_inputs)
def forward(self, input_ids, attention_mask):
return self.m.predict(input_ids, attention_mask)
@@ -38,14 +36,8 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
# text = "This is a great [MASK]."
text = "This [MASK] is very tasty."
encoded_inputs = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
return_tensors="tf",
)
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
encoded_inputs = tokenizer(text, padding='max_length', truncation=True, max_length=MAX_SEQUENCE_LENGTH, return_tensors="tf")
inputs = (encoded_inputs["input_ids"],encoded_inputs["attention_mask"])
mlir_importer = SharkImporter(
AlbertModule(),
inputs,
@@ -59,42 +51,22 @@ if __name__ == "__main__":
output_idx = 0
data_idx = 1
token_logits = shark_module.forward(inputs)[output_idx][data_idx]
mask_id = np.where(
tf.squeeze(encoded_inputs["input_ids"]) == tokenizer.mask_token_id
)
mask_id = np.where(tf.squeeze(encoded_inputs["input_ids"]) == tokenizer.mask_token_id)
mask_token_logits = token_logits[0, mask_id, :]
top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[0:5]
for token in top_5_tokens:
print(
f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
)
print(f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'")
while True:
try:
new_text = input("Give me a sentence with [MASK] to fill: ")
encoded_inputs = tokenizer(
new_text,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
return_tensors="tf",
)
inputs = (
encoded_inputs["input_ids"],
encoded_inputs["attention_mask"],
)
encoded_inputs = tokenizer(new_text, padding='max_length', truncation=True, max_length=MAX_SEQUENCE_LENGTH, return_tensors="tf")
inputs = (encoded_inputs["input_ids"],encoded_inputs["attention_mask"])
token_logits = shark_module.forward(inputs)[output_idx][data_idx]
mask_id = np.where(
tf.squeeze(encoded_inputs["input_ids"])
== tokenizer.mask_token_id
)
mask_id = np.where(tf.squeeze(encoded_inputs["input_ids"]) == tokenizer.mask_token_id)
mask_token_logits = token_logits[0, mask_id, :]
top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[
0:5
]
top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[0:5]
for token in top_5_tokens:
print(
f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
)
print(f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'")
except KeyboardInterrupt:
print("Exiting program.")
sys.exit()

View File

@@ -0,0 +1,35 @@
from PIL import Image
import requests
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from iree.compiler import tf as tfc
from iree.compiler import compile_str
from iree import runtime as ireert
import os
import numpy as np
MAX_SEQUENCE_LENGTH = 512
BATCH_SIZE = 1
if __name__ == "__main__":
# Prepping Data
model = AutoModelForMaskedLM.from_pretrained("albert-base-v2")
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
text = "This [MASK] is very tasty."
inputs = tokenizer(text, padding='max_length', truncation=True, max_length=MAX_SEQUENCE_LENGTH, return_tensors="pt")
token_logits = model(**inputs).logits
print(token_logits)
# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# print(mask_token_logits)
# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
print(np.argsort(mask_token_logits.detach().numpy()))
# print(top_5_tokens)
for token in top_5_tokens:
print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")

View File

@@ -1,14 +0,0 @@
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_model
mlir_model, func_name, inputs, golden_out = download_model(
"bloom", frontend="torch"
)
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"
)
shark_module.compile()
result = shark_module.forward(inputs)
print("The obtained result via shark is: ", result)
print("The golden result is:", golden_out)

View File

@@ -19,7 +19,7 @@ class GPT2Module(tf.Module):
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
@tf.function(input_signature=gpt2_inputs, jit_compile=True)
@tf.function(input_signature=gpt2_inputs)
def forward(self, input_ids, attention_mask):
return self.m.predict(input_ids, attention_mask)

View File

@@ -26,7 +26,7 @@ class BertModule(tf.Module):
input_ids=x, attention_mask=y, token_type_ids=z, training=False
)
@tf.function(input_signature=bert_input, jit_compile=True)
@tf.function(input_signature=bert_input)
def forward(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)

View File

@@ -1,15 +1,14 @@
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_model
from shark.shark_downloader import download_torch_model
mlir_model, func_name, inputs, golden_out = download_model(
"microsoft/MiniLM-L12-H384-uncased",
frontend="torch",
mlir_model, func_name, inputs, golden_out = download_torch_model(
"microsoft/MiniLM-L12-H384-uncased"
)
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
mlir_model, func_name, mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)

View File

@@ -26,7 +26,7 @@ class BertModule(tf.Module):
input_ids=x, attention_mask=y, token_type_ids=z, training=False
)
@tf.function(input_signature=bert_input, jit_compile=True)
@tf.function(input_signature=bert_input)
def forward(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)

View File

@@ -23,7 +23,7 @@ input = torch.randn(1, 3, 224, 224)
mlir_importer = SharkImporter(
ResnestModule(),
(input,),
(input),
frontend="torch",
)
@@ -33,7 +33,9 @@ mlir_importer = SharkImporter(
print(golden_out)
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
shark_module = SharkInference(
vision_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward((input,))
result = shark_module.forward((input))
print("Obtained result", result)

View File

@@ -1,76 +0,0 @@
from shark.shark_inference import SharkInference
from shark.parser import shark_args
import torch
import numpy as np
import sys
import torchvision.models as models
import torch_mlir
torch.manual_seed(0)
class VisionModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = models.resnet50(pretrained=True)
self.train(False)
def forward(self, input):
return self.model.forward(input)
model = VisionModule()
test_input = torch.randn(1, 3, 224, 224)
actual_out = model(test_input)
test_input_fp16 = test_input.to(device=torch.device("cuda"), dtype=torch.half)
model_fp16 = model.half()
model_fp16.eval()
model_fp16.to("cuda")
actual_out_fp16 = model_fp16(test_input_fp16)
ts_g = torch.jit.trace(model_fp16, [test_input_fp16])
module = torch_mlir.compile(
ts_g,
(test_input_fp16),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=True,
verbose=False,
)
# from contextlib import redirect_stdout
# with open('resnet50_fp16_linalg_ir.mlir', 'w') as f:
# with redirect_stdout(f):
# print(module.operation.get_asm())
mlir_model = module
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device="cuda", mlir_dialect="linalg"
)
shark_module.compile()
def shark_result(x):
x_ny = x.cpu().detach().numpy()
inputs = (x_ny,)
result = shark_module.forward(inputs)
return torch.from_numpy(result)
observed_out = shark_result(test_input_fp16)
print("Golden result:", actual_out_fp16)
print("SHARK result:", observed_out)
actual_out_fp16 = actual_out_fp16.to(device=torch.device("cpu"))
print(
torch.testing.assert_allclose(
actual_out_fp16, observed_out, rtol=1e-2, atol=1e-2
)
)

Some files were not shown because too many files have changed in this diff Show More