mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-12 07:18:27 -05:00
Compare commits
7 Commits
20230108.4
...
github-pag
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9c62e547c | ||
|
|
d84a86f6d2 | ||
|
|
dadd6640fb | ||
|
|
23501d34a1 | ||
|
|
9b9eef1d22 | ||
|
|
e4b156f3b4 | ||
|
|
ce26492a10 |
2
.github/workflows/gh-pages-releases.yml
vendored
2
.github/workflows/gh-pages-releases.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
- run: git fetch --all
|
||||
- run: git switch github-pages
|
||||
- run: git config --global user.email "none@none.com"
|
||||
- run: git config --global user.name "nod-ai"
|
||||
- run: git config --global user.name "nod-team"
|
||||
- run: mv /tmp/index.html package-index/index.html
|
||||
- run: git add package-index/index.html
|
||||
|
||||
|
||||
140
.github/workflows/nightly.yml
vendored
140
.github/workflows/nightly.yml
vendored
@@ -9,80 +9,7 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
windows-build:
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Compute version
|
||||
shell: powershell
|
||||
run: |
|
||||
$package_version = $(Get-Date -UFormat "%Y%m%d")+"."+${{ github.run_number }}
|
||||
$package_version_ = $(Get-Date -UFormat "%Y%m%d")+"_"+${{ github.run_number }}
|
||||
$tag_name=$package_version
|
||||
echo "package_version=$package_version" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
echo "package_version_=$package_version_" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
echo "tag_name=$tag_name" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ env.tag_name }}
|
||||
release_name: nod.ai SHARK ${{ env.tag_name }}
|
||||
body: |
|
||||
Automatic snapshot release of nod.ai SHARK.
|
||||
draft: true
|
||||
prerelease: false
|
||||
|
||||
- name: Build Package
|
||||
shell: powershell
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
pyinstaller web/shark_sd.spec
|
||||
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
|
||||
|
||||
# GHA windows VM OOMs so disable for now
|
||||
#- name: Build and validate the SHARK Runtime package
|
||||
# shell: powershell
|
||||
# run: |
|
||||
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
- uses: actions/upload-artifact@v2
|
||||
with:
|
||||
path: dist/*
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ./dist/*
|
||||
|
||||
- name: Publish Release
|
||||
id: publish_release
|
||||
uses: eregon/publish-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
|
||||
linux-build:
|
||||
build:
|
||||
|
||||
runs-on: a100
|
||||
strategy:
|
||||
@@ -105,13 +32,40 @@ jobs:
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
|
||||
- name: Compute version
|
||||
run: |
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
tag_name="${package_version}"
|
||||
echo "package_version=${package_version}" >> $GITHUB_ENV
|
||||
echo "tag_name=${tag_name}" >> $GITHUB_ENV
|
||||
- name: Set Environment Variables
|
||||
run: |
|
||||
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ env.tag_name }}
|
||||
release_name: nod.ai SHARK ${{ env.tag_name }}
|
||||
body: |
|
||||
Automatic snapshot release of nod.ai SHARK.
|
||||
draft: true
|
||||
prerelease: false
|
||||
- name: Find Torch-MLIR Release
|
||||
run: |
|
||||
TM_HTML_URL="$(python3 -c "import urllib.request, json, sys; u=json.loads(urllib.request.urlopen('https://api.github.com/repos/llvm/torch-mlir/releases/latest').read().decode()).get('html_url', False); print(u) if u else sys.exit(1);")"
|
||||
TM_RELEASE_DIR=${TM_HTML_URL/"tag"/"expanded_assets"}
|
||||
echo "TM_RELEASE_DIR=${TM_RELEASE_DIR}" >> $GITHUB_ENV
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
echo "Torch-MLIR Release DIR is ${{ env.TM_RELEASE_DIR }}"
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f ${{ env.TM_RELEASE_DIR }} -f https://github.com/nod-ai/SHARK-Runtime/releases; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
@@ -120,26 +74,25 @@ jobs:
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
|
||||
- name: Build and validate the IREE package
|
||||
if: ${{ matrix.backend == 'IREE' }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
|
||||
source iree.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://iree-org.github.io/iree/pip-release-links.html
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f ${{ env.TM_RELEASE_DIR }} -f https://github.com/iree-org/iree/releases
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" tank/test_models.py |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
if !(grep -Fxq " failed" pytest_results.txt)
|
||||
then
|
||||
export SHA=$(git log -1 --format='%h')
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
|
||||
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/latest/
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/$SHA
|
||||
gsutil -m cp -r gs://shark_tank/$SHA/* gs://shark_tank/latest/
|
||||
fi
|
||||
rm -rf ./wheelhouse/nodai*
|
||||
|
||||
@@ -151,10 +104,29 @@ jobs:
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f ${{ env.TM_RELEASE_DIR }} -f https://github.com/nod-ai/SHARK-Runtime/releases
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" tank/test_models.py |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
|
||||
- name: Upload Release Assets
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ${GITHUB_WORKSPACE}/wheelhouse/nodai_*.whl
|
||||
|
||||
- name: Publish Release
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
id: publish_release
|
||||
uses: eregon/publish-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
|
||||
39
.github/workflows/test-models.yml
vendored
39
.github/workflows/test-models.yml
vendored
@@ -6,24 +6,10 @@ name: Validate Models on Shark Runtime
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'shark/examples/**'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'shark/examples/**'
|
||||
workflow_dispatch:
|
||||
|
||||
# Ensure that only a single job or workflow using the same
|
||||
# concurrency group will run at a time. This would cancel
|
||||
# any in-progress jobs in the same github workflow and github
|
||||
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-validate:
|
||||
strategy:
|
||||
@@ -46,6 +32,8 @@ jobs:
|
||||
suite: cuda
|
||||
- os: MacStudio
|
||||
suite: cpu
|
||||
- os: MacStudio
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
@@ -102,7 +90,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cpu --update_tank
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k cpu
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
|
||||
@@ -112,25 +100,14 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cuda --update_tank
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k cuda
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
|
||||
- name: Validate Vulkan Models (MacOS)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
|
||||
- name: Validate Vulkan Models
|
||||
if: matrix.suite == 'vulkan'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
export DYLD_LIBRARY_PATH=/usr/local/lib/
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest -s --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" tank/test_models.py -k vulkan --update_tank
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os != 'MacStudio'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k vulkan --update_tank
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k vulkan
|
||||
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -31,6 +31,7 @@ MANIFEST
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
@@ -162,14 +163,7 @@ cython_debug/
|
||||
# Shark related artefacts
|
||||
*venv/
|
||||
shark_tmp/
|
||||
*.vmfb
|
||||
.use-iree
|
||||
tank/dict_configs.py
|
||||
|
||||
# ORT related artefacts
|
||||
cache_models/
|
||||
onnx_models/
|
||||
|
||||
#web logging
|
||||
web/logs/
|
||||
web/stored_results/stable_diffusion/
|
||||
|
||||
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -1,4 +0,0 @@
|
||||
[submodule "inference/thirdparty/shark-runtime"]
|
||||
path = inference/thirdparty/shark-runtime
|
||||
url =https://github.com/nod-ai/SHARK-Runtime.git
|
||||
branch = shark-06032022
|
||||
@@ -1,3 +0,0 @@
|
||||
[style]
|
||||
based_on_style = google
|
||||
column_limit = 80
|
||||
218
LICENSE
218
LICENSE
@@ -1,218 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
---- LLVM Exceptions to the Apache 2.0 License ----
|
||||
|
||||
As an exception, if, as a result of your compiling your source code, portions
|
||||
of this Software are embedded into an Object form of such source code, you
|
||||
may redistribute such embedded portions in such Object form without complying
|
||||
with the conditions of Sections 4(a), 4(b) and 4(d) of the License.
|
||||
|
||||
In addition, if you combine or link compiled forms of this Software with
|
||||
software that is licensed under the GPLv2 ("Combined Software") and if a
|
||||
court of competent jurisdiction determines that the patent provision (Section
|
||||
3), the indemnity provision (Section 9) or other Section of the License
|
||||
conflicts with the conditions of the GPLv2, you may retroactively and
|
||||
prospectively choose to deem waived or otherwise exclude such Section(s) of
|
||||
the License, but only in their entirety and only with respect to the Combined
|
||||
Software.
|
||||
336
README.md
336
README.md
@@ -1,336 +0,0 @@
|
||||
# SHARK
|
||||
|
||||
High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerators and Heterogeneous Clusters
|
||||
|
||||
[](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
|
||||
[](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
|
||||
|
||||
|
||||
## Installation (Windows, Linux and macOS)
|
||||
|
||||
## Check out the code
|
||||
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
cd SHARK
|
||||
```
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
|
||||
### Windows 10/11 Users
|
||||
|
||||
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
|
||||
|
||||
* Install Git for Windows from [here](https://git-scm.com/download/win)
|
||||
|
||||
#### Allow the install script to run in Powershell
|
||||
```powershell
|
||||
set-executionpolicy remotesigned
|
||||
```
|
||||
|
||||
#### Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...)
|
||||
```powershell
|
||||
./setup_venv.ps1 #You can re-run this script to get the latest version
|
||||
```
|
||||
|
||||
### Linux / macOS Users
|
||||
|
||||
```shell
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
```
|
||||
|
||||
|
||||
### Run Stable Diffusion on your device - WebUI
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\Users\nod\SHARK> cd web
|
||||
(shark.venv) PS C:\Users\nod\SHARK\web> python index.py
|
||||
```
|
||||
#### Linux Users
|
||||
```shell
|
||||
(shark.venv) > cd web
|
||||
(shark.venv) > python index.py
|
||||
```
|
||||
|
||||
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
|
||||
|
||||
|
||||
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
|
||||
|
||||
|
||||
|
||||
### Run Stable Diffusion on your device - Commandline
|
||||
|
||||
#### Install your hardware drivers
|
||||
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mril-iree)
|
||||
* [macOS Users] Download and install the latest Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home)
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
Other users please ensure you have your latest vendor drivers and Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home) and if you are using vulkan check `vulkaninfo` works in a terminal window
|
||||
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> python .\shark\examples\shark_inference\stable_diffusion\main.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
```
|
||||
|
||||
#### Linux / macOS Users
|
||||
```shell
|
||||
python3.10 shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
```
|
||||
|
||||
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
|
||||
|
||||
The output on a 6900XT would like:
|
||||
|
||||
```shell
|
||||
44it [00:08, 5.14it/s]i = 44 t = 120 (191ms)
|
||||
45it [00:08, 5.15it/s]i = 45 t = 100 (191ms)
|
||||
46it [00:08, 5.16it/s]i = 46 t = 80 (191ms)
|
||||
47it [00:09, 5.16it/s]i = 47 t = 60 (193ms)
|
||||
48it [00:09, 5.15it/s]i = 48 t = 40 (195ms)
|
||||
49it [00:09, 5.12it/s]i = 49 t = 20 (196ms)
|
||||
50it [00:09, 5.14it/s]
|
||||
Average step time: 192.8154182434082ms/it
|
||||
Total image generation runtime (s): 10.390909433364868
|
||||
(shark.venv) PS C:\g\shark>
|
||||
```
|
||||
|
||||
Here are some samples generated:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)
|
||||
|
||||
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Binary Installation</summary>
|
||||
|
||||
### Setup a new pip Virtual Environment
|
||||
|
||||
This step sets up a new VirtualEnv for Python
|
||||
|
||||
```shell
|
||||
python --version #Check you have 3.10 on Linux, macOS or Windows Powershell
|
||||
python -m venv shark_venv
|
||||
source shark_venv/bin/activate # Use shark_venv/Scripts/activate on Windows
|
||||
|
||||
# If you are using conda create and activate a new conda env
|
||||
|
||||
# Some older pip installs may not be able to handle the recent PyTorch deps
|
||||
python -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
*macOS Metal* users please install https://sdk.lunarg.com/sdk/download/latest/mac/vulkan-sdk.dmg and enable "System wide install"
|
||||
|
||||
### Install SHARK
|
||||
|
||||
This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10
|
||||
|
||||
```shell
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
|
||||
### Run shark tank model tests.
|
||||
```shell
|
||||
pytest tank/test_models.py
|
||||
```
|
||||
See tank/README.md for a more detailed walkthrough of our pytest suite and CLI.
|
||||
|
||||
### Download and run Resnet50 sample
|
||||
|
||||
```shell
|
||||
curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/resnet50_script.py
|
||||
#Install deps for test script
|
||||
pip install --pre torch torchvision torchaudio tqdm pillow gsutil --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python ./resnet50_script.py --device="cpu" #use cuda or vulkan or metal
|
||||
```
|
||||
|
||||
### Download and run BERT (MiniLM) sample
|
||||
```shell
|
||||
curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/minilm_jit.py
|
||||
#Install deps for test script
|
||||
pip install transformers torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Development, Testing and Benchmarks</summary>
|
||||
|
||||
If you want to use Python3.10 and with TF Import tools you can use the environment variables like:
|
||||
Set `USE_IREE=1` to use upstream IREE
|
||||
```
|
||||
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
|
||||
```
|
||||
|
||||
### Run any of the hundreds of SHARK tank models via the test framework
|
||||
```shell
|
||||
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
|
||||
# Or a pytest
|
||||
pytest tank/test_models.py -k "MiniLM"
|
||||
```
|
||||
|
||||
|
||||
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
|
||||
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
|
||||
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
|
||||
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
|
||||
for Torch-MLIR.
|
||||
|
||||
### How to use your locally built Torch-MLIR with SHARK
|
||||
```shell
|
||||
1.) Run `./setup_venv.sh in SHARK` and activate `shark.venv` virtual env.
|
||||
2.) Run `pip uninstall torch-mlir`.
|
||||
3.) Go to your local Torch-MLIR directory.
|
||||
4.) Activate mlir_venv virtual envirnoment.
|
||||
5.) Run `pip uninstall -r requirements.txt`.
|
||||
6.) Run `pip install -r requirements.txt`.
|
||||
7.) Build Torch-MLIR.
|
||||
8.) Activate shark.venv virtual environment from the Torch-MLIR directory.
|
||||
8.) Run `export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples` in the Torch-MLIR directory.
|
||||
9.) Go to the SHARK directory.
|
||||
```
|
||||
Now the SHARK will use your locally build Torch-MLIR repo.
|
||||
|
||||
|
||||
## Benchmarking Dispatches
|
||||
|
||||
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your command line argument.
|
||||
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
|
||||
|
||||
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
|
||||
|
||||
```
|
||||
shark_module = SharkInference(
|
||||
mlir_model,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
dispatch_benchmarks="all",
|
||||
dispatch_benchmarks_dir="results"
|
||||
)
|
||||
```
|
||||
|
||||
Output will include:
|
||||
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
|
||||
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
|
||||
- An .mlir file containing the dispatch benchmark
|
||||
- A compiled .vmfb file containing the dispatch benchmark
|
||||
- An .mlir file containing just the hal executable
|
||||
- A compiled .vmfb file of the hal executable
|
||||
- A .txt file containing benchmark output
|
||||
|
||||
|
||||
See tank/README.md for instructions on how to run model tests and benchmarks from the SHARK tank.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>API Reference</summary>
|
||||
|
||||
### Shark Inference API
|
||||
|
||||
```
|
||||
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
# SharkImporter imports mlir file from the torch, tensorflow or tf-lite module.
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
torch_module,
|
||||
(input),
|
||||
frontend="torch", #tf, #tf-lite
|
||||
)
|
||||
torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
|
||||
|
||||
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input))
|
||||
|
||||
```
|
||||
|
||||
|
||||
### Example demonstrating running MHLO IR.
|
||||
|
||||
```
|
||||
from shark.shark_inference import SharkInference
|
||||
import numpy as np
|
||||
|
||||
mhlo_ir = r"""builtin.module {
|
||||
func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32>
|
||||
%1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
return %1 : tensor<4x4xf32>
|
||||
}
|
||||
}"""
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((arg0, arg1))
|
||||
```
|
||||
</details>
|
||||
|
||||
## Supported and Validated Models
|
||||
|
||||
SHARK is maintained to support the latest innovations in ML Models:
|
||||
|
||||
| TF HuggingFace Models | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DistilBERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| GPT2 | :green_heart: | :green_heart: | :green_heart: |
|
||||
| BLOOM | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Stable Diffusion | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Vision Transformer | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ResNet50 | :green_heart: | :green_heart: | :green_heart: |
|
||||
|
||||
For a complete list of the models supported in SHARK, please refer to [tank/README.md](https://github.com/nod-ai/SHARK/blob/main/tank/README.md).
|
||||
|
||||
## Communication Channels
|
||||
|
||||
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
|
||||
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
|
||||
|
||||
## Related Projects
|
||||
|
||||
<details>
|
||||
<summary>IREE Project Channels</summary>
|
||||
|
||||
* [Upstream IREE issues](https://github.com/google/iree/issues): Feature requests,
|
||||
bugs, and other work tracking
|
||||
* [Upstream IREE Discord server](https://discord.gg/26P4xW4): Daily development
|
||||
discussions with the core team and collaborators
|
||||
* [iree-discuss email list](https://groups.google.com/forum/#!forum/iree-discuss):
|
||||
Announcements, general and low-priority discussion
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>MLIR and Torch-MLIR Project Channels</summary>
|
||||
|
||||
* `#torch-mlir` channel on the LLVM [Discord](https://discord.gg/xS7Z362) - this is the most active communication channel
|
||||
* Torch-MLIR Github issues [here](https://github.com/llvm/torch-mlir/issues)
|
||||
* [`torch-mlir` section](https://llvm.discourse.group/c/projects-that-want-to-become-official-llvm-projects/torch-mlir/41) of LLVM Discourse
|
||||
* Weekly meetings on Mondays 9AM PST. See [here](https://discourse.llvm.org/t/community-meeting-developer-hour-refactoring-recurring-meetings/62575) for more information.
|
||||
* [MLIR topic within LLVM Discourse](https://llvm.discourse.group/c/llvm-project/mlir/31) SHARK and IREE is enabled by and heavily relies on [MLIR](https://mlir.llvm.org).
|
||||
</details>
|
||||
|
||||
## License
|
||||
|
||||
nod.ai SHARK is licensed under the terms of the Apache 2.0 License with LLVM Exceptions.
|
||||
See [LICENSE](LICENSE) for more information.
|
||||
@@ -1,22 +0,0 @@
|
||||
import torch
|
||||
from shark.parser import parser
|
||||
from benchmarks.hf_transformer import SharkHFBenchmarkRunner
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help='Specifies name of HF model to benchmark. (For exmaple "microsoft/MiniLM-L12-H384-uncased"',
|
||||
)
|
||||
load_args, unknown = parser.parse_known_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_name = load_args.model_name
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
shark_module = SharkHFBenchmarkRunner(
|
||||
model_name, (test_input,), jit_trace=True
|
||||
)
|
||||
shark_module.benchmark_c()
|
||||
shark_module.benchmark_python((test_input,))
|
||||
shark_module.benchmark_torch(test_input)
|
||||
shark_module.benchmark_onnx(test_input)
|
||||
@@ -1,181 +0,0 @@
|
||||
import torch
|
||||
from shark.shark_benchmark_runner import SharkBenchmarkRunner
|
||||
from shark.parser import shark_args
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from onnxruntime.transformers.benchmark import (
|
||||
run_pytorch,
|
||||
run_tensorflow,
|
||||
run_onnxruntime,
|
||||
)
|
||||
from onnxruntime.transformers.huggingface_models import MODELS
|
||||
from onnxruntime.transformers.benchmark_helper import ConfigModifier, Precision
|
||||
import os
|
||||
import psutil
|
||||
|
||||
|
||||
class OnnxFusionOptions(object):
|
||||
def __init__(self):
|
||||
self.disable_gelu = False
|
||||
self.disable_layer_norm = False
|
||||
self.disable_attention = False
|
||||
self.disable_skip_layer_norm = False
|
||||
self.disable_embed_layer_norm = False
|
||||
self.disable_bias_skip_layer_norm = False
|
||||
self.disable_bias_gelu = False
|
||||
self.enable_gelu_approximation = False
|
||||
self.use_mask_index = False
|
||||
self.no_attention_mask = False
|
||||
|
||||
|
||||
class HuggingFaceLanguage(torch.nn.Module):
|
||||
def __init__(self, hf_model_name):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
hf_model_name, # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
class SharkHFBenchmarkRunner(SharkBenchmarkRunner):
|
||||
# SharkRunner derived class with Benchmarking capabilities.
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
input: tuple,
|
||||
dynamic: bool = False,
|
||||
device: str = None,
|
||||
jit_trace: bool = False,
|
||||
from_aot: bool = False,
|
||||
frontend: str = "torch",
|
||||
):
|
||||
self.device = device if device is not None else shark_args.device
|
||||
if self.device == "gpu":
|
||||
raise ValueError(
|
||||
"Currently GPU Benchmarking is not supported due to OOM from ORT."
|
||||
)
|
||||
self.model_name = model_name
|
||||
model = HuggingFaceLanguage(model_name)
|
||||
SharkBenchmarkRunner.__init__(
|
||||
self,
|
||||
model,
|
||||
input,
|
||||
dynamic,
|
||||
self.device,
|
||||
jit_trace,
|
||||
from_aot,
|
||||
frontend,
|
||||
)
|
||||
|
||||
def benchmark_torch(self, inputs):
|
||||
use_gpu = self.device == "gpu"
|
||||
# Set set the model's layer number to automatic.
|
||||
config_modifier = ConfigModifier(None)
|
||||
num_threads = psutil.cpu_count(logical=False)
|
||||
batch_sizes = [inputs.shape[0]]
|
||||
sequence_lengths = [inputs.shape[-1]]
|
||||
cache_dir = os.path.join(".", "cache_models")
|
||||
verbose = False
|
||||
result = run_pytorch(
|
||||
use_gpu,
|
||||
[self.model_name],
|
||||
None,
|
||||
config_modifier,
|
||||
Precision.FLOAT32,
|
||||
num_threads,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
shark_args.num_iterations,
|
||||
False,
|
||||
cache_dir,
|
||||
verbose,
|
||||
)
|
||||
print(
|
||||
f"ONNX Pytorch-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
|
||||
# TODO: Currently non-functional due to TF runtime error. There might be some issue with, initializing TF.
|
||||
def benchmark_tf(self, inputs):
|
||||
use_gpu = self.device == "gpu"
|
||||
# Set set the model's layer number to automatic.
|
||||
config_modifier = ConfigModifier(None)
|
||||
num_threads = psutil.cpu_count(logical=False)
|
||||
batch_sizes = [inputs.shape[0]]
|
||||
sequence_lengths = [inputs.shape[-1]]
|
||||
cache_dir = os.path.join(".", "cache_models")
|
||||
verbose = False
|
||||
result = run_tensorflow(
|
||||
use_gpu,
|
||||
[self.model_name],
|
||||
None,
|
||||
config_modifier,
|
||||
Precision.FLOAT32,
|
||||
num_threads,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
shark_args.num_iterations,
|
||||
cache_dir,
|
||||
verbose,
|
||||
)
|
||||
print(
|
||||
f"ONNX TF-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
|
||||
def benchmark_onnx(self, inputs):
|
||||
if self.model_name not in MODELS:
|
||||
print(
|
||||
f"{self.model_name} is currently not supported in ORT's HF. Check \
|
||||
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/huggingface_models.py \
|
||||
for currently supported models. Exiting benchmark ONNX."
|
||||
)
|
||||
return
|
||||
use_gpu = self.device == "gpu"
|
||||
num_threads = psutil.cpu_count(logical=False)
|
||||
batch_sizes = [inputs.shape[0]]
|
||||
sequence_lengths = [inputs.shape[-1]]
|
||||
cache_dir = os.path.join(".", "cache_models")
|
||||
onnx_dir = os.path.join(".", "onnx_models")
|
||||
verbose = False
|
||||
input_counts = [1]
|
||||
optimize_onnx = True
|
||||
validate_onnx = False
|
||||
disable_ort_io_binding = False
|
||||
use_raw_attention_mask = True
|
||||
model_fusion_statistics = {}
|
||||
overwrite = False
|
||||
model_source = "pt" # Either "pt" or "tf"
|
||||
provider = None
|
||||
config_modifier = ConfigModifier(None)
|
||||
onnx_args = OnnxFusionOptions()
|
||||
result = run_onnxruntime(
|
||||
use_gpu,
|
||||
provider,
|
||||
[self.model_name],
|
||||
None,
|
||||
config_modifier,
|
||||
Precision.FLOAT32,
|
||||
num_threads,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
shark_args.num_iterations,
|
||||
input_counts,
|
||||
optimize_onnx,
|
||||
validate_onnx,
|
||||
cache_dir,
|
||||
onnx_dir,
|
||||
verbose,
|
||||
overwrite,
|
||||
disable_ort_io_binding,
|
||||
use_raw_attention_mask,
|
||||
model_fusion_statistics,
|
||||
model_source,
|
||||
onnx_args,
|
||||
)
|
||||
print(
|
||||
f"ONNX ORT-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
@@ -1,231 +0,0 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
|
||||
import torch
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import torchvision.models as models
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
BertTokenizer,
|
||||
TFBertModel,
|
||||
)
|
||||
import importlib
|
||||
import pytest
|
||||
import unittest
|
||||
|
||||
torch.manual_seed(0)
|
||||
gpus = tf.config.experimental.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
|
||||
##################### Tensorflow Hugging Face LM Models ###################################
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
# Create a set of 2-dimensional inputs
|
||||
tf_bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class TFHuggingFaceLanguage(tf.Module):
|
||||
def __init__(self, hf_model_name):
|
||||
super(TFHuggingFaceLanguage, self).__init__()
|
||||
# Create a BERT trainer with the created network.
|
||||
self.m = TFBertModel.from_pretrained(hf_model_name, from_pt=True)
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m.predict = lambda x, y, z: self.m.call(
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=tf_bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
def get_TFhf_model(name):
|
||||
model = TFHuggingFaceLanguage(name)
|
||||
tokenizer = BertTokenizer.from_pretrained(name)
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
test_input = (
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
)
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
##################### Hugging Face LM Models ###################################
|
||||
|
||||
|
||||
class HuggingFaceLanguage(torch.nn.Module):
|
||||
def __init__(self, hf_model_name):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
hf_model_name, # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
def get_hf_model(name):
|
||||
model = HuggingFaceLanguage(name)
|
||||
# TODO: Currently the test input is set to (1,128)
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
actual_out = model(test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
##################### Torch Vision Models ###################################
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.train(False)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
def get_vision_model(torch_model):
|
||||
model = VisionModule(torch_model)
|
||||
# TODO: Currently the test input is set to (1,128)
|
||||
test_input = torch.randn(1, 3, 224, 224)
|
||||
actual_out = model(test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
############################# Benchmark Tests ####################################
|
||||
|
||||
pytest_benchmark_param = pytest.mark.parametrize(
|
||||
("dynamic", "device"),
|
||||
[
|
||||
pytest.param(False, "cpu"),
|
||||
# TODO: Language models are failing for dynamic case..
|
||||
pytest.param(True, "cpu", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"gpu",
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
),
|
||||
),
|
||||
pytest.param(True, "gpu", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"vulkan",
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
"vulkan",
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("iree.tools") is None,
|
||||
reason="Cannot find tools to import TF",
|
||||
)
|
||||
@pytest_benchmark_param
|
||||
def test_bench_minilm_torch(dynamic, device):
|
||||
model, test_input, act_out = get_hf_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(test_input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
try:
|
||||
# If becnhmarking succesful, assert success/True.
|
||||
shark_module.compile()
|
||||
shark_module.benchmark_all((test_input,))
|
||||
assert True
|
||||
except Exception as e:
|
||||
# If anything happen during benchmarking, assert False/failure.
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("iree.tools") is None,
|
||||
reason="Cannot find tools to import TF",
|
||||
)
|
||||
@pytest_benchmark_param
|
||||
def test_bench_distilbert(dynamic, device):
|
||||
model, test_input, act_out = get_TFhf_model("distilbert-base-uncased")
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
test_input,
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
try:
|
||||
# If becnhmarking succesful, assert success/True.
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
shark_module.benchmark_all(test_input)
|
||||
assert True
|
||||
except Exception as e:
|
||||
# If anything happen during benchmarking, assert False/failure.
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="XLM Roberta too large to test.")
|
||||
@pytest_benchmark_param
|
||||
def test_bench_xlm_roberta(dynamic, device):
|
||||
model, test_input, act_out = get_TFhf_model("xlm-roberta-base")
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
test_input,
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
try:
|
||||
# If becnhmarking succesful, assert success/True.
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
shark_module.benchmark_all(test_input)
|
||||
assert True
|
||||
except Exception as e:
|
||||
# If anything happen during benchmarking, assert False/failure.
|
||||
assert False
|
||||
@@ -1,45 +0,0 @@
|
||||
import torch
|
||||
from benchmarks.hf_transformer import SharkHFBenchmarkRunner
|
||||
import importlib
|
||||
import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
############################# HF Benchmark Tests ####################################
|
||||
|
||||
# Test running benchmark module without failing.
|
||||
pytest_benchmark_param = pytest.mark.parametrize(
|
||||
("dynamic", "device"),
|
||||
[
|
||||
pytest.param(False, "cpu"),
|
||||
# TODO: Language models are failing for dynamic case..
|
||||
pytest.param(True, "cpu", marks=pytest.mark.skip),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("onnxruntime") is None,
|
||||
reason="Cannot find ONNXRUNTIME.",
|
||||
)
|
||||
@pytest_benchmark_param
|
||||
def test_HFbench_minilm_torch(dynamic, device):
|
||||
model_name = "bert-base-uncased"
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
try:
|
||||
shark_module = SharkHFBenchmarkRunner(
|
||||
model_name,
|
||||
(test_input,),
|
||||
jit_trace=True,
|
||||
dynamic=dynamic,
|
||||
device=device,
|
||||
)
|
||||
shark_module.benchmark_c()
|
||||
shark_module.benchmark_python((test_input,))
|
||||
shark_module.benchmark_torch(test_input)
|
||||
shark_module.benchmark_onnx(test_input)
|
||||
# If becnhmarking succesful, assert success/True.
|
||||
assert True
|
||||
except Exception as e:
|
||||
# If anything happen during benchmarking, assert False/failure.
|
||||
assert False
|
||||
@@ -1,5 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
source $GITHUB_WORKSPACE/shark.venv/bin/activate
|
||||
python generate_sharktank.py --upload=False --ci_tank_dir=True
|
||||
@@ -1,37 +0,0 @@
|
||||
"""Scrapes the github releases API to generate a static pip-install-able releases page.
|
||||
|
||||
See https://github.com/llvm/torch-mlir/issues/1374
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
# Parse arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("owner", type=str)
|
||||
parser.add_argument("repo", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get releases
|
||||
response = requests.get(
|
||||
f"https://api.github.com/repos/{args.owner}/{args.repo}/releases"
|
||||
)
|
||||
body = json.loads(response.content)
|
||||
|
||||
# Parse releases
|
||||
releases = []
|
||||
for row in body:
|
||||
for asset in row["assets"]:
|
||||
releases.append((asset["name"], asset["browser_download_url"]))
|
||||
|
||||
# Output HTML
|
||||
html = """<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
"""
|
||||
for name, url in releases:
|
||||
html += f" <a href='{url}'>{name}</a><br />\n"
|
||||
html += """ </body>
|
||||
</html>"""
|
||||
print(html)
|
||||
62
conftest.py
62
conftest.py
@@ -1,62 +0,0 @@
|
||||
def pytest_addoption(parser):
|
||||
# Attaches SHARK command-line arguments to the pytest machinery.
|
||||
parser.addoption(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to benchmark and write results.csv",
|
||||
)
|
||||
parser.addoption(
|
||||
"--onnx_bench",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Add ONNX benchmark results to pytest benchmarks.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--tf32",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Use TensorFloat-32 calculations.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--save_repro",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to save reproduction artifacts to SHARK/shark_tmp/test_case/",
|
||||
)
|
||||
parser.addoption(
|
||||
"--save_fails",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Save reproduction artifacts for a test case only if it fails. Default is False.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--ci",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Enables uploading of reproduction artifacts upon test case failure during iree-compile or validation. Must be passed with --ci_sha option ",
|
||||
)
|
||||
parser.addoption(
|
||||
"--update_tank",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Update local shark tank with latest artifacts.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--ci_sha",
|
||||
action="store",
|
||||
default="None",
|
||||
help="Passes the github SHA of the CI workflow to include in google storage directory for reproduction artifacts.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--local_tank_cache",
|
||||
action="store",
|
||||
default="",
|
||||
help="Specify the directory in which all downloaded shark_tank artifacts will be cached.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--tank_url",
|
||||
type=str,
|
||||
default="gs://shark_tank/latest",
|
||||
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
|
||||
)
|
||||
3
cpp/.gitignore
vendored
3
cpp/.gitignore
vendored
@@ -1,3 +0,0 @@
|
||||
*.mlir
|
||||
*.vmfb
|
||||
*.ini
|
||||
@@ -1,52 +0,0 @@
|
||||
# Copyright 2022 The IREE Authors
|
||||
#
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
cmake_minimum_required(VERSION 3.21...3.23)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Project configuration
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
project(iree-samples C CXX)
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Core project dependency
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
message(STATUS "Fetching core IREE repo (this may take a few minutes)...")
|
||||
# Note: for log output, set -DFETCHCONTENT_QUIET=OFF,
|
||||
# see https://gitlab.kitware.com/cmake/cmake/-/issues/18238#note_440475
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
iree
|
||||
GIT_REPOSITORY https://github.com/nod-ai/shark-runtime.git
|
||||
GIT_TAG shark
|
||||
GIT_SUBMODULES_RECURSE OFF
|
||||
GIT_SHALLOW OFF
|
||||
GIT_PROGRESS ON
|
||||
USES_TERMINAL_DOWNLOAD ON
|
||||
)
|
||||
|
||||
# Extend module path to find MLIR CMake modules.
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_BINARY_DIR}/lib/cmake/mlir")
|
||||
|
||||
# Disable core project features not needed for these out of tree samples.
|
||||
set(IREE_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(IREE_BUILD_SAMPLES OFF CACHE BOOL "" FORCE)
|
||||
|
||||
FetchContent_MakeAvailable(iree)
|
||||
FetchContent_GetProperties(iree SOURCE_DIR IREE_SOURCE_DIR)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Individual samples
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
add_subdirectory(vulkan_gui)
|
||||
@@ -1,82 +0,0 @@
|
||||
# SHARK C/C++ Samples
|
||||
|
||||
These C/C++ samples can be built using CMake. The samples depend on the main
|
||||
SHARK-Runtime project's C/C++ sources, including both the runtime and the compiler.
|
||||
|
||||
Individual samples may require additional dependencies. Watch CMake's output
|
||||
for information about which you are missing for individual samples.
|
||||
|
||||
On Windows we recommend using https://github.com/microsoft/vcpkg to download packages for
|
||||
your system. The general setup flow looks like
|
||||
|
||||
*Install and activate SHARK*
|
||||
|
||||
```bash
|
||||
source shark.venv/bin/activate #follow main repo instructions to setup your venv
|
||||
```
|
||||
|
||||
*Install Dependencies*
|
||||
|
||||
```bash
|
||||
vcpkg install [library] --triplet [your platform]
|
||||
vcpkg integrate install
|
||||
|
||||
# Then pass `-DCMAKE_TOOLCHAIN_FILE=[check logs for path]` when configuring CMake
|
||||
```
|
||||
|
||||
In Ubuntu Linux you can install
|
||||
|
||||
```bash
|
||||
sudo apt install libsdl2-dev
|
||||
```
|
||||
|
||||
*Build*
|
||||
```bash
|
||||
cd cpp
|
||||
cmake -GNinja -B build/
|
||||
cmake --build build/
|
||||
```
|
||||
|
||||
*Prepare the model*
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvm-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
```
|
||||
*Prepare the input*
|
||||
|
||||
```bash
|
||||
python save_img.py
|
||||
```
|
||||
Note that this requires tensorflow, e.g.
|
||||
```bash
|
||||
python -m pip install tensorflow
|
||||
```
|
||||
|
||||
*Run the vulkan_gui*
|
||||
```bash
|
||||
./build/vulkan_gui/iree-samples-resnet-vulkan-gui
|
||||
```
|
||||
|
||||
## Other models
|
||||
A tool for benchmarking other models is built and can be invoked with a command like the following
|
||||
```bash
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=path/to/.vmfb --function_input=...
|
||||
```
|
||||
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
|
||||
```
|
||||
VAE and Autoencoder are also available
|
||||
```bash
|
||||
# VAE
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
|
||||
|
||||
# CLIP Autoencoder
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
|
||||
```
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 26 KiB |
@@ -1,18 +0,0 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
|
||||
def load_and_preprocess_image(fname: str):
|
||||
image = tf.io.read_file(fname)
|
||||
image = tf.image.decode_image(image, channels=3)
|
||||
image = tf.image.resize(image, (224, 224))
|
||||
image = image[tf.newaxis, :]
|
||||
# preprocessing pipeline
|
||||
input_tensor = tf.keras.applications.resnet50.preprocess_input(image)
|
||||
return input_tensor
|
||||
|
||||
|
||||
data = load_and_preprocess_image("dog_imagenet.jpg").numpy()
|
||||
|
||||
data.tofile("dog.bin")
|
||||
@@ -1,84 +0,0 @@
|
||||
# Copyright 2022 The IREE Authors
|
||||
#
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
if(NOT IREE_TARGET_BACKEND_LLVM_CPU OR
|
||||
NOT IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF)
|
||||
message(STATUS "Missing LLVM backend and/or embeddded elf loader, skipping vision_inference sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# vcpkg install stb
|
||||
# tested with version 2021-09-10
|
||||
find_package(Stb)
|
||||
if(NOT Stb_FOUND)
|
||||
message(STATUS "Could not find Stb, skipping vision inference sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Compile mnist.mlir to mnist.vmfb.
|
||||
set(_COMPILE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)
|
||||
set(_COMPILE_ARGS)
|
||||
list(APPEND _COMPILE_ARGS "--iree-input-type=mhlo")
|
||||
list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu")
|
||||
list(APPEND _COMPILE_ARGS "${IREE_SOURCE_DIR}/samples/models/mnist.mlir")
|
||||
list(APPEND _COMPILE_ARGS "-o")
|
||||
list(APPEND _COMPILE_ARGS "mnist.vmfb")
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb
|
||||
COMMAND ${_COMPILE_TOOL_EXECUTABLE} ${_COMPILE_ARGS}
|
||||
DEPENDS ${_COMPILE_TOOL_EXECUTABLE} "${IREE_SOURCE_DIR}/samples/models/mnist.mlir"
|
||||
)
|
||||
# Embed mnist.vmfb into a C file as mnist_bytecode_module_c.[h/c]
|
||||
set(_EMBED_DATA_EXECUTABLE $<TARGET_FILE:generate_embed_data>)
|
||||
set(_EMBED_ARGS)
|
||||
list(APPEND _EMBED_ARGS "--output_header=mnist_bytecode_module_c.h")
|
||||
list(APPEND _EMBED_ARGS "--output_impl=mnist_bytecode_module_c.c")
|
||||
list(APPEND _EMBED_ARGS "--identifier=iree_samples_vision_inference_mnist_bytecode_module")
|
||||
list(APPEND _EMBED_ARGS "--flatten")
|
||||
list(APPEND _EMBED_ARGS "${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb")
|
||||
add_custom_command(
|
||||
OUTPUT "mnist_bytecode_module_c.h" "mnist_bytecode_module_c.c"
|
||||
COMMAND ${_EMBED_DATA_EXECUTABLE} ${_EMBED_ARGS}
|
||||
DEPENDS ${_EMBED_DATA_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb
|
||||
)
|
||||
# Define a library target for mnist_bytecode_module_c.
|
||||
add_library(iree_samples_vision_inference_mnist_bytecode_module_c OBJECT)
|
||||
target_sources(iree_samples_vision_inference_mnist_bytecode_module_c
|
||||
PRIVATE
|
||||
mnist_bytecode_module_c.h
|
||||
mnist_bytecode_module_c.c
|
||||
)
|
||||
|
||||
# Define the sample executable.
|
||||
set(_NAME "iree-run-mnist-module")
|
||||
add_executable(${_NAME} "")
|
||||
target_sources(${_NAME}
|
||||
PRIVATE
|
||||
"image_util.h"
|
||||
"image_util.c"
|
||||
"iree-run-mnist-module.c"
|
||||
)
|
||||
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "iree-run-mnist-module")
|
||||
target_include_directories(${_NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
|
||||
)
|
||||
target_include_directories(${_NAME} PRIVATE
|
||||
${Stb_INCLUDE_DIR}
|
||||
)
|
||||
target_link_libraries(${_NAME}
|
||||
iree_base_base
|
||||
iree_base_tracing
|
||||
iree_hal_hal
|
||||
iree_runtime_runtime
|
||||
iree_samples_vision_inference_mnist_bytecode_module_c
|
||||
)
|
||||
|
||||
# Define a target that copies the test image into the build directory.
|
||||
add_custom_target(iree_samples_vision_inference_test_image
|
||||
COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/mnist_test.png" "${CMAKE_CURRENT_BINARY_DIR}/mnist_test.png")
|
||||
add_dependencies(${_NAME} iree_samples_vision_inference_test_image)
|
||||
|
||||
message(STATUS "Configured vision_inference sample successfully")
|
||||
@@ -1,8 +0,0 @@
|
||||
# Vision Inference Sample (C code)
|
||||
|
||||
This sample demonstrates how to run a MNIST handwritten digit detection vision
|
||||
model on an image using IREE's C API.
|
||||
|
||||
A similar sample is implemented using a Python script and IREE's command line
|
||||
tools over in the primary iree repository at
|
||||
https://github.com/iree-org/iree/tree/main/samples/vision_inference
|
||||
@@ -1,224 +0,0 @@
|
||||
// Copyright 2021 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
#include "image_util.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include "iree/base/internal/flags.h"
|
||||
#include "iree/base/tracing.h"
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
iree_status_t iree_tools_utils_pixel_rescaled_to_buffer(
|
||||
const uint8_t* pixel_data, iree_host_size_t buffer_length,
|
||||
const float* input_range, iree_host_size_t range_length,
|
||||
float* out_buffer) {
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
if (range_length != 2) {
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"range defined as 2-element [min, max] array.");
|
||||
}
|
||||
float input_scale = fabsf(input_range[1] - input_range[0]) / 2.0f;
|
||||
float input_offset = (input_range[0] + input_range[1]) / 2.0f;
|
||||
const float kUint8Mean = 127.5f;
|
||||
for (int i = 0; i < buffer_length; ++i) {
|
||||
out_buffer[i] =
|
||||
(((float)(pixel_data[i])) - kUint8Mean) / kUint8Mean * input_scale +
|
||||
input_offset;
|
||||
}
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return iree_ok_status();
|
||||
}
|
||||
|
||||
iree_status_t iree_tools_utils_load_pixel_data_impl(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length) {
|
||||
int img_dims[3];
|
||||
if (stbi_info(filename.data, img_dims, &(img_dims[1]), &(img_dims[2])) == 0) {
|
||||
return iree_make_status(IREE_STATUS_NOT_FOUND, "can't load image %.*s",
|
||||
(int)filename.size, filename.data);
|
||||
}
|
||||
if (!(element_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 ||
|
||||
element_type == IREE_HAL_ELEMENT_TYPE_SINT_8 ||
|
||||
element_type == IREE_HAL_ELEMENT_TYPE_UINT_8)) {
|
||||
char element_type_str[16];
|
||||
IREE_RETURN_IF_ERROR(iree_hal_format_element_type(
|
||||
element_type, sizeof(element_type_str), element_type_str, NULL));
|
||||
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
|
||||
"element type %s not supported", element_type_str);
|
||||
}
|
||||
switch (shape_rank) {
|
||||
case 2: { // Assume tensor <height x width>
|
||||
if (img_dims[2] != 1 || (shape[0] != img_dims[1]) ||
|
||||
(shape[1] != img_dims[0])) {
|
||||
return iree_make_status(
|
||||
IREE_STATUS_INVALID_ARGUMENT,
|
||||
"image size: %dx%dx%d, expected: %" PRIdim "x%" PRIdim, img_dims[0],
|
||||
img_dims[1], img_dims[2], shape[1], shape[0]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 3: { // Assume tensor <height x width x channel>
|
||||
if (shape[0] != img_dims[1] || shape[1] != img_dims[0] ||
|
||||
shape[2] != img_dims[2]) {
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"image size: %dx%dx%d, expected: %" PRIdim
|
||||
"x%" PRIdim "x%" PRIdim,
|
||||
img_dims[0], img_dims[1], img_dims[2], shape[1],
|
||||
shape[0], shape[2]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 4: { // Assume tensor <batch x height x width x channel>
|
||||
if (shape[1] != img_dims[1] || shape[2] != img_dims[0] ||
|
||||
shape[3] != img_dims[2]) {
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"image size: %dx%dx%d, expected: %" PRIdim
|
||||
"x%" PRIdim "x%" PRIdim,
|
||||
img_dims[0], img_dims[1], img_dims[2], shape[2],
|
||||
shape[1], shape[3]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return iree_make_status(
|
||||
IREE_STATUS_INVALID_ARGUMENT,
|
||||
"Input buffer shape rank %" PRIhsz " not supported", shape_rank);
|
||||
}
|
||||
// Drop the alpha channel if present.
|
||||
int req_ch = (img_dims[2] >= 3) ? 3 : 0;
|
||||
*out_pixel_data = stbi_load(filename.data, img_dims, &(img_dims[1]),
|
||||
&(img_dims[2]), req_ch);
|
||||
if (*out_pixel_data == NULL) {
|
||||
return iree_make_status(IREE_STATUS_NOT_FOUND, "can't load image %.*s",
|
||||
(int)filename.size, filename.data);
|
||||
}
|
||||
*out_buffer_length =
|
||||
img_dims[0] * img_dims[1] * (img_dims[2] > 3 ? 3 : img_dims[2]);
|
||||
return iree_ok_status();
|
||||
}
|
||||
|
||||
iree_status_t iree_tools_utils_load_pixel_data(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length) {
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
iree_status_t result = iree_tools_utils_load_pixel_data_impl(
|
||||
filename, shape, shape_rank, element_type, out_pixel_data,
|
||||
out_buffer_length);
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return result;
|
||||
}
|
||||
|
||||
iree_status_t iree_tools_utils_buffer_view_from_image(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
iree_hal_allocator_t* allocator, iree_hal_buffer_view_t** out_buffer_view) {
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
*out_buffer_view = NULL;
|
||||
if (element_type != IREE_HAL_ELEMENT_TYPE_SINT_8 &&
|
||||
element_type != IREE_HAL_ELEMENT_TYPE_UINT_8) {
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"element type should be i8 or u8");
|
||||
}
|
||||
|
||||
iree_status_t result;
|
||||
uint8_t* pixel_data = NULL;
|
||||
iree_host_size_t buffer_length;
|
||||
result = iree_tools_utils_load_pixel_data(
|
||||
filename, shape, shape_rank, element_type, &pixel_data, &buffer_length);
|
||||
if (iree_status_is_ok(result)) {
|
||||
iree_host_size_t element_byte =
|
||||
iree_hal_element_dense_byte_count(element_type);
|
||||
// SINT_8 and UINT_8 perform direct buffer wrap.
|
||||
result = iree_hal_buffer_view_allocate_buffer(
|
||||
allocator, shape_rank, shape, element_type,
|
||||
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
|
||||
(iree_hal_buffer_params_t){
|
||||
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
|
||||
.access = IREE_HAL_MEMORY_ACCESS_READ,
|
||||
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
|
||||
IREE_HAL_BUFFER_USAGE_TRANSFER,
|
||||
},
|
||||
iree_make_const_byte_span(pixel_data, element_byte * buffer_length),
|
||||
out_buffer_view);
|
||||
}
|
||||
stbi_image_free(pixel_data);
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return result;
|
||||
}
|
||||
|
||||
typedef struct iree_tools_utils_buffer_view_load_params_t {
|
||||
const uint8_t* pixel_data;
|
||||
iree_host_size_t pixel_data_length;
|
||||
const float* input_range;
|
||||
iree_host_size_t input_range_length;
|
||||
} iree_tools_utils_buffer_view_load_params_t;
|
||||
static iree_status_t iree_tools_utils_buffer_view_load_image_rescaled(
|
||||
iree_hal_buffer_mapping_t* mapping, void* user_data) {
|
||||
iree_tools_utils_buffer_view_load_params_t* params =
|
||||
(iree_tools_utils_buffer_view_load_params_t*)user_data;
|
||||
return iree_tools_utils_pixel_rescaled_to_buffer(
|
||||
params->pixel_data, params->pixel_data_length, params->input_range,
|
||||
params->input_range_length, (float*)mapping->contents.data);
|
||||
}
|
||||
|
||||
iree_status_t iree_tools_utils_buffer_view_from_image_rescaled(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
iree_hal_allocator_t* allocator, const float* input_range,
|
||||
iree_host_size_t input_range_length,
|
||||
iree_hal_buffer_view_t** out_buffer_view) {
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
*out_buffer_view = NULL;
|
||||
if (element_type != IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"element type should be f32");
|
||||
}
|
||||
|
||||
// Classic row-major image layout.
|
||||
iree_hal_encoding_type_t encoding_type =
|
||||
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
|
||||
|
||||
// Load pixel data from the file into a new host memory allocation (the only
|
||||
// interface stb_image provides). A real application would want to use the
|
||||
// generation callback to directly decode the image into the target mapped
|
||||
// device buffer.
|
||||
uint8_t* pixel_data = NULL;
|
||||
iree_host_size_t buffer_length = 0;
|
||||
IREE_RETURN_AND_END_ZONE_IF_ERROR(
|
||||
z0, iree_tools_utils_load_pixel_data(filename, shape, shape_rank,
|
||||
element_type, &pixel_data,
|
||||
&buffer_length));
|
||||
|
||||
iree_tools_utils_buffer_view_load_params_t params = {
|
||||
.pixel_data = pixel_data,
|
||||
.pixel_data_length = buffer_length,
|
||||
.input_range = input_range,
|
||||
.input_range_length = input_range_length,
|
||||
};
|
||||
iree_status_t status = iree_hal_buffer_view_generate_buffer(
|
||||
allocator, shape_rank, shape, element_type, encoding_type,
|
||||
(iree_hal_buffer_params_t){
|
||||
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
|
||||
IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
|
||||
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
|
||||
IREE_HAL_BUFFER_USAGE_TRANSFER |
|
||||
IREE_HAL_BUFFER_USAGE_MAPPING,
|
||||
},
|
||||
iree_tools_utils_buffer_view_load_image_rescaled, ¶ms,
|
||||
out_buffer_view);
|
||||
|
||||
stbi_image_free(pixel_data);
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return status;
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
// Copyright 2021 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
#ifndef IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
|
||||
#define IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
|
||||
|
||||
#include "iree/base/api.h"
|
||||
#include "iree/hal/api.h"
|
||||
#include "iree/hal/buffer_view.h"
|
||||
|
||||
#if __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// Loads the image at |filename| into |out_pixel_data| and sets
|
||||
// |out_buffer_length| to its length.
|
||||
//
|
||||
// The image dimension must match the width, height, and channel in|shape|,
|
||||
// while 2 <= |shape_rank| <= 4 to match the image tensor format.
|
||||
//
|
||||
// The file must be in a format supported by stb_image.h.
|
||||
// The returned |out_pixel_data| buffer must be released by the caller.
|
||||
iree_status_t iree_tools_utils_load_pixel_data(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length);
|
||||
|
||||
// Parse the content in an image file in |filename| into a HAL buffer view
|
||||
// |out_buffer_view|. |out_buffer_view| properties are defined by |shape|,
|
||||
// |shape_rank|, and |element_type|, while being allocated by |allocator|.
|
||||
//
|
||||
// The |element_type| has to be SINT_8 or UINT_8. For FLOAT_32, use
|
||||
// |iree_tools_utils_buffer_view_from_image_rescaled| instead.
|
||||
//
|
||||
// The returned |out_buffer_view| must be released by the caller.
|
||||
iree_status_t iree_tools_utils_buffer_view_from_image(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
iree_hal_allocator_t* allocator, iree_hal_buffer_view_t** out_buffer_view);
|
||||
|
||||
// Parse the content in an image file in |filename| into a HAL buffer view
|
||||
// |out_buffer_view|. |out_buffer_view| properties are defined by |shape|,
|
||||
// |shape_rank|, and |element_type|, while being allocated by |allocator|.
|
||||
// The value in |out_buffer_view| is rescaled with |input_range|.
|
||||
//
|
||||
// The |element_type| has to be FLOAT_32, For SINT_8 or UINT_8, use
|
||||
// |iree_tools_utils_buffer_view_from_image| instead.
|
||||
//
|
||||
// The returned |out_buffer_view| must be released by the caller.
|
||||
iree_status_t iree_tools_utils_buffer_view_from_image_rescaled(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
iree_hal_allocator_t* allocator, const float* input_range,
|
||||
iree_host_size_t input_range_length,
|
||||
iree_hal_buffer_view_t** out_buffer_view);
|
||||
|
||||
// Normalize uint8_t |pixel_data| of the size |buffer_length| to float buffer
|
||||
// |out_buffer| with the range |input_range|.
|
||||
//
|
||||
// float32_x = (uint8_x - 127.5) / 127.5 * input_scale + input_offset, where
|
||||
// input_scale = abs(|input_range[0]| - |input_range[1]| / 2
|
||||
// input_offset = |input_range[0]| + |input_range[1]| / 2
|
||||
//
|
||||
// |out_buffer| needs to be allocated before the call.
|
||||
iree_status_t iree_tools_utils_pixel_rescaled_to_buffer(
|
||||
const uint8_t* pixel_data, iree_host_size_t pixel_count,
|
||||
const float* input_range, iree_host_size_t input_range_length,
|
||||
float* out_buffer);
|
||||
|
||||
#if __cplusplus
|
||||
}
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
|
||||
@@ -1,121 +0,0 @@
|
||||
// Copyright 2021 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
// This sample uses image_util to load a hand-written image as an
|
||||
// iree_hal_buffer_view_t then passes it to the bytecode module built from
|
||||
// mnist.mlir on the CPU backend with the local-task driver.
|
||||
|
||||
#include <float.h>
|
||||
|
||||
#include "image_util.h"
|
||||
#include "iree/runtime/api.h"
|
||||
#include "mnist_bytecode_module_c.h"
|
||||
|
||||
iree_status_t Run(const iree_string_view_t image_path) {
|
||||
iree_runtime_instance_options_t instance_options;
|
||||
iree_runtime_instance_options_initialize(IREE_API_VERSION_LATEST,
|
||||
&instance_options);
|
||||
iree_runtime_instance_options_use_all_available_drivers(&instance_options);
|
||||
iree_runtime_instance_t* instance = NULL;
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_instance_create(
|
||||
&instance_options, iree_allocator_system(), &instance));
|
||||
|
||||
// TODO(#5724): move device selection into the compiled modules.
|
||||
iree_hal_device_t* device = NULL;
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_instance_try_create_default_device(
|
||||
instance, iree_make_cstring_view("local-task"), &device));
|
||||
|
||||
// Create one session per loaded module to hold the module state.
|
||||
iree_runtime_session_options_t session_options;
|
||||
iree_runtime_session_options_initialize(&session_options);
|
||||
iree_runtime_session_t* session = NULL;
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_session_create_with_device(
|
||||
instance, &session_options, device,
|
||||
iree_runtime_instance_host_allocator(instance), &session));
|
||||
iree_hal_device_release(device);
|
||||
|
||||
const struct iree_file_toc_t* module_file =
|
||||
iree_samples_vision_inference_mnist_bytecode_module_create();
|
||||
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_session_append_bytecode_module_from_memory(
|
||||
session, iree_make_const_byte_span(module_file->data, module_file->size),
|
||||
iree_allocator_null()));
|
||||
|
||||
iree_runtime_call_t call;
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name(
|
||||
session, iree_make_cstring_view("module.predict"), &call));
|
||||
|
||||
// Prepare the input hal buffer view with image_util library.
|
||||
// The input of the mmist model is single 28x28 pixel image as a
|
||||
// tensor<1x28x28x1xf32>, with pixels in [0.0, 1.0].
|
||||
iree_hal_buffer_view_t* buffer_view = NULL;
|
||||
iree_hal_dim_t buffer_shape[] = {1, 28, 28, 1};
|
||||
iree_hal_element_type_t hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32;
|
||||
float input_range[2] = {0.0f, 1.0f};
|
||||
IREE_RETURN_IF_ERROR(
|
||||
iree_tools_utils_buffer_view_from_image_rescaled(
|
||||
image_path, buffer_shape, IREE_ARRAYSIZE(buffer_shape),
|
||||
hal_element_type, iree_hal_device_allocator(device), input_range,
|
||||
IREE_ARRAYSIZE(input_range), &buffer_view),
|
||||
"load image");
|
||||
IREE_RETURN_IF_ERROR(
|
||||
iree_runtime_call_inputs_push_back_buffer_view(&call, buffer_view));
|
||||
iree_hal_buffer_view_release(buffer_view);
|
||||
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_call_invoke(&call, /*flags=*/0));
|
||||
|
||||
// Get the result buffers from the invocation.
|
||||
iree_hal_buffer_view_t* ret_buffer_view = NULL;
|
||||
IREE_RETURN_IF_ERROR(
|
||||
iree_runtime_call_outputs_pop_front_buffer_view(&call, &ret_buffer_view));
|
||||
|
||||
// Read back the results. The output of the mnist model is a 1x10 prediction
|
||||
// confidence values for each digit in [0, 9].
|
||||
float predictions[1 * 10] = {0.0f};
|
||||
IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
|
||||
iree_runtime_session_device(session),
|
||||
iree_hal_buffer_view_buffer(ret_buffer_view), 0, predictions,
|
||||
sizeof(predictions), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
|
||||
iree_infinite_timeout()));
|
||||
iree_hal_buffer_view_release(ret_buffer_view);
|
||||
|
||||
// Get the highest index from the output.
|
||||
float result_val = FLT_MIN;
|
||||
int result_idx = 0;
|
||||
for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(predictions); ++i) {
|
||||
if (predictions[i] > result_val) {
|
||||
result_val = predictions[i];
|
||||
result_idx = i;
|
||||
}
|
||||
}
|
||||
fprintf(stdout, "Detected number: %d\n", result_idx);
|
||||
|
||||
iree_runtime_call_deinitialize(&call);
|
||||
iree_runtime_session_release(session);
|
||||
iree_runtime_instance_release(instance);
|
||||
return iree_ok_status();
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc > 2) {
|
||||
fprintf(stderr, "Usage: iree-run-mnist-module <image file>\n");
|
||||
return -1;
|
||||
}
|
||||
iree_string_view_t image_path;
|
||||
if (argc == 1) {
|
||||
image_path = iree_make_cstring_view("mnist_test.png");
|
||||
} else {
|
||||
image_path = iree_make_cstring_view(argv[1]);
|
||||
}
|
||||
iree_status_t result = Run(image_path);
|
||||
if (!iree_status_is_ok(result)) {
|
||||
iree_status_fprint(stderr, result);
|
||||
iree_status_ignore(result);
|
||||
return -1;
|
||||
}
|
||||
iree_status_ignore(result);
|
||||
return 0;
|
||||
}
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 261 B |
@@ -1,116 +0,0 @@
|
||||
# Copyright 2022 The IREE Authors
|
||||
#
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
if(NOT IREE_TARGET_BACKEND_VULKAN_SPIRV OR
|
||||
NOT IREE_HAL_DRIVER_VULKAN)
|
||||
message(STATUS "Missing Vulkan backend and/or driver, skipping vulkan_gui sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# This target statically links against Vulkan.
|
||||
# One way to achieve this is by installing the Vulkan SDK from
|
||||
# https://vulkan.lunarg.com/.
|
||||
include(FindVulkan)
|
||||
if(NOT Vulkan_FOUND)
|
||||
message(STATUS "Could not find Vulkan, skipping vulkan_gui sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# vcpkg install sdl2[vulkan]
|
||||
# tested with versions 2.0.14#4 - 2.0.22#1
|
||||
find_package(SDL2)
|
||||
if(NOT SDL2_FOUND)
|
||||
message(STATUS "Could not find SDL2, skipping vulkan_gui sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
imgui
|
||||
GIT_REPOSITORY https://github.com/ocornut/imgui
|
||||
GIT_TAG master
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(imgui)
|
||||
|
||||
# Dear ImGui
|
||||
set(IMGUI_DIR ${CMAKE_BINARY_DIR}/_deps/imgui-src)
|
||||
message("Looking for Imgui in ${IMGUI_DIR}")
|
||||
include_directories(${IMGUI_DIR} ${IMGUI_DIR}/backends ..)
|
||||
|
||||
|
||||
function(iree_vulkan_sample)
|
||||
|
||||
cmake_parse_arguments(
|
||||
_RULE
|
||||
""
|
||||
"NAME"
|
||||
"SRCS"
|
||||
${ARGN}
|
||||
)
|
||||
|
||||
|
||||
# Define the sample executable.
|
||||
set(_NAME "${_RULE_NAME}")
|
||||
set(SRCS "${_RULE_SRCS}")
|
||||
add_executable(${_NAME} "")
|
||||
target_sources(${_NAME}
|
||||
PRIVATE
|
||||
${SRCS}
|
||||
"${IMGUI_DIR}/backends/imgui_impl_sdl.cpp"
|
||||
"${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp"
|
||||
"${IMGUI_DIR}/imgui.cpp"
|
||||
"${IMGUI_DIR}/imgui_draw.cpp"
|
||||
"${IMGUI_DIR}/imgui_demo.cpp"
|
||||
"${IMGUI_DIR}/imgui_tables.cpp"
|
||||
"${IMGUI_DIR}/imgui_widgets.cpp"
|
||||
)
|
||||
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_NAME}")
|
||||
target_include_directories(${_NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
|
||||
)
|
||||
target_link_libraries(${_NAME}
|
||||
SDL2::SDL2
|
||||
Vulkan::Vulkan
|
||||
iree_runtime_runtime
|
||||
iree_base_internal_main
|
||||
iree_hal_drivers_vulkan_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_vm_cc
|
||||
iree_tooling_vm_util_cc
|
||||
iree_tooling_context_util
|
||||
)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
|
||||
set(_GUI_LINKOPTS "-SUBSYSTEM:CONSOLE")
|
||||
else()
|
||||
set(_GUI_LINKOPTS "")
|
||||
endif()
|
||||
|
||||
target_link_options(${_NAME}
|
||||
PRIVATE
|
||||
${_GUI_LINKOPTS}
|
||||
)
|
||||
endfunction()
|
||||
|
||||
iree_vulkan_sample(
|
||||
NAME
|
||||
iree-samples-resnet-vulkan-gui
|
||||
|
||||
SRCS
|
||||
vulkan_resnet_inference_gui.cc
|
||||
)
|
||||
|
||||
iree_vulkan_sample(
|
||||
NAME
|
||||
iree-vulkan-gui
|
||||
|
||||
SRCS
|
||||
vulkan_inference_gui.cc
|
||||
)
|
||||
|
||||
message(STATUS "Configured vulkan_gui sample successfully")
|
||||
@@ -1,4 +0,0 @@
|
||||
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = "arith.mulf"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 14 KiB |
File diff suppressed because it is too large
Load Diff
@@ -1,957 +0,0 @@
|
||||
// Copyright 2019 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
// Vulkan Graphics + IREE API Integration Sample.
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_vulkan.h>
|
||||
#include <imgui.h>
|
||||
#include <imgui_impl_sdl.h>
|
||||
#include <imgui_impl_vulkan.h>
|
||||
#include <vulkan/vulkan.h>
|
||||
|
||||
|
||||
#include <cstring>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <array>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "iree/hal/drivers/vulkan/api.h"
|
||||
|
||||
// IREE's C API:
|
||||
#include "iree/base/api.h"
|
||||
#include "iree/hal/api.h"
|
||||
#include "iree/hal/drivers/vulkan/registration/driver_module.h"
|
||||
#include "iree/modules/hal/module.h"
|
||||
#include "iree/vm/api.h"
|
||||
#include "iree/vm/bytecode_module.h"
|
||||
#include "iree/vm/ref_cc.h"
|
||||
|
||||
// iree-run-module
|
||||
#include "iree/base/internal/flags.h"
|
||||
#include "iree/base/status_cc.h"
|
||||
#include "iree/base/tracing.h"
|
||||
#include "iree/modules/hal/types.h"
|
||||
#include "iree/tooling/comparison.h"
|
||||
#include "iree/tooling/context_util.h"
|
||||
#include "iree/tooling/vm_util_cc.h"
|
||||
|
||||
// Other dependencies (helpers, etc.)
|
||||
#include "iree/base/internal/main.h"
|
||||
|
||||
#define IMGUI_UNLIMITED_FRAME_RATE
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
IREE_FLAG(string, entry_function, "",
|
||||
"Name of a function contained in the module specified by module_file "
|
||||
"to run.");
|
||||
|
||||
// TODO(benvanik): move --function_input= flag into a util.
|
||||
static iree_status_t parse_function_io(iree_string_view_t flag_name,
|
||||
void* storage,
|
||||
iree_string_view_t value) {
|
||||
auto* list = (std::vector<std::string>*)storage;
|
||||
list->push_back(std::string(value.data, value.size));
|
||||
return iree_ok_status();
|
||||
}
|
||||
static void print_function_io(iree_string_view_t flag_name, void* storage,
|
||||
FILE* file) {
|
||||
auto* list = (std::vector<std::string>*)storage;
|
||||
if (list->empty()) {
|
||||
fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
|
||||
} else {
|
||||
for (size_t i = 0; i < list->size(); ++i) {
|
||||
fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
|
||||
list->at(i).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
static std::vector<std::string> FLAG_function_inputs;
|
||||
IREE_FLAG_CALLBACK(
|
||||
parse_function_io, print_function_io, &FLAG_function_inputs, function_input,
|
||||
"An input (a) value or (b) buffer of the format:\n"
|
||||
" (a) scalar value\n"
|
||||
" value\n"
|
||||
" e.g.: --function_input=\"3.14\"\n"
|
||||
" (b) buffer:\n"
|
||||
" [shape]xtype=[value]\n"
|
||||
" e.g.: --function_input=\"2x2xi32=1 2 3 4\"\n"
|
||||
"Optionally, brackets may be used to separate the element values:\n"
|
||||
" 2x2xi32=[[1 2][3 4]]\n"
|
||||
"Raw binary files can be read to provide buffer contents:\n"
|
||||
" 2x2xi32=@some/file.bin\n"
|
||||
"numpy npy files (from numpy.save) can be read to provide 1+ values:\n"
|
||||
" @some.npy\n"
|
||||
"Each occurrence of the flag indicates an input in the order they were\n"
|
||||
"specified on the command line.");
|
||||
|
||||
typedef struct iree_file_toc_t {
|
||||
const char* name; // the file's original name
|
||||
char* data; // beginning of the file
|
||||
size_t size; // length of the file
|
||||
} iree_file_toc_t;
|
||||
|
||||
bool load_file(const char* filename, char** pOut, size_t* pSize)
|
||||
{
|
||||
FILE* f = fopen(filename, "rb");
|
||||
if (f == NULL)
|
||||
{
|
||||
fprintf(stderr, "Can't open %s\n", filename);
|
||||
return false;
|
||||
}
|
||||
|
||||
fseek(f, 0L, SEEK_END);
|
||||
*pSize = ftell(f);
|
||||
fseek(f, 0L, SEEK_SET);
|
||||
|
||||
*pOut = (char*)malloc(*pSize);
|
||||
|
||||
size_t size = fread(*pOut, *pSize, 1, f);
|
||||
|
||||
fclose(f);
|
||||
|
||||
return size != 0;
|
||||
}
|
||||
|
||||
static VkAllocationCallbacks* g_Allocator = NULL;
|
||||
static VkInstance g_Instance = VK_NULL_HANDLE;
|
||||
static VkPhysicalDevice g_PhysicalDevice = VK_NULL_HANDLE;
|
||||
static VkDevice g_Device = VK_NULL_HANDLE;
|
||||
static uint32_t g_QueueFamily = (uint32_t)-1;
|
||||
static VkQueue g_Queue = VK_NULL_HANDLE;
|
||||
static VkPipelineCache g_PipelineCache = VK_NULL_HANDLE;
|
||||
static VkDescriptorPool g_DescriptorPool = VK_NULL_HANDLE;
|
||||
|
||||
static ImGui_ImplVulkanH_Window g_MainWindowData;
|
||||
static uint32_t g_MinImageCount = 2;
|
||||
static bool g_SwapChainRebuild = false;
|
||||
static int g_SwapChainResizeWidth = 0;
|
||||
static int g_SwapChainResizeHeight = 0;
|
||||
|
||||
static void check_vk_result(VkResult err) {
|
||||
if (err == 0) return;
|
||||
fprintf(stderr, "VkResult: %d\n", err);
|
||||
abort();
|
||||
}
|
||||
|
||||
// Returns the names of the Vulkan layers used for the given IREE
|
||||
// |extensibility_set| and |features|.
|
||||
std::vector<const char*> GetIreeLayers(
|
||||
iree_hal_vulkan_extensibility_set_t extensibility_set,
|
||||
iree_hal_vulkan_features_t features) {
|
||||
iree_host_size_t required_count;
|
||||
iree_hal_vulkan_query_extensibility_set(
|
||||
features, extensibility_set, /*string_capacity=*/0, &required_count,
|
||||
/*out_string_values=*/NULL);
|
||||
std::vector<const char*> layers(required_count);
|
||||
iree_hal_vulkan_query_extensibility_set(features, extensibility_set,
|
||||
layers.size(), &required_count,
|
||||
layers.data());
|
||||
return layers;
|
||||
}
|
||||
|
||||
// Returns the names of the Vulkan extensions used for the given IREE
|
||||
// |extensibility_set| and |features|.
|
||||
std::vector<const char*> GetIreeExtensions(
|
||||
iree_hal_vulkan_extensibility_set_t extensibility_set,
|
||||
iree_hal_vulkan_features_t features) {
|
||||
iree_host_size_t required_count;
|
||||
iree_hal_vulkan_query_extensibility_set(
|
||||
features, extensibility_set, /*string_capacity=*/0, &required_count,
|
||||
/*out_string_values=*/NULL);
|
||||
std::vector<const char*> extensions(required_count);
|
||||
iree_hal_vulkan_query_extensibility_set(features, extensibility_set,
|
||||
extensions.size(), &required_count,
|
||||
extensions.data());
|
||||
return extensions;
|
||||
}
|
||||
|
||||
// Returns the names of the Vulkan extensions used for the given IREE
|
||||
// |vulkan_features|.
|
||||
std::vector<const char*> GetDeviceExtensions(
|
||||
VkPhysicalDevice physical_device,
|
||||
iree_hal_vulkan_features_t vulkan_features) {
|
||||
std::vector<const char*> iree_required_extensions = GetIreeExtensions(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_REQUIRED,
|
||||
vulkan_features);
|
||||
std::vector<const char*> iree_optional_extensions = GetIreeExtensions(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
|
||||
vulkan_features);
|
||||
|
||||
uint32_t extension_count = 0;
|
||||
check_vk_result(vkEnumerateDeviceExtensionProperties(
|
||||
physical_device, nullptr, &extension_count, nullptr));
|
||||
std::vector<VkExtensionProperties> extension_properties(extension_count);
|
||||
check_vk_result(vkEnumerateDeviceExtensionProperties(
|
||||
physical_device, nullptr, &extension_count, extension_properties.data()));
|
||||
|
||||
// Merge extensions lists, including optional and required for simplicity.
|
||||
std::set<const char*> ext_set;
|
||||
ext_set.insert("VK_KHR_swapchain");
|
||||
ext_set.insert(iree_required_extensions.begin(),
|
||||
iree_required_extensions.end());
|
||||
for (int i = 0; i < iree_optional_extensions.size(); ++i) {
|
||||
const char* optional_extension = iree_optional_extensions[i];
|
||||
for (int j = 0; j < extension_count; ++j) {
|
||||
if (strcmp(optional_extension, extension_properties[j].extensionName) ==
|
||||
0) {
|
||||
ext_set.insert(optional_extension);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<const char*> extensions(ext_set.begin(), ext_set.end());
|
||||
return extensions;
|
||||
}
|
||||
|
||||
std::vector<const char*> GetInstanceLayers(
|
||||
iree_hal_vulkan_features_t vulkan_features) {
|
||||
// Query the layers that IREE wants / needs.
|
||||
std::vector<const char*> required_layers = GetIreeLayers(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_REQUIRED, vulkan_features);
|
||||
std::vector<const char*> optional_layers = GetIreeLayers(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL, vulkan_features);
|
||||
|
||||
// Query the layers that are available on the Vulkan ICD.
|
||||
uint32_t layer_property_count = 0;
|
||||
check_vk_result(
|
||||
vkEnumerateInstanceLayerProperties(&layer_property_count, NULL));
|
||||
std::vector<VkLayerProperties> layer_properties(layer_property_count);
|
||||
check_vk_result(vkEnumerateInstanceLayerProperties(&layer_property_count,
|
||||
layer_properties.data()));
|
||||
|
||||
// Match between optional/required and available layers.
|
||||
std::vector<const char*> layers;
|
||||
for (const char* layer_name : required_layers) {
|
||||
bool found = false;
|
||||
for (const auto& layer_property : layer_properties) {
|
||||
if (std::strcmp(layer_name, layer_property.layerName) == 0) {
|
||||
found = true;
|
||||
layers.push_back(layer_name);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
fprintf(stderr, "Required layer %s not available\n", layer_name);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
for (const char* layer_name : optional_layers) {
|
||||
for (const auto& layer_property : layer_properties) {
|
||||
if (std::strcmp(layer_name, layer_property.layerName) == 0) {
|
||||
layers.push_back(layer_name);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return layers;
|
||||
}
|
||||
|
||||
std::vector<const char*> GetInstanceExtensions(
|
||||
SDL_Window* window, iree_hal_vulkan_features_t vulkan_features) {
|
||||
// Ask SDL for its list of required instance extensions.
|
||||
uint32_t sdl_extensions_count = 0;
|
||||
SDL_Vulkan_GetInstanceExtensions(window, &sdl_extensions_count, NULL);
|
||||
std::vector<const char*> sdl_extensions(sdl_extensions_count);
|
||||
SDL_Vulkan_GetInstanceExtensions(window, &sdl_extensions_count,
|
||||
sdl_extensions.data());
|
||||
|
||||
std::vector<const char*> iree_required_extensions = GetIreeExtensions(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_REQUIRED,
|
||||
vulkan_features);
|
||||
std::vector<const char*> iree_optional_extensions = GetIreeExtensions(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL,
|
||||
vulkan_features);
|
||||
|
||||
// Merge extensions lists, including optional and required for simplicity.
|
||||
std::set<const char*> ext_set;
|
||||
ext_set.insert(sdl_extensions.begin(), sdl_extensions.end());
|
||||
ext_set.insert(iree_required_extensions.begin(),
|
||||
iree_required_extensions.end());
|
||||
ext_set.insert(iree_optional_extensions.begin(),
|
||||
iree_optional_extensions.end());
|
||||
std::vector<const char*> extensions(ext_set.begin(), ext_set.end());
|
||||
return extensions;
|
||||
}
|
||||
|
||||
void SetupVulkan(iree_hal_vulkan_features_t vulkan_features,
|
||||
const char** instance_layers, uint32_t instance_layers_count,
|
||||
const char** instance_extensions,
|
||||
uint32_t instance_extensions_count,
|
||||
const VkAllocationCallbacks* allocator, VkInstance* instance,
|
||||
uint32_t* queue_family_index,
|
||||
VkPhysicalDevice* physical_device, VkQueue* queue,
|
||||
VkDevice* device, VkDescriptorPool* descriptor_pool) {
|
||||
VkResult err;
|
||||
|
||||
// Create Vulkan Instance
|
||||
{
|
||||
VkInstanceCreateInfo create_info = {};
|
||||
create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
|
||||
create_info.enabledLayerCount = instance_layers_count;
|
||||
create_info.ppEnabledLayerNames = instance_layers;
|
||||
create_info.enabledExtensionCount = instance_extensions_count;
|
||||
create_info.ppEnabledExtensionNames = instance_extensions;
|
||||
err = vkCreateInstance(&create_info, allocator, instance);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Select GPU
|
||||
{
|
||||
uint32_t gpu_count;
|
||||
err = vkEnumeratePhysicalDevices(*instance, &gpu_count, NULL);
|
||||
check_vk_result(err);
|
||||
IM_ASSERT(gpu_count > 0);
|
||||
|
||||
VkPhysicalDevice* gpus =
|
||||
(VkPhysicalDevice*)malloc(sizeof(VkPhysicalDevice) * gpu_count);
|
||||
err = vkEnumeratePhysicalDevices(*instance, &gpu_count, gpus);
|
||||
check_vk_result(err);
|
||||
|
||||
// Use the first reported GPU for simplicity.
|
||||
*physical_device = gpus[0];
|
||||
|
||||
VkPhysicalDeviceProperties properties;
|
||||
vkGetPhysicalDeviceProperties(*physical_device, &properties);
|
||||
fprintf(stdout, "Selected Vulkan device: '%s'\n", properties.deviceName);
|
||||
free(gpus);
|
||||
}
|
||||
|
||||
// Select queue family. We want a single queue with graphics and compute for
|
||||
// simplicity, but we could also discover and use separate queues for each.
|
||||
{
|
||||
uint32_t count;
|
||||
vkGetPhysicalDeviceQueueFamilyProperties(*physical_device, &count, NULL);
|
||||
VkQueueFamilyProperties* queues = (VkQueueFamilyProperties*)malloc(
|
||||
sizeof(VkQueueFamilyProperties) * count);
|
||||
vkGetPhysicalDeviceQueueFamilyProperties(*physical_device, &count, queues);
|
||||
for (uint32_t i = 0; i < count; i++) {
|
||||
if (queues[i].queueFlags &
|
||||
(VK_QUEUE_GRAPHICS_BIT | VK_QUEUE_COMPUTE_BIT)) {
|
||||
*queue_family_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
free(queues);
|
||||
IM_ASSERT(*queue_family_index != (uint32_t)-1);
|
||||
}
|
||||
|
||||
// Create Logical Device (with 1 queue)
|
||||
{
|
||||
std::vector<const char*> device_extensions =
|
||||
GetDeviceExtensions(*physical_device, vulkan_features);
|
||||
const float queue_priority[] = {1.0f};
|
||||
VkDeviceQueueCreateInfo queue_info = {};
|
||||
queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
|
||||
queue_info.queueFamilyIndex = *queue_family_index;
|
||||
queue_info.queueCount = 1;
|
||||
queue_info.pQueuePriorities = queue_priority;
|
||||
VkDeviceCreateInfo create_info = {};
|
||||
create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
|
||||
create_info.queueCreateInfoCount = 1;
|
||||
create_info.pQueueCreateInfos = &queue_info;
|
||||
create_info.enabledExtensionCount =
|
||||
static_cast<uint32_t>(device_extensions.size());
|
||||
create_info.ppEnabledExtensionNames = device_extensions.data();
|
||||
|
||||
// Enable timeline semaphores.
|
||||
VkPhysicalDeviceFeatures2 features2;
|
||||
memset(&features2, 0, sizeof(features2));
|
||||
features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
|
||||
create_info.pNext = &features2;
|
||||
VkPhysicalDeviceTimelineSemaphoreFeatures semaphore_features;
|
||||
memset(&semaphore_features, 0, sizeof(semaphore_features));
|
||||
semaphore_features.sType =
|
||||
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_TIMELINE_SEMAPHORE_FEATURES;
|
||||
semaphore_features.pNext = features2.pNext;
|
||||
features2.pNext = &semaphore_features;
|
||||
semaphore_features.timelineSemaphore = VK_TRUE;
|
||||
|
||||
err = vkCreateDevice(*physical_device, &create_info, allocator, device);
|
||||
check_vk_result(err);
|
||||
vkGetDeviceQueue(*device, *queue_family_index, 0, queue);
|
||||
}
|
||||
|
||||
// Create Descriptor Pool
|
||||
{
|
||||
VkDescriptorPoolSize pool_sizes[] = {
|
||||
{VK_DESCRIPTOR_TYPE_SAMPLER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT, 1000}};
|
||||
VkDescriptorPoolCreateInfo pool_info = {};
|
||||
pool_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
|
||||
pool_info.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
|
||||
pool_info.maxSets = 1000 * IREE_ARRAYSIZE(pool_sizes);
|
||||
pool_info.poolSizeCount = (uint32_t)IREE_ARRAYSIZE(pool_sizes);
|
||||
pool_info.pPoolSizes = pool_sizes;
|
||||
err =
|
||||
vkCreateDescriptorPool(*device, &pool_info, allocator, descriptor_pool);
|
||||
check_vk_result(err);
|
||||
}
|
||||
}
|
||||
|
||||
void SetupVulkanWindow(ImGui_ImplVulkanH_Window* wd,
|
||||
const VkAllocationCallbacks* allocator,
|
||||
VkInstance instance, uint32_t queue_family_index,
|
||||
VkPhysicalDevice physical_device, VkDevice device,
|
||||
VkSurfaceKHR surface, int width, int height,
|
||||
uint32_t min_image_count) {
|
||||
wd->Surface = surface;
|
||||
|
||||
// Check for WSI support
|
||||
VkBool32 res;
|
||||
vkGetPhysicalDeviceSurfaceSupportKHR(physical_device, queue_family_index,
|
||||
wd->Surface, &res);
|
||||
if (res != VK_TRUE) {
|
||||
fprintf(stderr, "Error no WSI support on physical device 0\n");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Select Surface Format
|
||||
const VkFormat requestSurfaceImageFormat[] = {
|
||||
VK_FORMAT_B8G8R8A8_UNORM, VK_FORMAT_R8G8B8A8_UNORM,
|
||||
VK_FORMAT_B8G8R8_UNORM, VK_FORMAT_R8G8B8_UNORM};
|
||||
const VkColorSpaceKHR requestSurfaceColorSpace =
|
||||
VK_COLORSPACE_SRGB_NONLINEAR_KHR;
|
||||
wd->SurfaceFormat = ImGui_ImplVulkanH_SelectSurfaceFormat(
|
||||
physical_device, wd->Surface, requestSurfaceImageFormat,
|
||||
(size_t)IREE_ARRAYSIZE(requestSurfaceImageFormat),
|
||||
requestSurfaceColorSpace);
|
||||
|
||||
// Select Present Mode
|
||||
#ifdef IMGUI_UNLIMITED_FRAME_RATE
|
||||
VkPresentModeKHR present_modes[] = {VK_PRESENT_MODE_MAILBOX_KHR,
|
||||
VK_PRESENT_MODE_IMMEDIATE_KHR,
|
||||
VK_PRESENT_MODE_FIFO_KHR};
|
||||
#else
|
||||
VkPresentModeKHR present_modes[] = {VK_PRESENT_MODE_FIFO_KHR};
|
||||
#endif
|
||||
wd->PresentMode = ImGui_ImplVulkanH_SelectPresentMode(
|
||||
physical_device, wd->Surface, &present_modes[0],
|
||||
IREE_ARRAYSIZE(present_modes));
|
||||
|
||||
// Create SwapChain, RenderPass, Framebuffer, etc.
|
||||
IM_ASSERT(min_image_count >= 2);
|
||||
ImGui_ImplVulkanH_CreateOrResizeWindow(instance, physical_device, device, wd,
|
||||
queue_family_index, allocator, width,
|
||||
height, min_image_count);
|
||||
|
||||
// Set clear color.
|
||||
ImVec4 clear_color = ImVec4(0.45f, 0.55f, 0.60f, 1.00f);
|
||||
memcpy(&wd->ClearValue.color.float32[0], &clear_color, 4 * sizeof(float));
|
||||
}
|
||||
|
||||
void RenderFrame(ImGui_ImplVulkanH_Window* wd, VkDevice device, VkQueue queue) {
|
||||
VkResult err;
|
||||
|
||||
VkSemaphore image_acquired_semaphore =
|
||||
wd->FrameSemaphores[wd->SemaphoreIndex].ImageAcquiredSemaphore;
|
||||
VkSemaphore render_complete_semaphore =
|
||||
wd->FrameSemaphores[wd->SemaphoreIndex].RenderCompleteSemaphore;
|
||||
err = vkAcquireNextImageKHR(device, wd->Swapchain, UINT64_MAX,
|
||||
image_acquired_semaphore, VK_NULL_HANDLE,
|
||||
&wd->FrameIndex);
|
||||
check_vk_result(err);
|
||||
|
||||
ImGui_ImplVulkanH_Frame* fd = &wd->Frames[wd->FrameIndex];
|
||||
{
|
||||
err = vkWaitForFences(
|
||||
device, 1, &fd->Fence, VK_TRUE,
|
||||
UINT64_MAX); // wait indefinitely instead of periodically checking
|
||||
check_vk_result(err);
|
||||
|
||||
err = vkResetFences(device, 1, &fd->Fence);
|
||||
check_vk_result(err);
|
||||
}
|
||||
{
|
||||
err = vkResetCommandPool(device, fd->CommandPool, 0);
|
||||
check_vk_result(err);
|
||||
VkCommandBufferBeginInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
|
||||
info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
|
||||
err = vkBeginCommandBuffer(fd->CommandBuffer, &info);
|
||||
check_vk_result(err);
|
||||
}
|
||||
{
|
||||
VkRenderPassBeginInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_RENDER_PASS_BEGIN_INFO;
|
||||
info.renderPass = wd->RenderPass;
|
||||
info.framebuffer = fd->Framebuffer;
|
||||
info.renderArea.extent.width = wd->Width;
|
||||
info.renderArea.extent.height = wd->Height;
|
||||
info.clearValueCount = 1;
|
||||
info.pClearValues = &wd->ClearValue;
|
||||
vkCmdBeginRenderPass(fd->CommandBuffer, &info, VK_SUBPASS_CONTENTS_INLINE);
|
||||
}
|
||||
|
||||
// Record Imgui Draw Data and draw funcs into command buffer
|
||||
ImGui_ImplVulkan_RenderDrawData(ImGui::GetDrawData(), fd->CommandBuffer);
|
||||
|
||||
// Submit command buffer
|
||||
vkCmdEndRenderPass(fd->CommandBuffer);
|
||||
{
|
||||
VkPipelineStageFlags wait_stage =
|
||||
VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT;
|
||||
VkSubmitInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
|
||||
info.waitSemaphoreCount = 1;
|
||||
info.pWaitSemaphores = &image_acquired_semaphore;
|
||||
info.pWaitDstStageMask = &wait_stage;
|
||||
info.commandBufferCount = 1;
|
||||
info.pCommandBuffers = &fd->CommandBuffer;
|
||||
info.signalSemaphoreCount = 1;
|
||||
info.pSignalSemaphores = &render_complete_semaphore;
|
||||
|
||||
err = vkEndCommandBuffer(fd->CommandBuffer);
|
||||
check_vk_result(err);
|
||||
err = vkQueueSubmit(queue, 1, &info, fd->Fence);
|
||||
check_vk_result(err);
|
||||
}
|
||||
}
|
||||
|
||||
void PresentFrame(ImGui_ImplVulkanH_Window* wd, VkQueue queue) {
|
||||
VkSemaphore render_complete_semaphore =
|
||||
wd->FrameSemaphores[wd->SemaphoreIndex].RenderCompleteSemaphore;
|
||||
VkPresentInfoKHR info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_PRESENT_INFO_KHR;
|
||||
info.waitSemaphoreCount = 1;
|
||||
info.pWaitSemaphores = &render_complete_semaphore;
|
||||
info.swapchainCount = 1;
|
||||
info.pSwapchains = &wd->Swapchain;
|
||||
info.pImageIndices = &wd->FrameIndex;
|
||||
VkResult err = vkQueuePresentKHR(queue, &info);
|
||||
check_vk_result(err);
|
||||
wd->SemaphoreIndex =
|
||||
(wd->SemaphoreIndex + 1) %
|
||||
wd->ImageCount; // Now we can use the next set of semaphores
|
||||
}
|
||||
|
||||
static void CleanupVulkan() {
|
||||
vkDestroyDescriptorPool(g_Device, g_DescriptorPool, g_Allocator);
|
||||
|
||||
vkDestroyDevice(g_Device, g_Allocator);
|
||||
vkDestroyInstance(g_Instance, g_Allocator);
|
||||
}
|
||||
|
||||
static void CleanupVulkanWindow() {
|
||||
ImGui_ImplVulkanH_DestroyWindow(g_Instance, g_Device, &g_MainWindowData,
|
||||
g_Allocator);
|
||||
}
|
||||
|
||||
namespace iree {
|
||||
|
||||
extern "C" int iree_main(int argc, char** argv) {
|
||||
|
||||
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
|
||||
if (argc > 1) {
|
||||
// Avoid iree-run-module spinning endlessly on stdin if the user uses single
|
||||
// dashes for flags.
|
||||
printf(
|
||||
"[ERROR] unexpected positional argument (expected none)."
|
||||
" Did you use pass a flag with a single dash ('-')?"
|
||||
" Use '--' instead.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Create a window.
|
||||
if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) {
|
||||
fprintf(stderr, "Failed to initialize SDL\n");
|
||||
abort();
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Setup window
|
||||
// clang-format off
|
||||
SDL_WindowFlags window_flags = (SDL_WindowFlags)(
|
||||
SDL_WINDOW_VULKAN | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
|
||||
// clang-format on
|
||||
SDL_Window* window = SDL_CreateWindow(
|
||||
"IREE Samples - Vulkan Inference GUI", SDL_WINDOWPOS_CENTERED,
|
||||
SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
|
||||
if (window == nullptr)
|
||||
{
|
||||
const char* sdl_err = SDL_GetError();
|
||||
fprintf(stderr, "Error, SDL_CreateWindow returned: %s\n", sdl_err);
|
||||
abort();
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Setup Vulkan
|
||||
iree_hal_vulkan_features_t iree_vulkan_features =
|
||||
static_cast<iree_hal_vulkan_features_t>(
|
||||
IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS |
|
||||
IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS);
|
||||
std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features);
|
||||
std::vector<const char*> extensions =
|
||||
GetInstanceExtensions(window, iree_vulkan_features);
|
||||
SetupVulkan(iree_vulkan_features, layers.data(),
|
||||
static_cast<uint32_t>(layers.size()), extensions.data(),
|
||||
static_cast<uint32_t>(extensions.size()), g_Allocator,
|
||||
&g_Instance, &g_QueueFamily, &g_PhysicalDevice, &g_Queue,
|
||||
&g_Device, &g_DescriptorPool);
|
||||
|
||||
// Create Window Surface
|
||||
VkSurfaceKHR surface;
|
||||
VkResult err;
|
||||
if (SDL_Vulkan_CreateSurface(window, g_Instance, &surface) == 0) {
|
||||
fprintf(stderr, "Failed to create Vulkan surface.\n");
|
||||
abort();
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Create Framebuffers
|
||||
int w, h;
|
||||
SDL_GetWindowSize(window, &w, &h);
|
||||
ImGui_ImplVulkanH_Window* wd = &g_MainWindowData;
|
||||
SetupVulkanWindow(wd, g_Allocator, g_Instance, g_QueueFamily,
|
||||
g_PhysicalDevice, g_Device, surface, w, h, g_MinImageCount);
|
||||
|
||||
// Setup Dear ImGui context
|
||||
IMGUI_CHECKVERSION();
|
||||
ImGui::CreateContext();
|
||||
ImGuiIO& io = ImGui::GetIO();
|
||||
(void)io;
|
||||
|
||||
ImGui::StyleColorsDark();
|
||||
|
||||
// Setup Platform/Renderer bindings
|
||||
ImGui_ImplSDL2_InitForVulkan(window);
|
||||
ImGui_ImplVulkan_InitInfo init_info = {};
|
||||
init_info.Instance = g_Instance;
|
||||
init_info.PhysicalDevice = g_PhysicalDevice;
|
||||
init_info.Device = g_Device;
|
||||
init_info.QueueFamily = g_QueueFamily;
|
||||
init_info.Queue = g_Queue;
|
||||
init_info.PipelineCache = g_PipelineCache;
|
||||
init_info.DescriptorPool = g_DescriptorPool;
|
||||
init_info.Allocator = g_Allocator;
|
||||
init_info.MinImageCount = g_MinImageCount;
|
||||
init_info.ImageCount = wd->ImageCount;
|
||||
init_info.CheckVkResultFn = check_vk_result;
|
||||
ImGui_ImplVulkan_Init(&init_info, wd->RenderPass);
|
||||
|
||||
// Upload Fonts
|
||||
{
|
||||
// Use any command queue
|
||||
VkCommandPool command_pool = wd->Frames[wd->FrameIndex].CommandPool;
|
||||
VkCommandBuffer command_buffer = wd->Frames[wd->FrameIndex].CommandBuffer;
|
||||
|
||||
err = vkResetCommandPool(g_Device, command_pool, 0);
|
||||
check_vk_result(err);
|
||||
VkCommandBufferBeginInfo begin_info = {};
|
||||
begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
|
||||
begin_info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
|
||||
err = vkBeginCommandBuffer(command_buffer, &begin_info);
|
||||
check_vk_result(err);
|
||||
|
||||
ImGui_ImplVulkan_CreateFontsTexture(command_buffer);
|
||||
|
||||
VkSubmitInfo end_info = {};
|
||||
end_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
|
||||
end_info.commandBufferCount = 1;
|
||||
end_info.pCommandBuffers = &command_buffer;
|
||||
err = vkEndCommandBuffer(command_buffer);
|
||||
check_vk_result(err);
|
||||
err = vkQueueSubmit(g_Queue, 1, &end_info, VK_NULL_HANDLE);
|
||||
check_vk_result(err);
|
||||
|
||||
err = vkDeviceWaitIdle(g_Device);
|
||||
check_vk_result(err);
|
||||
ImGui_ImplVulkan_DestroyFontUploadObjects();
|
||||
}
|
||||
|
||||
// Demo state.
|
||||
bool show_iree_window = true;
|
||||
// --------------------------------------------------------------------------
|
||||
// Setup IREE.
|
||||
|
||||
// Check API version.
|
||||
iree_api_version_t actual_version;
|
||||
iree_status_t status =
|
||||
iree_api_version_check(IREE_API_VERSION_LATEST, &actual_version);
|
||||
if (iree_status_is_ok(status)) {
|
||||
fprintf(stdout, "IREE runtime API version: %d\n", actual_version);
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported runtime API version: %d\n", actual_version);
|
||||
abort();
|
||||
}
|
||||
|
||||
// Create a runtime Instance.
|
||||
iree_vm_instance_t* iree_instance = nullptr;
|
||||
IREE_CHECK_OK(
|
||||
iree_vm_instance_create(iree_allocator_system(), &iree_instance));
|
||||
|
||||
// Register HAL drivers and VM module types.
|
||||
IREE_CHECK_OK(iree_hal_vulkan_driver_module_register(
|
||||
iree_hal_driver_registry_default()));
|
||||
IREE_CHECK_OK(iree_hal_module_register_all_types(iree_instance));
|
||||
|
||||
// Create IREE Vulkan Driver and Device, sharing our VkInstance/VkDevice.
|
||||
fprintf(stdout, "Creating Vulkan driver/device\n");
|
||||
// Load symbols from our static `vkGetInstanceProcAddr` for IREE to use.
|
||||
iree_hal_vulkan_syms_t* iree_vk_syms = nullptr;
|
||||
IREE_CHECK_OK(iree_hal_vulkan_syms_create(
|
||||
reinterpret_cast<void*>(&vkGetInstanceProcAddr), iree_allocator_system(),
|
||||
&iree_vk_syms));
|
||||
// Create the driver sharing our VkInstance.
|
||||
iree_hal_driver_t* iree_vk_driver = nullptr;
|
||||
iree_string_view_t driver_identifier = iree_make_cstring_view("vulkan");
|
||||
iree_hal_vulkan_driver_options_t driver_options;
|
||||
driver_options.api_version = VK_API_VERSION_1_0;
|
||||
driver_options.requested_features = static_cast<iree_hal_vulkan_features_t>(
|
||||
IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS);
|
||||
IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance(
|
||||
driver_identifier, &driver_options, iree_vk_syms, g_Instance,
|
||||
iree_allocator_system(), &iree_vk_driver));
|
||||
// Create a device sharing our VkDevice and queue.
|
||||
// We could also create a separate (possibly low priority) compute queue for
|
||||
// IREE, and/or provide a dedicated transfer queue.
|
||||
iree_string_view_t device_identifier = iree_make_cstring_view("vulkan");
|
||||
iree_hal_vulkan_queue_set_t compute_queue_set;
|
||||
compute_queue_set.queue_family_index = g_QueueFamily;
|
||||
compute_queue_set.queue_indices = 1 << 0;
|
||||
iree_hal_vulkan_queue_set_t transfer_queue_set;
|
||||
transfer_queue_set.queue_indices = 0;
|
||||
iree_hal_device_t* iree_vk_device = nullptr;
|
||||
IREE_CHECK_OK(iree_hal_vulkan_wrap_device(
|
||||
device_identifier, &driver_options.device_options, iree_vk_syms,
|
||||
g_Instance, g_PhysicalDevice, g_Device, &compute_queue_set,
|
||||
&transfer_queue_set, iree_allocator_system(), &iree_vk_device));
|
||||
// Create a HAL module using the HAL device.
|
||||
iree_vm_module_t* hal_module = nullptr;
|
||||
IREE_CHECK_OK(iree_hal_module_create(iree_instance, iree_vk_device,
|
||||
IREE_HAL_MODULE_FLAG_NONE,
|
||||
iree_allocator_system(), &hal_module));
|
||||
|
||||
|
||||
// Load bytecode module
|
||||
//iree_file_toc_t module_file_toc;
|
||||
//const char network_model[] = "resnet50_tf.vmfb";
|
||||
//fprintf(stdout, "Loading: %s\n", network_model);
|
||||
//if (load_file(network_model, &module_file_toc.data, &module_file_toc.size) == false)
|
||||
//{
|
||||
// abort();
|
||||
// return 1;
|
||||
//}
|
||||
//fprintf(stdout, "module size: %zu\n", module_file_toc.size);
|
||||
|
||||
iree_vm_module_t* bytecode_module = nullptr;
|
||||
iree_status_t module_status = iree_tooling_load_module_from_flags(
|
||||
iree_instance, iree_allocator_system(), &bytecode_module);
|
||||
if (!iree_status_is_ok(module_status))
|
||||
return -1;
|
||||
//IREE_CHECK_OK(iree_vm_bytecode_module_create(
|
||||
// iree_instance,
|
||||
// iree_const_byte_span_t{
|
||||
// reinterpret_cast<const uint8_t*>(module_file_toc.data),
|
||||
// module_file_toc.size},
|
||||
// iree_allocator_null(), iree_allocator_system(), &bytecode_module));
|
||||
//// Query for details about what is in the loaded module.
|
||||
//iree_vm_module_signature_t bytecode_module_signature =
|
||||
// iree_vm_module_signature(bytecode_module);
|
||||
//fprintf(stdout, "Module loaded, have <%" PRIhsz "> exported functions:\n",
|
||||
// bytecode_module_signature.export_function_count);
|
||||
//for (int i = 0; i < bytecode_module_signature.export_function_count; ++i) {
|
||||
// iree_vm_function_t function;
|
||||
// IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal(
|
||||
// bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
|
||||
// auto function_name = iree_vm_function_name(&function);
|
||||
// auto function_signature = iree_vm_function_signature(&function);
|
||||
|
||||
// fprintf(stdout, " %d: '%.*s' with calling convention '%.*s'\n", i,
|
||||
// (int)function_name.size, function_name.data,
|
||||
// (int)function_signature.calling_convention.size,
|
||||
// function_signature.calling_convention.data);
|
||||
//}
|
||||
|
||||
// Allocate a context that will hold the module state across invocations.
|
||||
iree_vm_context_t* iree_context = nullptr;
|
||||
std::vector<iree_vm_module_t*> modules = {hal_module, bytecode_module};
|
||||
IREE_CHECK_OK(iree_vm_context_create_with_modules(
|
||||
iree_instance, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(),
|
||||
iree_allocator_system(), &iree_context));
|
||||
fprintf(stdout, "Context with modules is ready for use\n");
|
||||
|
||||
// Lookup the entry point function.
|
||||
iree_vm_function_t main_function;
|
||||
const char kMainFunctionName[] = "module.forward";
|
||||
IREE_CHECK_OK(iree_vm_context_resolve_function(
|
||||
iree_context,
|
||||
iree_string_view_t{kMainFunctionName, sizeof(kMainFunctionName) - 1},
|
||||
&main_function));
|
||||
iree_string_view_t main_function_name = iree_vm_function_name(&main_function);
|
||||
fprintf(stdout, "Resolved main function named '%.*s'\n",
|
||||
(int)main_function_name.size, main_function_name.data);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Write inputs into mappable buffers.
|
||||
iree_hal_allocator_t* allocator =
|
||||
iree_hal_device_allocator(iree_vk_device);
|
||||
//iree_hal_memory_type_t input_memory_type =
|
||||
// static_cast<iree_hal_memory_type_t>(
|
||||
// IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
|
||||
// IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
|
||||
//iree_hal_buffer_usage_t input_buffer_usage =
|
||||
// static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
|
||||
//iree_hal_buffer_params_t buffer_params;
|
||||
//buffer_params.type = input_memory_type;
|
||||
//buffer_params.usage = input_buffer_usage;
|
||||
//buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE;
|
||||
|
||||
// Wrap input buffers in buffer views.
|
||||
|
||||
vm::ref<iree_vm_list_t> inputs;
|
||||
iree_status_t input_status = ParseToVariantList(
|
||||
allocator,
|
||||
iree::span<const std::string>{FLAG_function_inputs.data(),
|
||||
FLAG_function_inputs.size()},
|
||||
iree_allocator_system(), &inputs);
|
||||
if (!iree_status_is_ok(input_status))
|
||||
return -1;
|
||||
//vm::ref<iree_vm_list_t> inputs;
|
||||
//IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
|
||||
|
||||
//iree_hal_buffer_view_t* input0_buffer_view = nullptr;
|
||||
//constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
|
||||
//IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
|
||||
// allocator,
|
||||
// /*shape_rank=*/4, /*shape=*/input_buffer_shape,
|
||||
// IREE_HAL_ELEMENT_TYPE_FLOAT_32,
|
||||
// IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
|
||||
// iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
|
||||
// &input0_buffer_view));
|
||||
|
||||
//auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
|
||||
//IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
|
||||
|
||||
// Prepare outputs list to accept results from the invocation.
|
||||
|
||||
vm::ref<iree_vm_list_t> outputs;
|
||||
constexpr iree_hal_dim_t kOutputCount = 1000;
|
||||
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, kOutputCount * sizeof(float), iree_allocator_system(), &outputs));
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Main loop.
|
||||
bool done = false;
|
||||
while (!done) {
|
||||
SDL_Event event;
|
||||
|
||||
while (SDL_PollEvent(&event)) {
|
||||
if (event.type == SDL_QUIT) {
|
||||
done = true;
|
||||
}
|
||||
|
||||
ImGui_ImplSDL2_ProcessEvent(&event);
|
||||
if (event.type == SDL_QUIT) done = true;
|
||||
if (event.type == SDL_WINDOWEVENT &&
|
||||
event.window.event == SDL_WINDOWEVENT_RESIZED &&
|
||||
event.window.windowID == SDL_GetWindowID(window)) {
|
||||
g_SwapChainResizeWidth = (int)event.window.data1;
|
||||
g_SwapChainResizeHeight = (int)event.window.data2;
|
||||
g_SwapChainRebuild = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (g_SwapChainRebuild) {
|
||||
g_SwapChainRebuild = false;
|
||||
ImGui_ImplVulkan_SetMinImageCount(g_MinImageCount);
|
||||
ImGui_ImplVulkanH_CreateOrResizeWindow(
|
||||
g_Instance, g_PhysicalDevice, g_Device, &g_MainWindowData,
|
||||
g_QueueFamily, g_Allocator, g_SwapChainResizeWidth,
|
||||
g_SwapChainResizeHeight, g_MinImageCount);
|
||||
g_MainWindowData.FrameIndex = 0;
|
||||
}
|
||||
|
||||
// Start the Dear ImGui frame
|
||||
ImGui_ImplVulkan_NewFrame();
|
||||
ImGui_ImplSDL2_NewFrame(window);
|
||||
ImGui::NewFrame();
|
||||
|
||||
// Custom window.
|
||||
{
|
||||
ImGui::Begin("IREE Vulkan Integration Demo", &show_iree_window);
|
||||
|
||||
ImGui::Separator();
|
||||
|
||||
// ImGui Inputs for two input tensors.
|
||||
// Run computation whenever any of the values changes.
|
||||
static bool dirty = true;
|
||||
if (dirty) {
|
||||
|
||||
// Synchronously invoke the function.
|
||||
IREE_CHECK_OK(iree_vm_invoke(iree_context, main_function,
|
||||
IREE_VM_INVOCATION_FLAG_NONE,
|
||||
/*policy=*/nullptr, inputs.get(),
|
||||
outputs.get(), iree_allocator_system()));
|
||||
|
||||
|
||||
// we want to run continuously so we can use tools like RenderDoc, RGP, etc...
|
||||
dirty = true;
|
||||
}
|
||||
|
||||
// Framerate counter.
|
||||
ImGui::Text("Application average %.3f ms/frame (%.1f FPS)",
|
||||
1000.0f / ImGui::GetIO().Framerate, ImGui::GetIO().Framerate);
|
||||
|
||||
ImGui::End();
|
||||
}
|
||||
|
||||
// Rendering
|
||||
ImGui::Render();
|
||||
RenderFrame(wd, g_Device, g_Queue);
|
||||
|
||||
PresentFrame(wd, g_Queue);
|
||||
}
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Cleanup
|
||||
iree_vm_module_release(hal_module);
|
||||
iree_vm_module_release(bytecode_module);
|
||||
iree_vm_context_release(iree_context);
|
||||
iree_hal_device_release(iree_vk_device);
|
||||
iree_hal_allocator_release(allocator);
|
||||
iree_hal_driver_release(iree_vk_driver);
|
||||
iree_hal_vulkan_syms_release(iree_vk_syms);
|
||||
iree_vm_instance_release(iree_instance);
|
||||
|
||||
err = vkDeviceWaitIdle(g_Device);
|
||||
check_vk_result(err);
|
||||
ImGui_ImplVulkan_Shutdown();
|
||||
ImGui_ImplSDL2_Shutdown();
|
||||
ImGui::DestroyContext();
|
||||
|
||||
CleanupVulkanWindow();
|
||||
CleanupVulkan();
|
||||
|
||||
SDL_DestroyWindow(window);
|
||||
SDL_Quit();
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace iree
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,251 +0,0 @@
|
||||
# Lint as: python3
|
||||
"""SHARK Tank"""
|
||||
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
|
||||
# will generate local shark tank folder like this:
|
||||
# HOME
|
||||
# /.local
|
||||
# /shark_tank
|
||||
# /albert_lite_base
|
||||
# /...model_name...
|
||||
#
|
||||
|
||||
import os
|
||||
import csv
|
||||
import argparse
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
import tensorflow as tf
|
||||
import subprocess as sp
|
||||
import hashlib
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
visible_default = tf.config.list_physical_devices("GPU")
|
||||
try:
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
visible_devices = tf.config.get_visible_devices()
|
||||
for device in visible_devices:
|
||||
assert device.device_type != "GPU"
|
||||
except:
|
||||
# Invalid device or cannot modify virtual devices once initialized.
|
||||
pass
|
||||
|
||||
|
||||
def create_hash(file_name):
|
||||
with open(file_name, "rb") as f:
|
||||
file_hash = hashlib.blake2b()
|
||||
while chunk := f.read(2**20):
|
||||
file_hash.update(chunk)
|
||||
|
||||
return file_hash.hexdigest()
|
||||
|
||||
|
||||
def save_torch_model(torch_model_list):
|
||||
from tank.model_utils import get_hf_model
|
||||
from tank.model_utils import get_vision_model
|
||||
from tank.model_utils import get_hf_img_cls_model
|
||||
|
||||
with open(torch_model_list) as csvfile:
|
||||
torch_reader = csv.reader(csvfile, delimiter=",")
|
||||
fields = next(torch_reader)
|
||||
for row in torch_reader:
|
||||
torch_model_name = row[0]
|
||||
tracing_required = row[1]
|
||||
model_type = row[2]
|
||||
is_dynamic = row[3]
|
||||
|
||||
tracing_required = False if tracing_required == "False" else True
|
||||
is_dynamic = False if is_dynamic == "False" else True
|
||||
|
||||
model = None
|
||||
input = None
|
||||
if model_type == "vision":
|
||||
model, input, _ = get_vision_model(torch_model_name)
|
||||
elif model_type == "hf":
|
||||
model, input, _ = get_hf_model(torch_model_name)
|
||||
elif model_type == "hf_img_cls":
|
||||
model, input, _ = get_hf_img_cls_model(torch_model_name)
|
||||
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
torch_model_dir = os.path.join(
|
||||
WORKDIR, str(torch_model_name) + "_torch"
|
||||
)
|
||||
os.makedirs(torch_model_dir, exist_ok=True)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name,
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(
|
||||
torch_model_dir, torch_model_name + "_torch" + ".mlir"
|
||||
)
|
||||
)
|
||||
np.save(os.path.join(torch_model_dir, "hash"), np.array(mlir_hash))
|
||||
# Generate torch dynamic models.
|
||||
if is_dynamic:
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=True,
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name + "_dynamic",
|
||||
)
|
||||
|
||||
|
||||
def save_tf_model(tf_model_list):
|
||||
from tank.model_utils_tf import (
|
||||
get_causal_image_model,
|
||||
get_causal_lm_model,
|
||||
get_keras_model,
|
||||
get_TFhf_model,
|
||||
)
|
||||
|
||||
with open(tf_model_list) as csvfile:
|
||||
tf_reader = csv.reader(csvfile, delimiter=",")
|
||||
fields = next(tf_reader)
|
||||
for row in tf_reader:
|
||||
tf_model_name = row[0]
|
||||
model_type = row[1]
|
||||
|
||||
model = None
|
||||
input = None
|
||||
print(f"Generating artifacts for model {tf_model_name}")
|
||||
if model_type == "hf":
|
||||
model, input, _ = get_causal_lm_model(tf_model_name)
|
||||
if model_type == "img":
|
||||
model, input, _ = get_causal_image_model(tf_model_name)
|
||||
if model_type == "keras":
|
||||
model, input, _ = get_keras_model(tf_model_name)
|
||||
if model_type == "TFhf":
|
||||
model, input, _ = get_TFhf_model(tf_model_name)
|
||||
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
input,
|
||||
frontend="tf",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
dir=tf_model_dir,
|
||||
model_name=tf_model_name,
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
|
||||
)
|
||||
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
|
||||
|
||||
|
||||
def save_tflite_model(tflite_model_list):
|
||||
from shark.tflite_utils import TFLitePreprocessor
|
||||
|
||||
with open(tflite_model_list) as csvfile:
|
||||
tflite_reader = csv.reader(csvfile, delimiter=",")
|
||||
for row in tflite_reader:
|
||||
print("\n")
|
||||
tflite_model_name = row[0]
|
||||
tflite_model_link = row[1]
|
||||
print("tflite_model_name", tflite_model_name)
|
||||
print("tflite_model_link", tflite_model_link)
|
||||
tflite_model_name_dir = os.path.join(
|
||||
WORKDIR, str(tflite_model_name) + "_tflite"
|
||||
)
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
|
||||
# Preprocess to get SharkImporter input args
|
||||
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
|
||||
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
|
||||
inputs = tflite_preprocessor.get_inputs()
|
||||
tflite_interpreter = tflite_preprocessor.get_interpreter()
|
||||
|
||||
# Use SharkImporter to get SharkInference input args
|
||||
my_shark_importer = SharkImporter(
|
||||
module=tflite_interpreter,
|
||||
inputs=inputs,
|
||||
frontend="tflite",
|
||||
raw_model_file=raw_model_file_path,
|
||||
)
|
||||
my_shark_importer.import_debug(
|
||||
dir=tflite_model_name_dir,
|
||||
model_name=tflite_model_name,
|
||||
func_name="main",
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(
|
||||
tflite_model_name_dir,
|
||||
tflite_model_name + "_tflite" + ".mlir",
|
||||
)
|
||||
)
|
||||
np.save(
|
||||
os.path.join(tflite_model_name_dir, "hash"),
|
||||
np.array(mlir_hash),
|
||||
)
|
||||
|
||||
|
||||
# Validates whether the file is present or not.
|
||||
def is_valid_file(arg):
|
||||
if not os.path.exists(arg):
|
||||
return None
|
||||
else:
|
||||
return arg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--torch_model_csv",
|
||||
type=lambda x: is_valid_file(x),
|
||||
default="./tank/torch_model_list.csv",
|
||||
help="""Contains the file with torch_model name and args.
|
||||
Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tf_model_csv",
|
||||
type=lambda x: is_valid_file(x),
|
||||
default="./tank/tf_model_list.csv",
|
||||
help="Contains the file with tf model name and args.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tflite_model_csv",
|
||||
type=lambda x: is_valid_file(x),
|
||||
default="./tank/tflite/tflite_model_list.csv",
|
||||
help="Contains the file with tf model name and args.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ci_tank_dir",
|
||||
type=bool,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument("--upload", type=bool, default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
home = str(Path.home())
|
||||
if args.ci_tank_dir == True:
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
|
||||
else:
|
||||
WORKDIR = os.path.join(home, ".local/shark_tank/")
|
||||
|
||||
if args.torch_model_csv:
|
||||
save_torch_model(args.torch_model_csv)
|
||||
|
||||
if args.tf_model_csv:
|
||||
save_tf_model(args.tf_model_csv)
|
||||
|
||||
if args.tflite_model_csv:
|
||||
save_tflite_model(args.tflite_model_csv)
|
||||
|
||||
if args.upload:
|
||||
git_hash = sp.getoutput("git log -1 --format='%h'") + "/"
|
||||
print("uploading files to gs://shark_tank/" + git_hash)
|
||||
os.system(f"gsutil cp -r {WORKDIR}* gs://shark_tank/" + git_hash)
|
||||
@@ -1,192 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
|
||||
project(sharkbackend LANGUAGES C CXX)
|
||||
|
||||
#
|
||||
# Options
|
||||
#
|
||||
|
||||
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
|
||||
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
|
||||
|
||||
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
|
||||
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
|
||||
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Dependencies
|
||||
#
|
||||
# FetchContent requires us to include the transitive closure of all
|
||||
# repos that we depend on so that we can override the tags.
|
||||
#
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-common
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
|
||||
GIT_TAG ${TRITON_COMMON_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-core
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
|
||||
GIT_TAG ${TRITON_CORE_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-backend
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
|
||||
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
|
||||
|
||||
#
|
||||
# The backend must be built into a shared library. Use an ldscript to
|
||||
# hide all symbols except for the TRITONBACKEND API.
|
||||
#
|
||||
configure_file(src/libtriton_dshark.ldscript libtriton_dshark.ldscript COPYONLY)
|
||||
|
||||
add_library(
|
||||
triton-dshark-backend SHARED
|
||||
src/dshark.cc
|
||||
#src/dshark_driver_module.c
|
||||
)
|
||||
|
||||
add_library(
|
||||
SharkBackend::triton-dshark-backend ALIAS triton-dshark-backend
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
triton-dshark-backend
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
|
||||
|
||||
add_subdirectory(thirdparty/shark-runtime EXCLUDE_FROM_ALL)
|
||||
|
||||
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
|
||||
iree_hal_hal
|
||||
iree_hal_cuda_cuda
|
||||
iree_hal_cuda_registration_registration
|
||||
iree_hal_vmvx_registration_registration
|
||||
iree_hal_dylib_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_hal_local_loaders_system_library_loader
|
||||
iree_hal_local_loaders_vmvx_module_loader
|
||||
)
|
||||
|
||||
target_compile_features(triton-dshark-backend PRIVATE cxx_std_11)
|
||||
|
||||
|
||||
target_link_libraries(
|
||||
triton-dshark-backend
|
||||
PRIVATE
|
||||
triton-core-serverapi # from repo-core
|
||||
triton-core-backendapi # from repo-core
|
||||
triton-core-serverstub # from repo-core
|
||||
triton-backend-utils # from repo-backend
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
set_target_properties(
|
||||
triton-dshark-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_dshark
|
||||
)
|
||||
else()
|
||||
set_target_properties(
|
||||
triton-dshark-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_dshark
|
||||
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_dshark.ldscript
|
||||
LINK_FLAGS "-Wl,--version-script libtriton_dshark.ldscript"
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
#
|
||||
# Install
|
||||
#
|
||||
include(GNUInstallDirs)
|
||||
set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/SharkBackend)
|
||||
|
||||
install(
|
||||
TARGETS
|
||||
triton-dshark-backend
|
||||
EXPORT
|
||||
triton-dshark-backend-targets
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
|
||||
)
|
||||
|
||||
install(
|
||||
EXPORT
|
||||
triton-dshark-backend-targets
|
||||
FILE
|
||||
SharkBackendTargets.cmake
|
||||
NAMESPACE
|
||||
SharkBackend::
|
||||
DESTINATION
|
||||
${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
include(CMakePackageConfigHelpers)
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_LIST_DIR}/cmake/SharkBackendConfig.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
|
||||
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
install(
|
||||
FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
|
||||
DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
#
|
||||
# Export from build tree
|
||||
#
|
||||
export(
|
||||
EXPORT triton-dshark-backend-targets
|
||||
FILE ${CMAKE_CURRENT_BINARY_DIR}/SharkBackendTargets.cmake
|
||||
NAMESPACE SharkBackend::
|
||||
)
|
||||
|
||||
export(PACKAGE SharkBackend)
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
# SHARK Triton Backend
|
||||
|
||||
The triton backend for shark.
|
||||
|
||||
# Build
|
||||
|
||||
Install SHARK
|
||||
|
||||
```
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
# skip above step if dshark is already installed
|
||||
cd SHARK/inference
|
||||
```
|
||||
|
||||
install dependancies
|
||||
|
||||
```
|
||||
apt-get install patchelf rapidjson-dev python3-dev
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
update the submodules of iree
|
||||
|
||||
```
|
||||
cd thirdparty/shark-runtime
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
Next, make the backend and install it
|
||||
|
||||
```
|
||||
cd ../..
|
||||
mkdir build && cd build
|
||||
cmake -DTRITON_ENABLE_GPU=ON \
|
||||
-DIREE_HAL_DRIVER_CUDA=ON \
|
||||
-DIREE_TARGET_BACKEND_CUDA=ON \
|
||||
-DMLIR_ENABLE_CUDA_RUNNER=ON \
|
||||
-DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \
|
||||
-DTRITON_BACKEND_REPO_TAG=r22.02 \
|
||||
-DTRITON_CORE_REPO_TAG=r22.02 \
|
||||
-DTRITON_COMMON_REPO_TAG=r22.02 ..
|
||||
make install
|
||||
```
|
||||
|
||||
# Incorporating into Triton
|
||||
|
||||
There are much more in depth explenations for the following steps in triton's documentation:
|
||||
https://github.com/triton-inference-server/server/blob/main/docs/compose.md#triton-with-unsupported-and-custom-backends
|
||||
|
||||
There should be a file at /build/install/backends/dshark/libtriton_dshark.so. You will need to copy it into your triton server image.
|
||||
More documentation is in the link above, but to create the docker image, you need to run the compose.py command in the triton-backend server repo
|
||||
|
||||
|
||||
To first build your image, clone the tritonserver repo.
|
||||
|
||||
```
|
||||
git clone https://github.com/triton-inference-server/server.git
|
||||
```
|
||||
|
||||
then run `compose.py` to build a docker compose file
|
||||
```
|
||||
cd server
|
||||
python3 compose.py --repoagent checksum --dry-run
|
||||
```
|
||||
|
||||
Because dshark is a third party backend, you will need to manually modify the `Dockerfile.compose` to include the dshark backend. To do this, in the Dockerfile.compose file produced, copy this line.
|
||||
the dshark backend will be located in the build folder from earlier under `/build/install/backends`
|
||||
|
||||
```
|
||||
COPY /path/to/build/install/backends/dshark /opt/tritonserver/backends/dshark
|
||||
```
|
||||
|
||||
Next run
|
||||
```
|
||||
docker build -t tritonserver_custom -f Dockerfile.compose .
|
||||
docker run -it --gpus=1 --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
|
||||
```
|
||||
|
||||
where `path/to/model_repos` is where you are storing the models you want to run
|
||||
|
||||
if your not using gpus, omit `--gpus=1`
|
||||
|
||||
```
|
||||
docker run -it --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
|
||||
```
|
||||
|
||||
# Setting up a model
|
||||
|
||||
to include a model in your backend, add a directory with your model name to your model repository directory. examples of models can be seen here: https://github.com/triton-inference-server/backend/tree/main/examples/model_repos/minimal_models
|
||||
|
||||
make sure to adjust the input correctly in the config.pbtxt file, and save a vmfb file under 1/model.vmfb
|
||||
|
||||
# CUDA
|
||||
|
||||
if you're having issues with cuda, make sure your correct drivers are installed, and that `nvidia-smi` works, and also make sure that the nvcc compiler is on the path.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include(CMakeFindDependencyMacro)
|
||||
|
||||
get_filename_component(
|
||||
SHARKBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${SHARKBACKEND_CMAKE_DIR})
|
||||
|
||||
if(NOT TARGET SharkBackend::triton-dshark-backend)
|
||||
include("${SHARKBACKEND_CMAKE_DIR}/SharkBackendTargets.cmake")
|
||||
endif()
|
||||
|
||||
set(SHARKBACKEND_LIBRARIES SharkBackend::triton-dshark-backend)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,30 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
{
|
||||
global:
|
||||
TRITONBACKEND_*;
|
||||
local: *;
|
||||
};
|
||||
1
inference/thirdparty/shark-runtime
vendored
1
inference/thirdparty/shark-runtime
vendored
Submodule inference/thirdparty/shark-runtime deleted from 7b82d90c72
45
package-index/index.html
Normal file
45
package-index/index.html
Normal file
@@ -0,0 +1,45 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230130.481/shark_sd_20230130_481.exe'>shark_sd_20230130_481.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230130.481/shark_sd_cli_20230130_481.exe'>shark_sd_cli_20230130_481.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230129.479/shark_sd_20230129_479.exe'>shark_sd_20230129_479.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230129.479/shark_sd_cli_20230129_479.exe'>shark_sd_cli_20230129_479.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230129.480/shark_sd_20230129_480.exe'>shark_sd_20230129_480.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230129.480/shark_sd_cli_20230129_480.exe'>shark_sd_cli_20230129_480.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230129.478/shark_sd_20230129_478.exe'>shark_sd_20230129_478.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230129.478/shark_sd_cli_20230129_478.exe'>shark_sd_cli_20230129_478.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230128.477/shark_sd_20230128_477.exe'>shark_sd_20230128_477.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230128.477/shark_sd_cli_20230128_477.exe'>shark_sd_cli_20230128_477.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230127.476/shark_sd_20230127_476.exe'>shark_sd_20230127_476.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230127.476/shark_sd_cli_20230127_476.exe'>shark_sd_cli_20230127_476.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230126.475/shark_sd_20230126_475.exe'>shark_sd_20230126_475.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230126.475/shark_sd_cli_20230126_475.exe'>shark_sd_cli_20230126_475.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230125.474/shark_sd_20230125_474.exe'>shark_sd_20230125_474.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230125.474/shark_sd_cli_20230125_474.exe'>shark_sd_cli_20230125_474.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230125.473/shark_sd_20230125_473.exe'>shark_sd_20230125_473.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230125.473/shark_sd_cli_20230125_473.exe'>shark_sd_cli_20230125_473.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230125.472/shark_sd_20230125_472.exe'>shark_sd_20230125_472.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230125.471/shark_sd_20230125_471.exe'>shark_sd_20230125_471.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230125.468/shark_sd_20230125_468.exe'>shark_sd_20230125_468.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230124.470/shark_sd_20230124_470.exe'>shark_sd_20230124_470.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230124.470/shark_sd_cli_20230124_470.exe'>shark_sd_cli_20230124_470.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230124.469/shark_sd_20230124_469.exe'>shark_sd_20230124_469.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230124.467/shark_sd_20230124_467.exe'>shark_sd_20230124_467.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230124.466/shark_sd_20230124_466.exe'>shark_sd_20230124_466.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230124.462/shark_sd_20230124_462.exe'>shark_sd_20230124_462.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230123.461/shark_sd_20230123_461.exe'>shark_sd_20230123_461.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230123.460/shark_sd_20230123_460.exe'>shark_sd_20230123_460.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230122.459/shark_sd_20230122_459.exe'>shark_sd_20230122_459.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230122.458/shark_sd_20230122_458.exe'>shark_sd_20230122_458.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230122.457/shark_sd_20230122_457.exe'>shark_sd_20230122_457.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230121.456/shark_sd_20230121_456.exe'>shark_sd_20230121_456.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230120.455/shark_sd_20230120_455.exe'>shark_sd_20230120_455.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230119.454/shark_sd_20230119_454.exe'>shark_sd_20230119_454.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230118.453/shark_sd_20230118_453.exe'>shark_sd_20230118_453.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230117.452/shark_sd_20230117_452.exe'>shark_sd_20230117_452.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230116.451/shark_sd_20230116_451.exe'>shark_sd_20230116_451.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230115.450/shark_sd_20230115_450.exe'>shark_sd_20230115_450.exe</a><br />
|
||||
<a href='https://github.com/nod-ai/SHARK/releases/download/20230114.449/shark_sd_20230114_449.exe'>shark_sd_20230114_449.exe</a><br />
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,12 +0,0 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"wheel",
|
||||
"packaging",
|
||||
|
||||
"numpy>=1.22.4",
|
||||
"torch-mlir>=20221021.633",
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
@@ -1,3 +0,0 @@
|
||||
[pytest]
|
||||
addopts = --verbose -p no:warnings
|
||||
norecursedirs = inference tank/tflite examples benchmarks shark
|
||||
@@ -1,45 +0,0 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/
|
||||
--pre
|
||||
|
||||
numpy
|
||||
torch
|
||||
torchvision
|
||||
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
#these dont work ok osx
|
||||
#iree-tools-tflite
|
||||
#iree-tools-xla
|
||||
#iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow-macos
|
||||
tensorflow-metal
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
tensorflow-probability
|
||||
#jax[cpu]
|
||||
|
||||
# tflitehub dependencies.
|
||||
Pillow
|
||||
|
||||
# web dependecies.
|
||||
gradio
|
||||
altair
|
||||
|
||||
# Testing and support.
|
||||
#lit
|
||||
#pyyaml
|
||||
|
||||
#ONNX and ORT for benchmarking
|
||||
#--extra-index-url https://test.pypi.org/simple/
|
||||
#protobuf
|
||||
#coloredlogs
|
||||
#flatbuffers
|
||||
#sympy
|
||||
#psutil
|
||||
#onnx-weekly
|
||||
#ort-nightly
|
||||
@@ -1,48 +0,0 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
|
||||
numpy==1.22.4
|
||||
torchvision
|
||||
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
iree-tools-tflite
|
||||
iree-tools-xla
|
||||
iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow==2.10
|
||||
keras==2.10
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
diffusers
|
||||
#tensorflow-probability
|
||||
#jax[cpu]
|
||||
|
||||
|
||||
# tflitehub dependencies.
|
||||
Pillow
|
||||
|
||||
# Testing and support.
|
||||
lit
|
||||
pyyaml
|
||||
python-dateutil
|
||||
sacremoses
|
||||
|
||||
# web dependecies.
|
||||
gradio
|
||||
altair
|
||||
scipy
|
||||
|
||||
#ONNX and ORT for benchmarking
|
||||
#--extra-index-url https://test.pypi.org/simple/
|
||||
#protobuf
|
||||
#coloredlogs
|
||||
#flatbuffers
|
||||
#sympy
|
||||
#psutil
|
||||
#onnx-weekly
|
||||
#ort-nightly
|
||||
@@ -1,25 +0,0 @@
|
||||
setuptools
|
||||
wheel
|
||||
|
||||
# SHARK Runner
|
||||
tqdm
|
||||
|
||||
# SHARK Downloader
|
||||
google-cloud-storage
|
||||
|
||||
# Testing
|
||||
pytest
|
||||
pytest-xdist
|
||||
Pillow
|
||||
parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
transformers
|
||||
diffusers
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
altair
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pyinstaller
|
||||
43
setup.py
43
setup.py
@@ -1,43 +0,0 @@
|
||||
from setuptools import find_packages
|
||||
from setuptools import setup
|
||||
|
||||
import os
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.4"
|
||||
backend_deps = []
|
||||
if "NO_BACKEND" in os.environ.keys():
|
||||
backend_deps = [
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
]
|
||||
|
||||
setup(
|
||||
name="nodai-SHARK",
|
||||
version=f"{PACKAGE_VERSION}",
|
||||
description="SHARK provides a High Performance Machine Learning Framework",
|
||||
author="nod.ai",
|
||||
author_email="stdin@nod.ai",
|
||||
url="https://nod.ai",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
project_urls={
|
||||
"Code": "https://github.com/nod-ai/SHARK",
|
||||
"Bug Tracker": "https://github.com/nod-ai/SHARK/issues",
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
packages=find_packages(exclude=("examples")),
|
||||
python_requires=">=3.9",
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"PyYAML",
|
||||
"torch-mlir>=20221021.633",
|
||||
]
|
||||
+ backend_deps,
|
||||
)
|
||||
@@ -1,39 +0,0 @@
|
||||
#Write-Host "Installing python"
|
||||
|
||||
#Start-Process winget install Python.Python.3.10 '/quiet InstallAllUsers=1 PrependPath=1' -wait -NoNewWindow
|
||||
|
||||
#Write-Host "python installation completed successfully"
|
||||
|
||||
#Write-Host "Reload environment variables"
|
||||
#$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||
#Write-Host "Reloaded environment variables"
|
||||
|
||||
|
||||
# redirect stderr into stdout
|
||||
$p = &{python -V} 2>&1
|
||||
# check if an ErrorRecord was returned
|
||||
$version = if($p -is [System.Management.Automation.ErrorRecord])
|
||||
{
|
||||
# grab the version string from the error message
|
||||
$p.Exception.Message
|
||||
}
|
||||
else
|
||||
{
|
||||
# otherwise return as is
|
||||
$p
|
||||
}
|
||||
|
||||
Write-Host "Python version found is"
|
||||
Write-Host $p
|
||||
|
||||
|
||||
Write-Host "Installing Build Dependencies"
|
||||
python -m venv .\shark.venv\
|
||||
.\shark.venv\Scripts\activate
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
Write-Host "Build and installation completed successfully"
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
149
setup_venv.sh
149
setup_venv.sh
@@ -1,149 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Sets up a venv suitable for running samples.
|
||||
# e.g:
|
||||
# ./setup_venv.sh #setup a default $PYTHON3 shark.venv
|
||||
# Environment Variables by the script.
|
||||
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
|
||||
# VENV_DIR=myshark.venv #create a venv called myshark.venv
|
||||
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
|
||||
# IMPORTER=1 #Install importer deps
|
||||
# BENCHMARK=1 #Install benchmark deps
|
||||
# NO_BACKEND=1 #Don't install iree or shark backend
|
||||
# if you run the script from a conda env it will install in your conda env
|
||||
|
||||
TD="$(cd $(dirname $0) && pwd)"
|
||||
if [ -z "$PYTHON" ]; then
|
||||
PYTHON="$(which python3)"
|
||||
fi
|
||||
|
||||
function die() {
|
||||
echo "Error executing command: $*"
|
||||
exit 1
|
||||
}
|
||||
|
||||
PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; print("{0}.{1}".format(*version))'`
|
||||
|
||||
echo "Python: $PYTHON"
|
||||
echo "Python version: $PYTHON_VERSION_X_Y"
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
# Not a conda env. So create a new VENV dir
|
||||
VENV_DIR=${VENV_DIR:-shark.venv}
|
||||
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
|
||||
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
|
||||
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
|
||||
PYTHON="$(which python3)"
|
||||
else
|
||||
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
|
||||
fi
|
||||
|
||||
Red=`tput setaf 1`
|
||||
Green=`tput setaf 2`
|
||||
Yellow=`tput setaf 3`
|
||||
|
||||
# Assume no binary torch-mlir.
|
||||
# Currently available for macOS m1&intel (3.10) and Linux(3.7,3.8,3.9,3.10)
|
||||
torch_mlir_bin=false
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "${Yellow}Apple macOS detected"
|
||||
if [[ $(uname -m) == 'arm64' ]]; then
|
||||
echo "${Yellow}Apple M1 Detected"
|
||||
hash rustc 2>/dev/null
|
||||
if [ $? -eq 0 ];then
|
||||
echo "${Green}rustc found to compile HF tokenizers"
|
||||
else
|
||||
echo "${Red}Could not find rustc" >&2
|
||||
echo "${Red}Please run:"
|
||||
echo "${Red}curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
echo "${Yellow}Run the following commands to setup your SSL certs for your Python version if you see SSL errors with tests"
|
||||
echo "${Yellow}/Applications/Python\ 3.XX/Install\ Certificates.command"
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.10" ]; then
|
||||
torch_mlir_bin=true
|
||||
fi
|
||||
elif [[ $(uname -s) = 'Linux' ]]; then
|
||||
echo "${Yellow}Linux detected"
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.7" ] || [ "$PYTHON_VERSION_X_Y" == "3.8" ] || [ "$PYTHON_VERSION_X_Y" == "3.9" ] || [ "$PYTHON_VERSION_X_Y" == "3.10" ] ; then
|
||||
torch_mlir_bin=true
|
||||
fi
|
||||
else
|
||||
echo "${Red}OS not detected. Pray and Play"
|
||||
fi
|
||||
|
||||
# Upgrade pip and install requirements.
|
||||
$PYTHON -m pip install --upgrade pip || die "Could not upgrade pip"
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
|
||||
if [ "$torch_mlir_bin" = true ]; then
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch-mlir"
|
||||
else
|
||||
echo "Could not install torch-mlir" >&2
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "${Red}No binaries found for Python $PYTHON_VERSION_X_Y on $(uname -s)"
|
||||
echo "${Yello}Python 3.10 supported on macOS and 3.7,3.8,3.9 and 3.10 on Linux"
|
||||
echo "${Red}Please build torch-mlir from source in your environment"
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "${USE_IREE}" ]]; then
|
||||
rm .use-iree
|
||||
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
|
||||
else
|
||||
touch ./.use-iree
|
||||
RUNTIME="https://iree-org.github.io/iree/pip-release-links.html"
|
||||
fi
|
||||
if [[ -z "${NO_BACKEND}" ]]; then
|
||||
echo "Installing ${RUNTIME}..."
|
||||
$PYTHON -m pip install --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
|
||||
else
|
||||
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
|
||||
fi
|
||||
|
||||
if [[ ! -z "${IMPORTER}" ]]; then
|
||||
echo "${Yellow}Installing importer tools.."
|
||||
if [[ $(uname -s) = 'Linux' ]]; then
|
||||
echo "${Yellow}Linux detected.. installing Linux importer tools"
|
||||
#Always get the importer tools from upstream IREE
|
||||
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://iree-org.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
elif [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "${Yellow}macOS detected.. installing macOS importer tools"
|
||||
#Conda seems to have some problems installing these packages and hope they get resolved upstream.
|
||||
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer-macos.txt" -f ${RUNTIME} --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
fi
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu117
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu117."
|
||||
else
|
||||
echo "Could not install torch + cu117." >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ ! -z "${ONNX}" ]]; then
|
||||
echo "${Yellow}Installing ONNX and onnxruntime for benchmarks..."
|
||||
$PYTHON -m pip install onnx onnxruntime psutil
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully installed ONNX and ONNX runtime."
|
||||
else
|
||||
echo "Could not install ONNX." >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
echo "${Green}Before running examples activate venv with:"
|
||||
echo " ${Green}source $VENV_DIR/bin/activate"
|
||||
fi
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch._decomp import get_decompositions
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.nn.utils import _stateless
|
||||
|
||||
from torch import fx
|
||||
import tempfile
|
||||
|
||||
|
||||
class MakeFxModule:
|
||||
def __init__(self, model, inputs, labels=None, custom_inference_fn=None):
|
||||
self.model = model
|
||||
self.inputs = inputs
|
||||
self.custom_inference_fn = custom_inference_fn
|
||||
self.training_graph = None
|
||||
|
||||
# Doesn't replace the None type.
|
||||
def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
# output nodes always have one argument
|
||||
node_arg = node.args[0]
|
||||
out_nodes = []
|
||||
if isinstance(node_arg, list):
|
||||
# Don't return NoneType elements.
|
||||
for out_node in node_arg:
|
||||
if not isinstance(out_node, type(None)):
|
||||
out_nodes.append(out_node)
|
||||
# If there is a single tensor/element to be returned don't
|
||||
# a tuple for it.
|
||||
if len(out_nodes) == 1:
|
||||
node.args = out_nodes
|
||||
else:
|
||||
node.args = (tuple(out_nodes),)
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
def generate_graph(self):
|
||||
fx_g = make_fx(
|
||||
self.custom_inference_fn,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
]
|
||||
),
|
||||
)(
|
||||
dict(self.model.named_parameters()),
|
||||
dict(self.model.named_buffers()),
|
||||
self.inputs,
|
||||
)
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
temp = tempfile.NamedTemporaryFile(
|
||||
suffix="_shark_ts", prefix="temp_ts_"
|
||||
)
|
||||
ts_g.save(temp.name)
|
||||
new_ts = torch.jit.load(temp.name)
|
||||
self.training_graph = new_ts
|
||||
@@ -1,70 +0,0 @@
|
||||
import torchdynamo
|
||||
import torch
|
||||
import torch_mlir
|
||||
from shark.sharkdynamo.utils import make_shark_compiler
|
||||
|
||||
|
||||
import warnings, logging
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
torchdynamo.config.log_level = logging.ERROR
|
||||
|
||||
|
||||
torchdynamo.reset()
|
||||
|
||||
|
||||
@torchdynamo.optimize(
|
||||
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
|
||||
)
|
||||
def foo(t):
|
||||
return 2 * t
|
||||
|
||||
|
||||
example_input = torch.rand((2, 3))
|
||||
x = foo(example_input)
|
||||
print(x)
|
||||
|
||||
|
||||
torchdynamo.reset()
|
||||
|
||||
|
||||
@torchdynamo.optimize(
|
||||
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
|
||||
)
|
||||
def foo(a, b):
|
||||
x = a / (a + 1)
|
||||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
|
||||
|
||||
print(foo(torch.rand((2, 3)), -torch.rand((2, 3))))
|
||||
|
||||
|
||||
torchdynamo.reset()
|
||||
|
||||
|
||||
@torchdynamo.optimize(
|
||||
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
|
||||
)
|
||||
def foo(a):
|
||||
for i in range(10):
|
||||
a += 1.0
|
||||
return a
|
||||
|
||||
|
||||
print(foo(torch.rand((1, 2))))
|
||||
|
||||
torchdynamo.reset()
|
||||
|
||||
|
||||
@torchdynamo.optimize(
|
||||
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
|
||||
)
|
||||
def test_unsupported_types(t, y):
|
||||
return t, 2 * y
|
||||
|
||||
|
||||
str_input = "hello"
|
||||
tensor_input = torch.randn(2)
|
||||
print(test_unsupported_types(str_input, tensor_input))
|
||||
@@ -1,309 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/mlevental/miniconda3/envs/torch-mlir/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# standard imports\n",
|
||||
"import torch\n",
|
||||
"from shark.iree_utils import get_iree_compiled_module"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# torch dynamo related imports\n",
|
||||
"try:\n",
|
||||
" import torchdynamo\n",
|
||||
" from torchdynamo.optimizations.backends import create_backend\n",
|
||||
" from torchdynamo.optimizations.subgraph import SubGraph\n",
|
||||
"except ModuleNotFoundError:\n",
|
||||
" print(\n",
|
||||
" \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n",
|
||||
" )\n",
|
||||
" exit()\n",
|
||||
"\n",
|
||||
"# torch-mlir imports for compiling\n",
|
||||
"from torch_mlir import compile, OutputType"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"[TorchDynamo](https://github.com/pytorch/torchdynamo) is a compiler for PyTorch programs that uses the [frame evaluation API](https://www.python.org/dev/peps/pep-0523/) in CPython to dynamically modify Python bytecode right before it is executed. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def toy_example(*args):\n",
|
||||
" a, b = args\n",
|
||||
"\n",
|
||||
" x = a / (torch.abs(a) + 1)\n",
|
||||
" if b.sum() < 0:\n",
|
||||
" b = b * -1\n",
|
||||
" return x * b"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# compiler that lowers fx_graph to through MLIR\n",
|
||||
"def __torch_mlir(fx_graph, *args, **kwargs):\n",
|
||||
" assert isinstance(\n",
|
||||
" fx_graph, torch.fx.GraphModule\n",
|
||||
" ), \"Model must be an FX GraphModule.\"\n",
|
||||
"\n",
|
||||
" def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule):\n",
|
||||
" \"\"\"Replace tuple with tuple element in functions that return one-element tuples.\"\"\"\n",
|
||||
"\n",
|
||||
" for node in fx_g.graph.nodes:\n",
|
||||
" if node.op == \"output\":\n",
|
||||
" assert (\n",
|
||||
" len(node.args) == 1\n",
|
||||
" ), \"Output node must have a single argument\"\n",
|
||||
" node_arg = node.args[0]\n",
|
||||
" if isinstance(node_arg, tuple) and len(node_arg) == 1:\n",
|
||||
" node.args = (node_arg[0],)\n",
|
||||
" fx_g.graph.lint()\n",
|
||||
" fx_g.recompile()\n",
|
||||
" return fx_g\n",
|
||||
"\n",
|
||||
" fx_graph = _unwrap_single_tuple_return(fx_graph)\n",
|
||||
" ts_graph = torch.jit.script(fx_graph)\n",
|
||||
"\n",
|
||||
" # torchdynamo does munges the args differently depending on whether you use\n",
|
||||
" # the @torchdynamo.optimize decorator or the context manager\n",
|
||||
" if isinstance(args, tuple):\n",
|
||||
" args = list(args)\n",
|
||||
" assert isinstance(args, list)\n",
|
||||
" if len(args) == 1 and isinstance(args[0], list):\n",
|
||||
" args = args[0]\n",
|
||||
"\n",
|
||||
" linalg_module = compile(\n",
|
||||
" ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n",
|
||||
" )\n",
|
||||
" callable, _ = get_iree_compiled_module(\n",
|
||||
" linalg_module, \"cuda\", func_name=\"forward\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(*inputs):\n",
|
||||
" return callable(*inputs)\n",
|
||||
"\n",
|
||||
" return forward"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Simplest way to use TorchDynamo with the `torchdynamo.optimize` context manager:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found 1 device(s).\n",
|
||||
"Device: 0\n",
|
||||
" Name: NVIDIA GeForce RTX 3080\n",
|
||||
" Compute Capability: 8.6\n",
|
||||
"[-0.40066046 -0.4210303 0.03225489 -0.44849953 0.10370405 -0.04422468\n",
|
||||
" 0.33262825 -0.20109026 0.02102537 -0.24882983]\n",
|
||||
"[-0.07824923 -0.17004533 0.06439921 -0.06163602 0.26633525 -1.1560082\n",
|
||||
" -0.06660341 0.24227881 0.1462235 -0.32055548]\n",
|
||||
"[-0.01464001 0.442209 -0.0607936 -0.5477967 -0.25226554 -0.08588809\n",
|
||||
" -0.30497575 0.00061084 -0.50069696 0.2317973 ]\n",
|
||||
"[ 0.25726247 0.39388427 -0.24093066 0.12316308 -0.01981307 0.5661146\n",
|
||||
" 0.26199922 0.8123446 -0.01576749 0.30846444]\n",
|
||||
"[ 0.7878203 -0.45975062 -0.29956317 -0.07032048 -0.55817443 -0.62506855\n",
|
||||
" -1.6837492 -0.38442805 0.28220773 -1.5325156 ]\n",
|
||||
"[ 0.07975311 0.67754704 -0.30927914 0.00347631 -0.07326564 0.01893554\n",
|
||||
" -0.7518105 -0.03078967 -0.07623022 0.38865626]\n",
|
||||
"[-0.7751679 -0.5841397 -0.6622711 0.18574935 -0.6049372 0.02844244\n",
|
||||
" -0.20471913 0.3337415 -0.3619432 -0.35087156]\n",
|
||||
"[-0.08569919 -0.10775139 -0.02338934 0.21933547 -0.46712473 0.00062137\n",
|
||||
" -0.58207744 0.06457533 0.18276742 0.03866556]\n",
|
||||
"[-0.2311981 -0.43036282 0.20561649 -0.10363232 -0.13248594 0.02885137\n",
|
||||
" -0.31241602 -0.36907142 0.08861586 0.2331427 ]\n",
|
||||
"[-0.07273526 -0.31246194 -0.24218291 -0.24145737 0.0364486 0.14382267\n",
|
||||
" -0.00531162 0.15447603 -0.5220248 -0.09016377]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"with torchdynamo.optimize(__torch_mlir):\n",
|
||||
" for _ in range(10):\n",
|
||||
" print(toy_example(torch.randn(10), torch.randn(10)))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"It can also be used through a decorator:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@create_backend\n",
|
||||
"def torch_mlir(subgraph, *args, **kwargs):\n",
|
||||
" assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n",
|
||||
" return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@torchdynamo.optimize(\"torch_mlir\")\n",
|
||||
"def toy_example2(*args):\n",
|
||||
" a, b = args\n",
|
||||
"\n",
|
||||
" x = a / (torch.abs(a) + 1)\n",
|
||||
" if b.sum() < 0:\n",
|
||||
" b = b * -1\n",
|
||||
" return x * b"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found 1 device(s).\n",
|
||||
"Device: 0\n",
|
||||
" Name: NVIDIA GeForce RTX 3080\n",
|
||||
" Compute Capability: 8.6\n",
|
||||
"[-0.35494277 0.03409214 -0.02271946 0.7335942 0.03122527 -0.41881397\n",
|
||||
" -0.6609761 -0.6418614 0.29336175 -0.01973678]\n",
|
||||
"[-2.7246824e-01 -3.5543957e-01 6.0087401e-01 -7.4570496e-03\n",
|
||||
" -4.2481605e-02 -5.0296803e-04 7.2928613e-01 -1.4673788e-03\n",
|
||||
" -2.7621329e-01 -6.0995776e-02]\n",
|
||||
"[-0.03165906 0.3889693 0.24052973 0.27279532 -0.02773128 -0.12602475\n",
|
||||
" -1.0124422 0.5720256 -0.35437614 -0.20992722]\n",
|
||||
"[-0.41831446 0.5525326 -0.29749998 -0.17044766 0.11804754 -0.05210691\n",
|
||||
" -0.46145165 -0.8776549 0.10090438 0.17463352]\n",
|
||||
"[ 0.02194221 0.20959911 0.26973712 0.12551276 -0.0020404 0.1490246\n",
|
||||
" -0.04456685 1.1100804 0.8105744 0.6676846 ]\n",
|
||||
"[ 0.06528181 -0.13591261 0.5370964 -0.4398162 -0.03372452 0.9691372\n",
|
||||
" -0.01120087 0.2947028 0.4804801 -0.3324341 ]\n",
|
||||
"[ 0.33549032 -0.23001772 -0.08681437 0.16490957 -0.11223086 0.09168988\n",
|
||||
" 0.02403045 0.17344482 0.46406478 -0.00129451]\n",
|
||||
"[-0.27475086 0.42384806 1.9090122 -0.41147137 -0.6888369 0.08435658\n",
|
||||
" -0.26628923 -0.17436793 -0.8058869 -0.02582378]\n",
|
||||
"[-0.10109414 0.08681287 -0.10055986 0.6858881 0.29267687 -0.02797117\n",
|
||||
" -0.01425194 0.4882803 0.3551982 -0.858935 ]\n",
|
||||
"[-0.22086617 0.524994 0.17721705 -0.03813264 -0.54570735 -0.4421502\n",
|
||||
" 0.11938014 -0.01122053 0.39294165 -0.61770755]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for _ in range(10):\n",
|
||||
" print(toy_example2(torch.randn(10), torch.randn(10)))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
import torch
|
||||
from torch_mlir import compile, OutputType
|
||||
|
||||
from shark.iree_utils import get_iree_compiled_module
|
||||
|
||||
try:
|
||||
import torchdynamo
|
||||
from torchdynamo.optimizations.backends import create_backend
|
||||
from torchdynamo.optimizations.subgraph import SubGraph
|
||||
except ModuleNotFoundError:
|
||||
print(
|
||||
"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo"
|
||||
)
|
||||
exit()
|
||||
|
||||
NUM_ITERS = 10
|
||||
|
||||
|
||||
def __torch_mlir(fx_graph, *args, **kwargs):
|
||||
assert isinstance(
|
||||
fx_graph, torch.fx.GraphModule
|
||||
), "Model must be an FX GraphModule."
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule):
|
||||
"""Replace tuple with tuple element in functions that return one-element tuples."""
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple) and len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
fx_graph = _unwrap_single_tuple_return(fx_graph)
|
||||
ts_graph = torch.jit.script(fx_graph)
|
||||
|
||||
if isinstance(args, tuple):
|
||||
args = list(args)
|
||||
assert isinstance(args, list)
|
||||
if len(args) == 1 and isinstance(args[0], list):
|
||||
args = args[0]
|
||||
|
||||
linalg_module = compile(
|
||||
ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS
|
||||
)
|
||||
callable, _ = get_iree_compiled_module(
|
||||
linalg_module, "cuda", func_name="forward"
|
||||
)
|
||||
|
||||
def forward(*inputs):
|
||||
return callable(*inputs)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def toy_example(*args):
|
||||
a, b = args
|
||||
|
||||
x = a / (torch.abs(a) + 1)
|
||||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
|
||||
|
||||
with torchdynamo.optimize(__torch_mlir):
|
||||
for _ in range(10):
|
||||
print(toy_example(torch.randn(10), torch.randn(10)))
|
||||
|
||||
|
||||
@create_backend
|
||||
def torch_mlir(subgraph, *args, **kwargs):
|
||||
assert isinstance(subgraph, SubGraph), "Model must be a dynamo SubGraph."
|
||||
return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))
|
||||
|
||||
|
||||
@torchdynamo.optimize("torch_mlir")
|
||||
def toy_example2(*args):
|
||||
a, b = args
|
||||
|
||||
x = a / (torch.abs(a) + 1)
|
||||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
|
||||
|
||||
for _ in range(10):
|
||||
print(toy_example2(torch.randn(10), torch.randn(10)))
|
||||
@@ -1,805 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/mlevental/miniconda3/envs/torch-mlir/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# standard imports\n",
|
||||
"import torch\n",
|
||||
"from torch_mlir.eager_mode import torch_mlir_tensor"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# eager mode imports\n",
|
||||
"from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor\n",
|
||||
"from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"The simplest way of using Eager Mode (through IREE) requires setting a \"backend\":"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend(\"cpu\")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"and wrapping all your `torch.Tensor`s:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"NUM_ITERS = 10\n",
|
||||
"\n",
|
||||
"t = torch.ones((10, 10))\n",
|
||||
"u = 2 * torch.ones((10, 10))\n",
|
||||
"\n",
|
||||
"tt = TorchMLIRTensor(t)\n",
|
||||
"print(tt)\n",
|
||||
"uu = TorchMLIRTensor(u)\n",
|
||||
"print(uu)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"`TorchMLIRTensor` is a \"tensor wrapper subclass\" (more info [here](https://github.com/albanD/subclass_zoo)) that keeps the IREE `DeviceArray` in a field `elem`:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for i in range(NUM_ITERS):\n",
|
||||
" yy = tt + uu\n",
|
||||
" print(type(yy))\n",
|
||||
" print(yy.elem.to_host())\n",
|
||||
" yy = tt * uu\n",
|
||||
" print(type(yy))\n",
|
||||
" print(yy.elem.to_host())"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"If you have a GPU (and CUDA installed) that works too (you can verify by having `watch -n1 nvidia-smi` up in a terminal while running the next cell):"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend(\"gpu\")\n",
|
||||
"\n",
|
||||
"t = torch.ones((10, 10))\n",
|
||||
"u = 2 * torch.ones((10, 10))\n",
|
||||
"\n",
|
||||
"tt = TorchMLIRTensor(t)\n",
|
||||
"print(tt)\n",
|
||||
"uu = TorchMLIRTensor(u)\n",
|
||||
"print(uu)\n",
|
||||
"\n",
|
||||
"yy = tt + uu\n",
|
||||
"print(yy.elem.to_host())\n",
|
||||
"yy = tt * uu\n",
|
||||
"print(yy.elem.to_host())"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"There is a convenience class `SharkEagerMode` that will handle both the installation of the backend and the wrapping of `torch.Tensor`s:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# eager mode RAII\n",
|
||||
"from shark.shark_runner import SharkEagerMode\n",
|
||||
"\n",
|
||||
"shark_eager_mode = SharkEagerMode(\"cpu\")\n",
|
||||
"\n",
|
||||
"t = torch.ones((10, 10))\n",
|
||||
"u = torch.ones((10, 10))\n",
|
||||
"\n",
|
||||
"print(t)\n",
|
||||
"print(u)\n",
|
||||
"\n",
|
||||
"for i in range(NUM_ITERS):\n",
|
||||
" yy = t + u\n",
|
||||
" print(type(yy))\n",
|
||||
" print(yy.elem.to_host())\n",
|
||||
" yy = t * u\n",
|
||||
" print(type(yy))\n",
|
||||
" print(yy.elem.to_host())"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"The `SharkEagerMode` class is a hacky take on [RAII](https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization) that defines a \"deleter\" that runs when an instantiation (of `SharkEagerMode`) is garbage collected. Takeaway is that if you want to turn off `SharkEagerMode`, or switch backends, you need to `del` the instance:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"del shark_eager_mode\n",
|
||||
"shark_eager_mode = SharkEagerMode(\"cuda\")\n",
|
||||
"\n",
|
||||
"t = torch.ones((10, 10))\n",
|
||||
"u = torch.ones((10, 10))\n",
|
||||
"\n",
|
||||
"print(t)\n",
|
||||
"print(u)\n",
|
||||
"\n",
|
||||
"yy = t + u\n",
|
||||
"print(type(yy))\n",
|
||||
"print(yy.elem.to_host())\n",
|
||||
"yy = t * u\n",
|
||||
"print(type(yy))\n",
|
||||
"print(yy.elem.to_host())"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -1,148 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load_inline, include_paths
|
||||
from torch_mlir.eager_mode import torch_mlir_tensor
|
||||
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
||||
|
||||
from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
|
||||
from shark.shark_runner import SharkEagerMode
|
||||
|
||||
|
||||
def test_cpu():
|
||||
torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend("cpu")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = 2 * torch.ones((10, 10), device="cpu")
|
||||
|
||||
tt = TorchMLIRTensor(t)
|
||||
print(tt)
|
||||
uu = TorchMLIRTensor(u)
|
||||
print(uu)
|
||||
|
||||
for i in range(NUM_ITERS):
|
||||
yy = tt + uu
|
||||
print(type(yy))
|
||||
print(yy.elem.to_host())
|
||||
yy = tt * uu
|
||||
print(type(yy))
|
||||
print(yy.elem.to_host())
|
||||
|
||||
|
||||
def test_gpu():
|
||||
source = """
|
||||
#include <iostream>
|
||||
#include "cuda.h"
|
||||
#include "cuda_runtime_api.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
void print_free_mem() {
|
||||
int num_gpus;
|
||||
size_t free, total;
|
||||
cudaSetDevice(0);
|
||||
int id;
|
||||
cudaGetDevice(&id);
|
||||
cudaMemGetInfo(&free, &total);
|
||||
cout << "GPU " << id << " memory: used=" << (total-free)/(1<<20) << endl;
|
||||
}
|
||||
"""
|
||||
gpu_stats = load_inline(
|
||||
name="inline_extension",
|
||||
cpp_sources=[source],
|
||||
extra_include_paths=include_paths(cuda=True),
|
||||
functions=["print_free_mem"],
|
||||
)
|
||||
torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend("gpu")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = 2 * torch.ones((10, 10), device="cpu")
|
||||
|
||||
tt = TorchMLIRTensor(t)
|
||||
print(tt)
|
||||
uu = TorchMLIRTensor(u)
|
||||
print(uu)
|
||||
|
||||
for i in range(NUM_ITERS):
|
||||
yy = tt + uu
|
||||
print(yy.elem.to_host())
|
||||
yy = tt * uu
|
||||
print(yy.elem.to_host())
|
||||
gpu_stats.print_free_mem()
|
||||
|
||||
|
||||
def test_python_mode_ref_backend():
|
||||
# hide this wherever you want?
|
||||
_ = SharkEagerMode("refbackend")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = torch.ones((10, 10), device="cpu")
|
||||
|
||||
print(t)
|
||||
print(u)
|
||||
|
||||
for i in range(NUM_ITERS):
|
||||
print(i)
|
||||
yy = t + u
|
||||
print(yy.elem)
|
||||
yy = t * u
|
||||
print(yy.elem)
|
||||
|
||||
|
||||
def test_python_mode_iree_cpu():
|
||||
# hide this wherever you want?
|
||||
_ = SharkEagerMode("cpu")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = torch.ones((10, 10), device="cpu")
|
||||
|
||||
print(t)
|
||||
print(u)
|
||||
|
||||
for i in range(NUM_ITERS):
|
||||
yy = t + u
|
||||
print(type(yy))
|
||||
print(yy.elem.to_host())
|
||||
yy = t * u
|
||||
print(type(yy))
|
||||
print(yy.elem.to_host())
|
||||
|
||||
|
||||
def test_python_mode_iree_gpu():
|
||||
_ = SharkEagerMode("gpu")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = torch.ones((10, 10), device="cpu")
|
||||
|
||||
print(t)
|
||||
print(u)
|
||||
|
||||
for i in range(NUM_ITERS):
|
||||
yy = t + u
|
||||
print(type(yy))
|
||||
print(yy.elem.to_host())
|
||||
yy = t * u
|
||||
print(type(yy))
|
||||
print(yy.elem.to_host())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
NUM_ITERS = 10
|
||||
test_cpu()
|
||||
if torch.cuda.is_available():
|
||||
test_gpu()
|
||||
test_python_mode_ref_backend()
|
||||
test_python_mode_iree_cpu()
|
||||
test_python_mode_iree_gpu()
|
||||
@@ -1,73 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
model = torch.hub.load(
|
||||
"pytorch/vision:v0.10.0", "squeezenet1_0", pretrained=True
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# from PIL import Image
|
||||
# from torchvision import transforms
|
||||
# import urllib
|
||||
#
|
||||
# url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
|
||||
# try: urllib.URLopener().retrieve(url, filename)
|
||||
# except: urllib.request.urlretrieve(url, filename)
|
||||
#
|
||||
#
|
||||
# input_image = Image.open(filename)
|
||||
# preprocess = transforms.Compose([
|
||||
# transforms.Resize(256),
|
||||
# transforms.CenterCrop(224),
|
||||
# transforms.ToTensor(),
|
||||
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
# ])
|
||||
# input_tensor = preprocess(input_image)
|
||||
# input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
||||
# print(input_batch.shape) # size = [1, 3, 224, 224]
|
||||
|
||||
# The above is code for generating sample inputs from an image. We can just use
|
||||
# random values for accuracy testing though
|
||||
input_batch = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
# Focus on CPU for now
|
||||
if False and torch.cuda.is_available():
|
||||
input_batch = input_batch.to("cuda")
|
||||
model.to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_batch)
|
||||
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
|
||||
golden_confidences = output[0]
|
||||
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
|
||||
golden_probabilities = torch.nn.functional.softmax(
|
||||
golden_confidences, dim=0
|
||||
).numpy()
|
||||
|
||||
golden_confidences = golden_confidences.numpy()
|
||||
|
||||
from shark.torch_mlir_lockstep_tensor import TorchMLIRLockstepTensor
|
||||
|
||||
input_detached_clone = input_batch.clone()
|
||||
eager_input_batch = TorchMLIRLockstepTensor(input_detached_clone)
|
||||
|
||||
print("getting torch-mlir result")
|
||||
|
||||
output = model(eager_input_batch)
|
||||
|
||||
static_output = output.elem
|
||||
confidences = static_output[0]
|
||||
probabilities = torch.nn.functional.softmax(
|
||||
torch.from_numpy(confidences), dim=0
|
||||
).numpy()
|
||||
|
||||
print("The obtained result via shark is: ", confidences)
|
||||
print("The golden result is:", golden_confidences)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
golden_confidences, confidences, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
golden_probabilities, probabilities, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
@@ -1,65 +0,0 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import CLIPProcessor, TFCLIPModel
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
# Create a set of inputs
|
||||
clip_vit_inputs = [
|
||||
tf.TensorSpec(shape=[2, 7], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[2, 7], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
|
||||
]
|
||||
|
||||
|
||||
class CLIPModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(CLIPModule, self).__init__()
|
||||
self.m = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
self.m.predict = lambda x, y, z: self.m(
|
||||
input_ids=x, attention_mask=y, pixel_values=z
|
||||
)
|
||||
|
||||
@tf.function(input_signature=clip_vit_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, pixel_values):
|
||||
return self.m.predict(
|
||||
input_ids, attention_mask, pixel_values
|
||||
).logits_per_image
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(
|
||||
text=["a photo of a cat", "a photo of a dog"],
|
||||
images=image,
|
||||
return_tensors="tf",
|
||||
padding=True,
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
CLIPModule(),
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["pixel_values"],
|
||||
),
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
|
||||
print(
|
||||
shark_module.forward(
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["pixel_values"],
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -1,15 +0,0 @@
|
||||
## Running ESRGAN
|
||||
|
||||
```
|
||||
1. pip install numpy opencv-python
|
||||
2. mkdir InputImages
|
||||
(this is where all the input images will reside in)
|
||||
3. mkdir OutputImages
|
||||
(this is where the model will generate all the images)
|
||||
4. mkdir models
|
||||
(save the .pth checkpoint file here)
|
||||
5. python esrgan.py
|
||||
```
|
||||
|
||||
- Download [RRDB_ESRGAN_x4.pth](https://drive.google.com/drive/u/0/folders/17VYV_SoZZesU6mbxz2dMAIccSSlqLecY) and place it in the `models` directory as mentioned above in step 4.
|
||||
- Credits : [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
@@ -1,240 +0,0 @@
|
||||
from ast import arg
|
||||
import os.path as osp
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from shark.shark_inference import SharkInference
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""Residual in Residual Dense Block"""
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(
|
||||
self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
fea = self.lrelu(
|
||||
self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument(
|
||||
"--mlir_loc",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location of the model's mlir file",
|
||||
)
|
||||
args = p.parse_args()
|
||||
###################################################
|
||||
|
||||
|
||||
def inference(input_m):
|
||||
return model(input_m)
|
||||
|
||||
|
||||
def load_mlir(mlir_loc):
|
||||
import os
|
||||
|
||||
if mlir_loc == None:
|
||||
return None
|
||||
print(f"Trying to load the model from {mlir_loc}.")
|
||||
with open(os.path.join(mlir_loc)) as f:
|
||||
mlir_module = f.read()
|
||||
return mlir_module
|
||||
|
||||
|
||||
def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
|
||||
module = load_mlir(mlir_loc)
|
||||
if module == None:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
print("Torchscript graph generated successfully")
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
inputs,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mlir_model = str(module)
|
||||
func_name = "forward"
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
model_path = "models/RRDB_ESRGAN_x4.pth" # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
|
||||
# device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
|
||||
device = torch.device("cpu")
|
||||
|
||||
test_img_folder = "InputImages/*"
|
||||
|
||||
model = RRDBNet(3, 3, 64, 23, gc=32)
|
||||
model.load_state_dict(torch.load(model_path), strict=True)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
print("Model path {:s}. \nTesting...".format(model_path))
|
||||
|
||||
if __name__ == "__main__":
|
||||
idx = 0
|
||||
for path in glob.glob(test_img_folder):
|
||||
idx += 1
|
||||
base = osp.splitext(osp.basename(path))[0]
|
||||
print(idx, base)
|
||||
# read images
|
||||
img = cv2.imread(path, cv2.IMREAD_COLOR)
|
||||
img = img * 1.0 / 255
|
||||
img = torch.from_numpy(
|
||||
np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))
|
||||
).float()
|
||||
img_LR = img.unsqueeze(0)
|
||||
img_LR = img_LR.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
shark_module = compile_through_fx(inference, img_LR)
|
||||
shark_output = shark_module.forward((img_LR,))
|
||||
shark_output = torch.from_numpy(shark_output)
|
||||
shark_output = (
|
||||
shark_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
)
|
||||
esrgan_output = (
|
||||
model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
)
|
||||
# SHARK OUTPUT
|
||||
shark_output = np.transpose(shark_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
shark_output = (shark_output * 255.0).round()
|
||||
cv2.imwrite(
|
||||
"OutputImages/{:s}_rlt_shark_output.png".format(base), shark_output
|
||||
)
|
||||
print("Generated SHARK's output")
|
||||
# ESRGAN OUTPUT
|
||||
esrgan_output = np.transpose(esrgan_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
esrgan_output = (esrgan_output * 255.0).round()
|
||||
cv2.imwrite(
|
||||
"OutputImages/{:s}_rlt_esrgan_output.png".format(base),
|
||||
esrgan_output,
|
||||
)
|
||||
print("Generated ESRGAN's output")
|
||||
@@ -1,88 +0,0 @@
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from iree.compiler import compile_str
|
||||
from iree import runtime as ireert
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
|
||||
class AlbertModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForMaskedLM.from_pretrained("albert-base-v2")
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.model(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).logits
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||
text = "This [MASK] is very tasty."
|
||||
encoded_inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
|
||||
mlir_importer = SharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
encoded_inputs["input_ids"] == tokenizer.mask_token_id
|
||||
)[1]
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence with [MASK] to fill: ")
|
||||
encoded_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
encoded_inputs["input_ids"] == tokenizer.mask_token_id
|
||||
)[1]
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = (
|
||||
torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
|
||||
)
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
break
|
||||
@@ -1,100 +0,0 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from iree.compiler import tf as tfc
|
||||
from iree.compiler import compile_str
|
||||
from iree import runtime as ireert
|
||||
import os
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
# Create a set of inputs
|
||||
t5_inputs = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class AlbertModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(AlbertModule, self).__init__()
|
||||
self.m = TFAutoModelForMaskedLM.from_pretrained("albert-base-v2")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||
# text = "This is a great [MASK]."
|
||||
text = "This [MASK] is very tasty."
|
||||
encoded_inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="tf",
|
||||
)
|
||||
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
|
||||
mlir_importer = SharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="tf",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
output_idx = 0
|
||||
data_idx = 1
|
||||
token_logits = shark_module.forward(inputs)[output_idx][data_idx]
|
||||
mask_id = np.where(
|
||||
tf.squeeze(encoded_inputs["input_ids"]) == tokenizer.mask_token_id
|
||||
)
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[0:5]
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence with [MASK] to fill: ")
|
||||
encoded_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="tf",
|
||||
)
|
||||
inputs = (
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
token_logits = shark_module.forward(inputs)[output_idx][data_idx]
|
||||
mask_id = np.where(
|
||||
tf.squeeze(encoded_inputs["input_ids"])
|
||||
== tokenizer.mask_token_id
|
||||
)
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[
|
||||
0:5
|
||||
]
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
sys.exit()
|
||||
@@ -1,14 +0,0 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"bloom", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
@@ -1,40 +0,0 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import GPT2Tokenizer, TFGPT2Model
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
# Create a set of inputs
|
||||
gpt2_inputs = [
|
||||
tf.TensorSpec(shape=[1, 8], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[1, 8], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class GPT2Module(tf.Module):
|
||||
def __init__(self):
|
||||
super(GPT2Module, self).__init__()
|
||||
self.m = TFGPT2Model.from_pretrained("distilgpt2")
|
||||
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
|
||||
|
||||
@tf.function(input_signature=gpt2_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
||||
text = "I love the distilled version of models."
|
||||
|
||||
inputs = tokenizer(text, return_tensors="tf")
|
||||
shark_module = SharkInference(
|
||||
GPT2Module(), (inputs["input_ids"], inputs["attention_mask"])
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
print(
|
||||
shark_module.forward((inputs["input_ids"], inputs["attention_mask"]))
|
||||
)
|
||||
@@ -1,37 +0,0 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
import numpy as np
|
||||
|
||||
mhlo_ir = r"""builtin.module {
|
||||
func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32>
|
||||
%1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
return %1 : tensor<4x4xf32>
|
||||
}
|
||||
}"""
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
|
||||
print("Running shark on cpu backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo"
|
||||
)
|
||||
|
||||
# Generate the random inputs and feed into the graph.
|
||||
x = shark_module.generate_random_inputs()
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on cuda backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on vulkan backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
@@ -1,35 +0,0 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
|
||||
shark_module = SharkInference(
|
||||
MiniLMSequenceClassification(),
|
||||
(test_input,),
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
|
||||
shark_module.compile()
|
||||
shark_module.forward((test_input,))
|
||||
shark_module.benchmark_all((test_input,))
|
||||
@@ -1,61 +0,0 @@
|
||||
import tensorflow as tf
|
||||
from transformers import BertModel, BertTokenizer, TFBertModel
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
# Create a set of 2-dimensional inputs
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
# Create a BERT trainer with the created network.
|
||||
self.m = TFBertModel.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True
|
||||
)
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m.predict = lambda x, y, z: self.m.call(
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
|
||||
test_input = (
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
BertModule(), test_input, benchmark_mode=True
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
shark_module.benchmark_all(test_input)
|
||||
@@ -1,25 +0,0 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased",
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
|
||||
|
||||
# Let's generate random inputs, currently supported
|
||||
# for static models.
|
||||
rand_inputs = shark_module.generate_random_inputs()
|
||||
rand_results = shark_module.forward(rand_inputs)
|
||||
|
||||
print("Running shark_module with random_inputs is: ", rand_results)
|
||||
@@ -1,70 +0,0 @@
|
||||
import tensorflow as tf
|
||||
from transformers import BertModel, BertTokenizer, TFBertModel
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
# Create a set of 2-dimensional inputs
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
# Create a BERT trainer with the created network.
|
||||
self.m = TFBertModel.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True
|
||||
)
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m.predict = lambda x, y, z: self.m.call(
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
BertModule(),
|
||||
(
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
),
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
|
||||
print(
|
||||
shark_module.forward(
|
||||
(
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
)
|
||||
)
|
||||
)
|
||||
File diff suppressed because one or more lines are too long
@@ -1,39 +0,0 @@
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
torch.hub.list("zhanghang1989/ResNeSt", force_reload=True)
|
||||
|
||||
|
||||
class ResnestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = torch.hub.load(
|
||||
"zhanghang1989/ResNeSt", "resnest50", pretrained=True
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
ResnestModule(),
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(vision_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
print(golden_out)
|
||||
|
||||
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input,))
|
||||
print("Obtained result", result)
|
||||
@@ -1,76 +0,0 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import sys
|
||||
import torchvision.models as models
|
||||
import torch_mlir
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = models.resnet50(pretrained=True)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
model = VisionModule()
|
||||
test_input = torch.randn(1, 3, 224, 224)
|
||||
actual_out = model(test_input)
|
||||
|
||||
test_input_fp16 = test_input.to(device=torch.device("cuda"), dtype=torch.half)
|
||||
model_fp16 = model.half()
|
||||
model_fp16.eval()
|
||||
model_fp16.to("cuda")
|
||||
actual_out_fp16 = model_fp16(test_input_fp16)
|
||||
|
||||
ts_g = torch.jit.trace(model_fp16, [test_input_fp16])
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(test_input_fp16),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# from contextlib import redirect_stdout
|
||||
|
||||
# with open('resnet50_fp16_linalg_ir.mlir', 'w') as f:
|
||||
# with redirect_stdout(f):
|
||||
# print(module.operation.get_asm())
|
||||
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cuda", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
|
||||
def shark_result(x):
|
||||
x_ny = x.cpu().detach().numpy()
|
||||
inputs = (x_ny,)
|
||||
result = shark_module.forward(inputs)
|
||||
return torch.from_numpy(result)
|
||||
|
||||
|
||||
observed_out = shark_result(test_input_fp16)
|
||||
|
||||
print("Golden result:", actual_out_fp16)
|
||||
print("SHARK result:", observed_out)
|
||||
|
||||
actual_out_fp16 = actual_out_fp16.to(device=torch.device("cpu"))
|
||||
|
||||
print(
|
||||
torch.testing.assert_allclose(
|
||||
actual_out_fp16, observed_out, rtol=1e-2, atol=1e-2
|
||||
)
|
||||
)
|
||||
@@ -1,85 +0,0 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
################################## Preprocessing inputs and model ############
|
||||
def load_and_preprocess_image(url: str):
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
|
||||
}
|
||||
img = Image.open(
|
||||
requests.get(url, headers=headers, stream=True).raw
|
||||
).convert("RGB")
|
||||
# preprocessing pipeline
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
]
|
||||
)
|
||||
img_preprocessed = preprocess(img)
|
||||
return torch.unsqueeze(img_preprocessed, 0)
|
||||
|
||||
|
||||
def load_labels():
|
||||
classes_text = requests.get(
|
||||
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt",
|
||||
stream=True,
|
||||
).text
|
||||
labels = [line.strip() for line in classes_text.splitlines()]
|
||||
return labels
|
||||
|
||||
|
||||
def top3_possibilities(res):
|
||||
_, indexes = torch.sort(res, descending=True)
|
||||
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
|
||||
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
|
||||
return top3
|
||||
|
||||
|
||||
class Resnet50Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.resnet = models.resnet50(pretrained=True)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
|
||||
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
print("load image from " + image_url, file=sys.stderr)
|
||||
img = load_and_preprocess_image(image_url)
|
||||
labels = load_labels()
|
||||
|
||||
##############################################################################
|
||||
|
||||
|
||||
## Can pass any img or input to the forward module.
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"resnet50", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
path = shark_module.save_module()
|
||||
shark_module.load_module(path)
|
||||
result = shark_module.forward((img.detach().numpy(),))
|
||||
|
||||
print("The top 3 results obtained via shark_runner is:")
|
||||
print(top3_possibilities(torch.from_numpy(result)))
|
||||
|
||||
print()
|
||||
|
||||
print("The top 3 results obtained via torch is:")
|
||||
print(top3_possibilities(Resnet50Module()(img)))
|
||||
@@ -1,392 +0,0 @@
|
||||
# Description: an implementation of a deep learning recommendation model (DLRM)
|
||||
# The model input consists of dense and sparse features. The former is a vector
|
||||
# of floating point values. The latter is a list of sparse indices into
|
||||
# embedding tables, which consist of vectors of floating point values.
|
||||
# The selected vectors are passed to mlp networks denoted by triangles,
|
||||
# in some cases the vectors are interacted through operators (Ops).
|
||||
#
|
||||
# output:
|
||||
# vector of values
|
||||
# model: |
|
||||
# /\
|
||||
# /__\
|
||||
# |
|
||||
# _____________________> Op <___________________
|
||||
# / | \
|
||||
# /\ /\ /\
|
||||
# /__\ /__\ ... /__\
|
||||
# | | |
|
||||
# | Op Op
|
||||
# | ____/__\_____ ____/__\____
|
||||
# | |_Emb_|____|__| ... |_Emb_|__|___|
|
||||
# input:
|
||||
# [ dense features ] [sparse indices] , ..., [sparse indices]
|
||||
#
|
||||
# More precise definition of model layers:
|
||||
# 1) fully connected layers of an mlp
|
||||
# z = f(y)
|
||||
# y = Wx + b
|
||||
#
|
||||
# 2) embedding lookup (for a list of sparse indices p=[p1,...,pk])
|
||||
# z = Op(e1,...,ek)
|
||||
# obtain vectors e1=E[:,p1], ..., ek=E[:,pk]
|
||||
#
|
||||
# 3) Operator Op can be one of the following
|
||||
# Sum(e1,...,ek) = e1 + ... + ek
|
||||
# Dot(e1,...,ek) = [e1'e1, ..., e1'ek, ..., ek'e1, ..., ek'ek]
|
||||
# Cat(e1,...,ek) = [e1', ..., ek']'
|
||||
# where ' denotes transpose operation
|
||||
#
|
||||
# References:
|
||||
# [1] Maxim Naumov, Dheevatsa Mudigere, Hao-Jun Michael Shi, Jianyu Huang,
|
||||
# Narayanan Sundaram, Jongsoo Park, Xiaodong Wang, Udit Gupta, Carole-Jean Wu,
|
||||
# Alisson G. Azzolini, Dmytro Dzhulgakov, Andrey Mallevich, Ilia Cherniavskii,
|
||||
# Yinghai Lu, Raghuraman Krishnamoorthi, Ansha Yu, Volodymyr Kondratenko,
|
||||
# Stephanie Pereira, Xianjie Chen, Wenlin Chen, Vijay Rao, Bill Jia, Liang Xiong,
|
||||
# Misha Smelyanskiy, "Deep Learning Recommendation Model for Personalization and
|
||||
# Recommendation Systems", CoRR, arXiv:1906.00091, 2019
|
||||
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
### define dlrm in PyTorch ###
|
||||
class DLRM_Net(nn.Module):
|
||||
def create_mlp(self, ln, sigmoid_layer):
|
||||
# build MLP layer by layer
|
||||
layers = nn.ModuleList()
|
||||
for i in range(0, ln.size - 1):
|
||||
n = ln[i]
|
||||
m = ln[i + 1]
|
||||
|
||||
# construct fully connected operator
|
||||
LL = nn.Linear(int(n), int(m), bias=True)
|
||||
|
||||
# initialize the weights
|
||||
# with torch.no_grad():
|
||||
# custom Xavier input, output or two-sided fill
|
||||
|
||||
mean = 0.0 # std_dev = np.sqrt(variance)
|
||||
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
|
||||
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
|
||||
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
|
||||
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
|
||||
LL.weight.data = torch.tensor(W, requires_grad=True)
|
||||
LL.bias.data = torch.tensor(bt, requires_grad=True)
|
||||
|
||||
# approach 2
|
||||
# LL.weight.data.copy_(torch.tensor(W))
|
||||
# LL.bias.data.copy_(torch.tensor(bt))
|
||||
# approach 3
|
||||
# LL.weight = Parameter(torch.tensor(W),requires_grad=True)
|
||||
# LL.bias = Parameter(torch.tensor(bt),requires_grad=True)
|
||||
layers.append(LL)
|
||||
|
||||
# construct sigmoid or relu operator
|
||||
if i == sigmoid_layer:
|
||||
layers.append(nn.Sigmoid())
|
||||
else:
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
# approach 1: use ModuleList
|
||||
# return layers
|
||||
# approach 2: use Sequential container to wrap all layers
|
||||
return torch.nn.Sequential(*layers)
|
||||
|
||||
def create_emb(self, m, ln, weighted_pooling=None):
|
||||
emb_l = nn.ModuleList()
|
||||
v_W_l = []
|
||||
for i in range(0, ln.size):
|
||||
n = ln[i]
|
||||
|
||||
# construct embedding operator
|
||||
EE = nn.EmbeddingBag(n, m, mode="sum")
|
||||
# initialize embeddings
|
||||
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n))
|
||||
W = np.random.uniform(
|
||||
low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m)
|
||||
).astype(np.float32)
|
||||
# approach 1
|
||||
print(W)
|
||||
EE.weight.data = torch.tensor(W, requires_grad=True)
|
||||
# approach 2
|
||||
# EE.weight.data.copy_(torch.tensor(W))
|
||||
# approach 3
|
||||
# EE.weight = Parameter(torch.tensor(W),requires_grad=True)
|
||||
if weighted_pooling is None:
|
||||
v_W_l.append(None)
|
||||
else:
|
||||
v_W_l.append(torch.ones(n, dtype=torch.float32))
|
||||
emb_l.append(EE)
|
||||
return emb_l, v_W_l
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
m_spa=None,
|
||||
ln_emb=None,
|
||||
ln_bot=None,
|
||||
ln_top=None,
|
||||
arch_interaction_op=None,
|
||||
arch_interaction_itself=False,
|
||||
sigmoid_bot=-1,
|
||||
sigmoid_top=-1,
|
||||
weighted_pooling=None,
|
||||
):
|
||||
super(DLRM_Net, self).__init__()
|
||||
|
||||
if (
|
||||
(m_spa is not None)
|
||||
and (ln_emb is not None)
|
||||
and (ln_bot is not None)
|
||||
and (ln_top is not None)
|
||||
and (arch_interaction_op is not None)
|
||||
):
|
||||
|
||||
# save arguments
|
||||
self.output_d = 0
|
||||
self.arch_interaction_op = arch_interaction_op
|
||||
self.arch_interaction_itself = arch_interaction_itself
|
||||
if weighted_pooling is not None and weighted_pooling != "fixed":
|
||||
self.weighted_pooling = "learned"
|
||||
else:
|
||||
self.weighted_pooling = weighted_pooling
|
||||
|
||||
# create operators
|
||||
self.emb_l, w_list = self.create_emb(
|
||||
m_spa, ln_emb, weighted_pooling
|
||||
)
|
||||
if self.weighted_pooling == "learned":
|
||||
self.v_W_l = nn.ParameterList()
|
||||
for w in w_list:
|
||||
self.v_W_l.append(nn.Parameter(w))
|
||||
else:
|
||||
self.v_W_l = w_list
|
||||
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot)
|
||||
self.top_l = self.create_mlp(ln_top, sigmoid_top)
|
||||
|
||||
def apply_mlp(self, x, layers):
|
||||
return layers(x)
|
||||
|
||||
def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
|
||||
# WARNING: notice that we are processing the batch at once. We implicitly
|
||||
# assume that the data is laid out such that:
|
||||
# 1. each embedding is indexed with a group of sparse indices,
|
||||
# corresponding to a single lookup
|
||||
# 2. for each embedding the lookups are further organized into a batch
|
||||
# 3. for a list of embedding tables there is a list of batched lookups
|
||||
# TORCH-MLIR
|
||||
# We are passing all the embeddings as arguments for easy parsing.
|
||||
|
||||
ly = []
|
||||
for k, sparse_index_group_batch in enumerate(lS_i):
|
||||
sparse_offset_group_batch = lS_o[k]
|
||||
|
||||
# embedding lookup
|
||||
# We are using EmbeddingBag, which implicitly uses sum operator.
|
||||
# The embeddings are represented as tall matrices, with sum
|
||||
# happening vertically across 0 axis, resulting in a row vector
|
||||
# E = emb_l[k]
|
||||
|
||||
if v_W_l[k] is not None:
|
||||
per_sample_weights = v_W_l[k].gather(
|
||||
0, sparse_index_group_batch
|
||||
)
|
||||
else:
|
||||
per_sample_weights = None
|
||||
|
||||
E = emb_l[k]
|
||||
V = E(
|
||||
sparse_index_group_batch,
|
||||
sparse_offset_group_batch,
|
||||
per_sample_weights=per_sample_weights,
|
||||
)
|
||||
|
||||
ly.append(V)
|
||||
|
||||
return ly
|
||||
|
||||
def interact_features(self, x, ly):
|
||||
|
||||
if self.arch_interaction_op == "dot":
|
||||
# concatenate dense and sparse features
|
||||
(batch_size, d) = x.shape
|
||||
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
|
||||
# perform a dot product
|
||||
Z = torch.bmm(T, torch.transpose(T, 1, 2))
|
||||
# append dense feature with the interactions (into a row vector)
|
||||
# approach 1: all
|
||||
# Zflat = Z.view((batch_size, -1))
|
||||
# approach 2: unique
|
||||
_, ni, nj = Z.shape
|
||||
# approach 1: tril_indices
|
||||
# offset = 0 if self.arch_interaction_itself else -1
|
||||
# li, lj = torch.tril_indices(ni, nj, offset=offset)
|
||||
# approach 2: custom
|
||||
offset = 1 if self.arch_interaction_itself else 0
|
||||
li = torch.tensor(
|
||||
[i for i in range(ni) for j in range(i + offset)]
|
||||
)
|
||||
lj = torch.tensor(
|
||||
[j for i in range(nj) for j in range(i + offset)]
|
||||
)
|
||||
Zflat = Z[:, li, lj]
|
||||
# concatenate dense features and interactions
|
||||
R = torch.cat([x] + [Zflat], dim=1)
|
||||
elif self.arch_interaction_op == "cat":
|
||||
# concatenation features (into a row vector)
|
||||
R = torch.cat([x] + ly, dim=1)
|
||||
else:
|
||||
sys.exit(
|
||||
"ERROR: --arch-interaction-op="
|
||||
+ self.arch_interaction_op
|
||||
+ " is not supported"
|
||||
)
|
||||
|
||||
return R
|
||||
|
||||
def forward(self, dense_x, lS_o, *lS_i):
|
||||
return self.sequential_forward(dense_x, lS_o, lS_i)
|
||||
|
||||
def sequential_forward(self, dense_x, lS_o, lS_i):
|
||||
# process dense features (using bottom mlp), resulting in a row vector
|
||||
x = self.apply_mlp(dense_x, self.bot_l)
|
||||
# debug prints
|
||||
# print("intermediate")
|
||||
# print(x.detach().cpu().numpy())
|
||||
|
||||
# process sparse features(using embeddings), resulting in a list of row vectors
|
||||
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
|
||||
# for y in ly:
|
||||
# print(y.detach().cpu().numpy())
|
||||
|
||||
# interact features (dense and sparse)
|
||||
z = self.interact_features(x, ly)
|
||||
# print(z.detach().cpu().numpy())
|
||||
|
||||
# obtain probability of a click (using top mlp)
|
||||
p = self.apply_mlp(z, self.top_l)
|
||||
|
||||
# # clamp output if needed
|
||||
# if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
|
||||
# z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
|
||||
# else:
|
||||
# z = p
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def dash_separated_ints(value):
|
||||
vals = value.split("-")
|
||||
for val in vals:
|
||||
try:
|
||||
int(val)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is not a valid dash separated list of ints" % value
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# model related parameters
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train Deep Learning Recommendation Model (DLRM)"
|
||||
)
|
||||
parser.add_argument("--arch-sparse-feature-size", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--arch-embedding-size", type=dash_separated_ints, default="4-3-2"
|
||||
)
|
||||
# j will be replaced with the table number
|
||||
parser.add_argument(
|
||||
"--arch-mlp-bot", type=dash_separated_ints, default="4-3-2"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-mlp-top", type=dash_separated_ints, default="8-2-1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-interaction-op", type=str, choices=["dot", "cat"], default="dot"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-interaction-itself", action="store_true", default=False
|
||||
)
|
||||
parser.add_argument("--weighted-pooling", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
|
||||
ln_top = np.fromstring(args.arch_mlp_top, dtype=int, sep="-")
|
||||
m_den = ln_bot[0]
|
||||
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-")
|
||||
m_spa = args.arch_sparse_feature_size
|
||||
ln_emb = np.asarray(ln_emb)
|
||||
num_fea = ln_emb.size + 1 # num sparse + num dense features
|
||||
|
||||
|
||||
# Initialize the model.
|
||||
dlrm_model = DLRM_Net(
|
||||
m_spa=m_spa,
|
||||
ln_emb=ln_emb,
|
||||
ln_bot=ln_bot,
|
||||
ln_top=ln_top,
|
||||
arch_interaction_op=args.arch_interaction_op,
|
||||
)
|
||||
|
||||
|
||||
# Inputs to the model.
|
||||
dense_inp = torch.tensor([[0.6965, 0.2861, 0.2269, 0.5513]])
|
||||
vs0 = torch.tensor([[0], [0], [0]], dtype=torch.int64)
|
||||
vsi = torch.tensor([1, 2, 3]), torch.tensor([1]), torch.tensor([1])
|
||||
|
||||
input_dlrm = (dense_inp, vs0, *vsi)
|
||||
|
||||
golden_output = dlrm_model(dense_inp, vs0, *vsi)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
dlrm_model,
|
||||
input_dlrm,
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
dlrm_mlir, func_name, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(input_dlrm)
|
||||
np.testing.assert_allclose(
|
||||
golden_output.detach().numpy(), result, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
|
||||
|
||||
# Verified via torch-mlir.
|
||||
# import torch_mlir
|
||||
# from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
||||
# module = torch_mlir.compile(
|
||||
# dlrm_model, inputs, use_tracing=True, output_type="linalg-on-tensors"
|
||||
# )
|
||||
# backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
# compiled = backend.compile(module)
|
||||
# jit_module = backend.load(compiled)
|
||||
|
||||
# dense_numpy = dense_inp.numpy()
|
||||
# vs0_numpy = vs0.numpy()
|
||||
# vsi_numpy = [inp.numpy() for inp in vsi]
|
||||
|
||||
# numpy_inp = (dense_numpy, vs0_numpy, *vsi_numpy)
|
||||
|
||||
# print(jit_module.forward(*numpy_inp))
|
||||
@@ -1,314 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchrec.datasets.utils import Batch
|
||||
from torchrec.modules.crossnet import LowRankCrossNet
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from torchrec.models.dlrm import (
|
||||
choose,
|
||||
DenseArch,
|
||||
DLRM,
|
||||
InteractionArch,
|
||||
SparseArch,
|
||||
OverArch,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
def calculate_offsets(tensor_list, prev_values, prev_offsets):
|
||||
offset_init = 0
|
||||
offset_list = []
|
||||
values_list = []
|
||||
|
||||
if prev_offsets != None:
|
||||
offset_init = prev_values.shape[-1]
|
||||
for tensor in tensor_list:
|
||||
offset_list.append(offset_init)
|
||||
offset_init += tensor.shape[0]
|
||||
|
||||
concatendated_tensor_list = torch.cat(tensor_list)
|
||||
|
||||
if prev_values != None:
|
||||
concatendated_tensor_list = torch.cat(
|
||||
[prev_values, concatendated_tensor_list]
|
||||
)
|
||||
|
||||
concatenated_offsets = torch.tensor(offset_list)
|
||||
|
||||
if prev_offsets != None:
|
||||
concatenated_offsets = torch.cat([prev_offsets, concatenated_offsets])
|
||||
|
||||
return concatendated_tensor_list, concatenated_offsets
|
||||
|
||||
|
||||
# Have to make combined_keys as dict as to which embedding bags they
|
||||
# point to. {f1: 0, f3: 0, f2: 1}
|
||||
# The result will be a triple containing values, indices and pointer tensor.
|
||||
def to_list(key_jagged, combined_keys):
|
||||
key_jagged_dict = key_jagged.to_dict()
|
||||
combined_list = []
|
||||
|
||||
for key in combined_keys:
|
||||
prev_values, prev_offsets = calculate_offsets(
|
||||
key_jagged_dict[key].to_dense(), None, None
|
||||
)
|
||||
print(prev_values)
|
||||
print(prev_offsets)
|
||||
combined_list.append(prev_values)
|
||||
combined_list.append(prev_offsets)
|
||||
combined_list.append(torch.tensor(combined_keys[key]))
|
||||
|
||||
return combined_list
|
||||
|
||||
|
||||
class SparseArchShark(nn.Module):
|
||||
def create_emb(self, embedding_dim, num_embeddings_list):
|
||||
embedding_list = nn.ModuleList()
|
||||
for i in range(0, num_embeddings_list.size):
|
||||
num_embeddings = num_embeddings_list[i]
|
||||
EE = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum")
|
||||
W = np.random.uniform(
|
||||
low=-np.sqrt(1 / num_embeddings),
|
||||
high=np.sqrt(1 / num_embeddings),
|
||||
size=(num_embeddings, embedding_dim),
|
||||
).astype(np.float32)
|
||||
EE.weight.data = torch.tensor(W, requires_grad=True)
|
||||
embedding_list.append(EE)
|
||||
return embedding_list
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
total_features,
|
||||
num_embeddings_list,
|
||||
):
|
||||
super(SparseArchShark, self).__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_features = total_features
|
||||
self.embedding_list = self.create_emb(
|
||||
embedding_dim, num_embeddings_list
|
||||
)
|
||||
|
||||
def forward(self, *batched_inputs):
|
||||
|
||||
concatenated_list = []
|
||||
input_enum, embedding_enum = 0, 0
|
||||
|
||||
for k in range(len(batched_inputs) // 3):
|
||||
values = batched_inputs[input_enum]
|
||||
input_enum += 1
|
||||
offsets = batched_inputs[input_enum]
|
||||
input_enum += 1
|
||||
embedding_pointer = int(batched_inputs[input_enum])
|
||||
input_enum += 1
|
||||
|
||||
E = self.embedding_list[embedding_pointer]
|
||||
V = E(values, offsets)
|
||||
concatenated_list.append(V)
|
||||
|
||||
return torch.cat(concatenated_list, dim=1).reshape(
|
||||
-1, self.num_features, self.embedding_dim
|
||||
)
|
||||
|
||||
|
||||
def test_sparse_arch() -> None:
|
||||
|
||||
D = 3
|
||||
eb1_config = EmbeddingBagConfig(
|
||||
name="t1",
|
||||
embedding_dim=D,
|
||||
num_embeddings=10,
|
||||
feature_names=["f1", "f3"],
|
||||
)
|
||||
eb2_config = EmbeddingBagConfig(
|
||||
name="t2",
|
||||
embedding_dim=D,
|
||||
num_embeddings=10,
|
||||
feature_names=["f2"],
|
||||
)
|
||||
|
||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||
|
||||
w1 = ebc.embedding_bags["t1"].weight
|
||||
w2 = ebc.embedding_bags["t2"].weight
|
||||
|
||||
sparse_arch = SparseArch(ebc)
|
||||
|
||||
keys = ["f1", "f2", "f3", "f4", "f5"]
|
||||
offsets = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 19])
|
||||
features = KeyedJaggedTensor.from_offsets_sync(
|
||||
keys=keys,
|
||||
values=torch.tensor(
|
||||
[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]
|
||||
),
|
||||
offsets=offsets,
|
||||
)
|
||||
sparse_archi = SparseArchShark(D, 3, np.array([10, 10]))
|
||||
sparse_archi.embedding_list[0].weight = w1
|
||||
sparse_archi.embedding_list[1].weight = w2
|
||||
inputs = to_list(features, {"f1": 0, "f3": 0, "f2": 1})
|
||||
|
||||
test_results = sparse_archi(*inputs)
|
||||
sparse_features = sparse_arch(features)
|
||||
|
||||
torch.allclose(
|
||||
sparse_features,
|
||||
test_results,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
|
||||
test_sparse_arch()
|
||||
|
||||
|
||||
class DLRMShark(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
total_features,
|
||||
num_embeddings_list,
|
||||
dense_in_features: int,
|
||||
dense_arch_layer_sizes: List[int],
|
||||
over_arch_layer_sizes: List[int],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.sparse_arch: SparseArchShark = SparseArchShark(
|
||||
embedding_dim, total_features, num_embeddings_list
|
||||
)
|
||||
num_sparse_features: int = total_features
|
||||
|
||||
self.dense_arch = DenseArch(
|
||||
in_features=dense_in_features,
|
||||
layer_sizes=dense_arch_layer_sizes,
|
||||
)
|
||||
|
||||
self.inter_arch = InteractionArch(
|
||||
num_sparse_features=num_sparse_features,
|
||||
)
|
||||
|
||||
over_in_features: int = (
|
||||
embedding_dim
|
||||
+ choose(num_sparse_features, 2)
|
||||
+ num_sparse_features
|
||||
)
|
||||
|
||||
self.over_arch = OverArch(
|
||||
in_features=over_in_features,
|
||||
layer_sizes=over_arch_layer_sizes,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, dense_features: torch.Tensor, *sparse_features
|
||||
) -> torch.Tensor:
|
||||
|
||||
embedded_dense = self.dense_arch(dense_features)
|
||||
embedded_sparse = self.sparse_arch(*sparse_features)
|
||||
concatenated_dense = self.inter_arch(
|
||||
dense_features=embedded_dense, sparse_features=embedded_sparse
|
||||
)
|
||||
logits = self.over_arch(concatenated_dense)
|
||||
return logits
|
||||
|
||||
|
||||
def test_dlrm() -> None:
|
||||
B = 2
|
||||
D = 8
|
||||
dense_in_features = 100
|
||||
|
||||
eb1_config = EmbeddingBagConfig(
|
||||
name="t1",
|
||||
embedding_dim=D,
|
||||
num_embeddings=100,
|
||||
feature_names=["f1", "f3"],
|
||||
)
|
||||
eb2_config = EmbeddingBagConfig(
|
||||
name="t2",
|
||||
embedding_dim=D,
|
||||
num_embeddings=100,
|
||||
feature_names=["f2"],
|
||||
)
|
||||
|
||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||
|
||||
sparse_features = KeyedJaggedTensor.from_offsets_sync(
|
||||
keys=["f1", "f3", "f2"],
|
||||
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]),
|
||||
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]),
|
||||
)
|
||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||
sparse_nn = DLRM(
|
||||
embedding_bag_collection=ebc,
|
||||
dense_in_features=dense_in_features,
|
||||
dense_arch_layer_sizes=[20, D],
|
||||
over_arch_layer_sizes=[5, 1],
|
||||
)
|
||||
sparse_nn_nod = DLRMShark(
|
||||
embedding_dim=8,
|
||||
total_features=3,
|
||||
num_embeddings_list=np.array([100, 100]),
|
||||
dense_in_features=dense_in_features,
|
||||
dense_arch_layer_sizes=[20, D],
|
||||
over_arch_layer_sizes=[5, 1],
|
||||
)
|
||||
|
||||
dense_features = torch.rand((B, dense_in_features))
|
||||
|
||||
x = to_list(sparse_features, {"f1": 0, "f3": 0, "f2": 1})
|
||||
|
||||
w1 = ebc.embedding_bags["t1"].weight
|
||||
w2 = ebc.embedding_bags["t2"].weight
|
||||
|
||||
sparse_nn_nod.sparse_arch.embedding_list[0].weight = w1
|
||||
sparse_nn_nod.sparse_arch.embedding_list[1].weight = w2
|
||||
|
||||
sparse_nn_nod.dense_arch.load_state_dict(sparse_nn.dense_arch.state_dict())
|
||||
sparse_nn_nod.inter_arch.load_state_dict(sparse_nn.inter_arch.state_dict())
|
||||
sparse_nn_nod.over_arch.load_state_dict(sparse_nn.over_arch.state_dict())
|
||||
|
||||
logits = sparse_nn(
|
||||
dense_features=dense_features,
|
||||
sparse_features=sparse_features,
|
||||
)
|
||||
logits_nod = sparse_nn_nod(dense_features, *x)
|
||||
|
||||
# print(logits)
|
||||
# print(logits_nod)
|
||||
|
||||
# Import the module and print.
|
||||
mlir_importer = SharkImporter(
|
||||
sparse_nn_nod,
|
||||
(dense_features, *x),
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
dlrm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
|
||||
torch.allclose(
|
||||
logits,
|
||||
logits_nod,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
|
||||
test_dlrm()
|
||||
@@ -1,272 +0,0 @@
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
from tqdm.auto import tqdm
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import numpy as np
|
||||
|
||||
# pip install diffusers
|
||||
# pip install scipy
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a photograph of an astronaut riding a horse",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument("--steps", type=int, default=10, help="the device to use")
|
||||
p.add_argument("--mlir_loc", type=str, default=None, help="the device to use")
|
||||
p.add_argument("--vae_loc", type=str, default=None, help="the device to use")
|
||||
args = p.parse_args()
|
||||
|
||||
#####################################################
|
||||
|
||||
|
||||
def load_mlir(mlir_loc):
|
||||
import os
|
||||
|
||||
if mlir_loc == None:
|
||||
return None
|
||||
print(f"Trying to load the model from {mlir_loc}.")
|
||||
with open(os.path.join(mlir_loc)) as f:
|
||||
mlir_module = f.read()
|
||||
return mlir_module
|
||||
|
||||
|
||||
def compile_through_fx(model, inputs, mlir_loc=None, extra_args=[]):
|
||||
|
||||
module = load_mlir(mlir_loc)
|
||||
if mlir_loc == None:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(*inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
inputs,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile(extra_args)
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
# 1. Load the autoencoder model which will be used to decode the latents into image space.
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.vae.decode(input, return_dict=False)[0]
|
||||
|
||||
vae = VaeModel()
|
||||
vae_input = torch.rand(1, 4, 64, 64)
|
||||
shark_vae = compile_through_fx(vae, (vae_input,), args.vae_loc)
|
||||
|
||||
# Wrap the unet model to return tuples.
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
# 3. The UNet model for generating the latents.
|
||||
unet = UnetModel()
|
||||
latent_model_input = torch.rand([2, 4, 64, 64])
|
||||
text_embeddings = torch.rand([2, 77, 768])
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
(latent_model_input, torch.tensor([1.0]), text_embeddings),
|
||||
args.mlir_loc,
|
||||
["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
|
||||
)
|
||||
|
||||
# torch.jit.script(unet)
|
||||
|
||||
scheduler = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
prompt = [args.prompt]
|
||||
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
|
||||
num_inference_steps = args.steps # Number of denoising steps
|
||||
|
||||
guidance_scale = 7.5 # Scale for classifier-free guidance
|
||||
|
||||
generator = torch.manual_seed(
|
||||
42
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_input = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_embeddings = text_encoder(text_input.input_ids)[0]
|
||||
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer(
|
||||
[""] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = text_encoder(uncond_input.input_ids)[0]
|
||||
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
)
|
||||
# latents = latents.to(torch_device)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
latents = latents * scheduler.sigmas[0]
|
||||
# print(latents, latents.shape)
|
||||
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
|
||||
print(f"i = {i} t = {t}")
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
sigma = scheduler.sigmas[i]
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
# with torch.no_grad():
|
||||
# noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
|
||||
|
||||
latent_model_input_numpy = latent_model_input.detach().numpy()
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
|
||||
noise_pred = shark_unet.forward(
|
||||
(
|
||||
latent_model_input_numpy,
|
||||
np.array([t]).astype(np.float32),
|
||||
text_embeddings_numpy,
|
||||
)
|
||||
)
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
|
||||
|
||||
# print("Latents shape : ", latents.shape)
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents.detach().numpy()
|
||||
image = shark_vae.forward((latents_numpy,))
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
pil_images[0].save("astro.jpg")
|
||||
@@ -1,280 +0,0 @@
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
from tqdm.auto import tqdm
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import numpy as np
|
||||
|
||||
# pip install diffusers
|
||||
# pip install scipy
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a photograph of an astronaut riding a horse",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument("--steps", type=int, default=50, help="the device to use")
|
||||
p.add_argument("--mlir_loc", type=str, default=None, help="the device to use")
|
||||
p.add_argument("--vae_loc", type=str, default=None, help="the device to use")
|
||||
args = p.parse_args()
|
||||
|
||||
#####################################################
|
||||
|
||||
|
||||
def fp16_unet():
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"stable_diff_f16_18_OCT",
|
||||
tank_url="gs://shark_tank/prashant_nod",
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
return shark_module
|
||||
|
||||
|
||||
def load_mlir(mlir_loc):
|
||||
import os
|
||||
|
||||
if mlir_loc == None:
|
||||
return None
|
||||
print(f"Trying to load the model from {mlir_loc}.")
|
||||
with open(os.path.join(mlir_loc)) as f:
|
||||
mlir_module = f.read()
|
||||
return mlir_module
|
||||
|
||||
|
||||
def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
|
||||
module = load_mlir(mlir_loc)
|
||||
if mlir_loc == None:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(*inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
inputs,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
# 1. Load the autoencoder model which will be used to decode the latents into image space.
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.vae.decode(input, return_dict=False)[0]
|
||||
|
||||
vae = VaeModel()
|
||||
vae_input = torch.rand(1, 4, 64, 64)
|
||||
shark_vae = compile_through_fx(vae, (vae_input,), args.vae_loc)
|
||||
|
||||
# Wrap the unet model to return tuples.
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
# # 3. The UNet model for generating the latents.
|
||||
unet = UnetModel()
|
||||
|
||||
shark_unet = fp16_unet()
|
||||
|
||||
scheduler = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
prompt = [args.prompt]
|
||||
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
|
||||
num_inference_steps = args.steps # Number of denoising steps
|
||||
|
||||
guidance_scale = 7.5 # Scale for classifier-free guidance
|
||||
|
||||
generator = torch.manual_seed(
|
||||
42
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_input = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_embeddings = text_encoder(text_input.input_ids)[0]
|
||||
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer(
|
||||
[""] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = text_encoder(uncond_input.input_ids)[0]
|
||||
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
)
|
||||
# latents = latents.to(torch_device)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
latents = latents * scheduler.sigmas[0]
|
||||
# print(latents, latents.shape)
|
||||
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
|
||||
print(f"i = {i} t = {t}")
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
sigma = scheduler.sigmas[i]
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
# with torch.no_grad():
|
||||
# noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
|
||||
|
||||
latent_model_input_numpy = (
|
||||
latent_model_input.detach().numpy().astype(np.half)
|
||||
)
|
||||
text_embeddings_numpy = (
|
||||
text_embeddings.detach().numpy().astype(np.half)
|
||||
)
|
||||
|
||||
noise_pred = shark_unet.forward(
|
||||
(
|
||||
latent_model_input_numpy,
|
||||
np.array([t]).astype(np.half),
|
||||
text_embeddings_numpy,
|
||||
)
|
||||
)
|
||||
noise_pred = torch.from_numpy(noise_pred).to(torch.float32)
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
|
||||
|
||||
# print("Latents shape : ", latents.shape)
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents.detach().numpy()
|
||||
image = shark_vae.forward((latents_numpy,))
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
pil_images[0].save("astro.jpg")
|
||||
@@ -1,313 +0,0 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from keras_cv.models.generative.stable_diffusion.clip_tokenizer import (
|
||||
SimpleTokenizer,
|
||||
)
|
||||
from keras_cv.models.generative.stable_diffusion.constants import (
|
||||
_ALPHAS_CUMPROD,
|
||||
)
|
||||
from keras_cv.models.generative.stable_diffusion.constants import (
|
||||
_UNCONDITIONAL_TOKENS,
|
||||
)
|
||||
from keras_cv.models.generative.stable_diffusion.decoder import Decoder
|
||||
from keras_cv.models.generative.stable_diffusion.text_encoder import (
|
||||
TextEncoder,
|
||||
)
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
from PIL import Image
|
||||
|
||||
# pip install "git+https://github.com/keras-team/keras-cv.git"
|
||||
# pip install tensorflow_dataset
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a photograph of an astronaut riding a horse",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument(
|
||||
"--steps", type=int, default=10, help="the number of steps to use"
|
||||
)
|
||||
p.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="the file to save the resulting image to. (default to <input prompt>.jpg)",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
#####################################################
|
||||
|
||||
MAX_PROMPT_LENGTH = 77
|
||||
|
||||
|
||||
class SharkStableDiffusion:
|
||||
"""Shark implementation of Stable Diffusion based on model from keras_cv.
|
||||
Stable Diffusion is a powerful image generation model that can be used,
|
||||
among other things, to generate pictures according to a short text description
|
||||
(called a "prompt").
|
||||
Arguments:
|
||||
device: Device to use with SHARK. Default: cpu
|
||||
jit_compile: Whether to compile the underlying models to XLA.
|
||||
This can lead to a significant speedup on some systems. Default: False.
|
||||
References:
|
||||
- [About Stable Diffusion](https://stability.ai/blog/stable-diffusion-announcement)
|
||||
- [Original implementation](https://github.com/CompVis/stable-diffusion)
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", jit_compile=True):
|
||||
self.img_height = 512
|
||||
self.img_width = 512
|
||||
self.tokenizer = SimpleTokenizer()
|
||||
|
||||
# Create models
|
||||
self.text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"stable_diff", tank_url="gs://shark_tank/quinn", frontend="tf"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
self.diffusion_model = shark_module
|
||||
self.decoder = Decoder(self.img_height, self.img_width)
|
||||
if jit_compile:
|
||||
self.text_encoder.compile(jit_compile=True)
|
||||
self.decoder.compile(jit_compile=True)
|
||||
|
||||
print(
|
||||
"By using this model checkpoint, you acknowledge that its usage is "
|
||||
"subject to the terms of the CreativeML Open RAIL-M license at "
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE"
|
||||
)
|
||||
# Load weights
|
||||
text_encoder_weights_fpath = keras.utils.get_file(
|
||||
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
|
||||
file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
|
||||
)
|
||||
decoder_weights_fpath = keras.utils.get_file(
|
||||
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5",
|
||||
file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962",
|
||||
)
|
||||
self.text_encoder.load_weights(text_encoder_weights_fpath)
|
||||
self.decoder.load_weights(decoder_weights_fpath)
|
||||
|
||||
def text_to_image(
|
||||
self,
|
||||
prompt,
|
||||
batch_size=1,
|
||||
num_steps=25,
|
||||
unconditional_guidance_scale=7.5,
|
||||
seed=None,
|
||||
):
|
||||
encoded_text = self.encode_text(prompt)
|
||||
|
||||
return self.generate_image(
|
||||
encoded_text,
|
||||
batch_size=batch_size,
|
||||
num_steps=num_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def encode_text(self, prompt):
|
||||
"""Encodes a prompt into a latent text encoding.
|
||||
The encoding produced by this method should be used as the
|
||||
`encoded_text` parameter of `StableDiffusion.generate_image`. Encoding
|
||||
text separately from generating an image can be used to arbitrarily
|
||||
modify the text encoding priot to image generation, e.g. for walking
|
||||
between two prompts.
|
||||
Args:
|
||||
prompt: a string to encode, must be 77 tokens or shorter.
|
||||
Example:
|
||||
```python
|
||||
from keras_cv.models import StableDiffusion
|
||||
model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
|
||||
encoded_text = model.encode_text("Tacos at dawn")
|
||||
img = model.generate_image(encoded_text)
|
||||
```
|
||||
"""
|
||||
# Tokenize prompt (i.e. starting context)
|
||||
inputs = self.tokenizer.encode(prompt)
|
||||
if len(inputs) > MAX_PROMPT_LENGTH:
|
||||
raise ValueError(
|
||||
f"Prompt is too long (should be <= {MAX_PROMPT_LENGTH} tokens)"
|
||||
)
|
||||
phrase = inputs + [49407] * (MAX_PROMPT_LENGTH - len(inputs))
|
||||
phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)
|
||||
|
||||
context = self.text_encoder.predict_on_batch(
|
||||
[phrase, self._get_pos_ids()]
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
encoded_text,
|
||||
batch_size=1,
|
||||
num_steps=25,
|
||||
unconditional_guidance_scale=7.5,
|
||||
diffusion_noise=None,
|
||||
seed=None,
|
||||
):
|
||||
"""Generates an image based on encoded text.
|
||||
The encoding passed to this method should be derived from
|
||||
`StableDiffusion.encode_text`.
|
||||
Args:
|
||||
encoded_text: Tensor of shape (`batch_size`, 77, 768), or a Tensor
|
||||
of shape (77, 768). When the batch axis is omitted, the same encoded
|
||||
text will be used to produce every generated image.
|
||||
batch_size: number of images to generate. Default: 1.
|
||||
num_steps: number of diffusion steps (controls image quality).
|
||||
Default: 25.
|
||||
unconditional_guidance_scale: float controling how closely the image
|
||||
should adhere to the prompt. Larger values result in more
|
||||
closely adhering to the prompt, but will make the image noisier.
|
||||
Default: 7.5.
|
||||
diffusion_noise: Tensor of shape (`batch_size`, img_height // 8,
|
||||
img_width // 8, 4), or a Tensor of shape (img_height // 8,
|
||||
img_width // 8, 4). Optional custom noise to seed the diffusion
|
||||
process. When the batch axis is omitted, the same noise will be
|
||||
used to seed diffusion for every generated image.
|
||||
seed: integer which is used to seed the random generation of
|
||||
diffusion noise, only to be specified if `diffusion_noise` is
|
||||
None.
|
||||
Example:
|
||||
```python
|
||||
from keras_cv.models import StableDiffusion
|
||||
batch_size = 8
|
||||
model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
|
||||
e_tacos = model.encode_text("Tacos at dawn")
|
||||
e_watermelons = model.encode_text("Watermelons at dusk")
|
||||
e_interpolated = tf.linspace(e_tacos, e_watermelons, batch_size)
|
||||
images = model.generate_image(e_interpolated, batch_size=batch_size)
|
||||
```
|
||||
"""
|
||||
if diffusion_noise is not None and seed is not None:
|
||||
raise ValueError(
|
||||
"`diffusion_noise` and `seed` should not both be passed to "
|
||||
"`generate_image`. `seed` is only used to generate diffusion "
|
||||
"noise when it's not already user-specified."
|
||||
)
|
||||
|
||||
encoded_text = tf.squeeze(encoded_text)
|
||||
if encoded_text.shape.rank == 2:
|
||||
encoded_text = tf.repeat(
|
||||
tf.expand_dims(encoded_text, axis=0), batch_size, axis=0
|
||||
)
|
||||
|
||||
context = encoded_text
|
||||
unconditional_context = tf.repeat(
|
||||
self._get_unconditional_context(), batch_size, axis=0
|
||||
)
|
||||
context = tf.concat([context, unconditional_context], 0)
|
||||
|
||||
if diffusion_noise is not None:
|
||||
diffusion_noise = tf.squeeze(diffusion_noise)
|
||||
if diffusion_noise.shape.rank == 3:
|
||||
diffusion_noise = tf.repeat(
|
||||
tf.expand_dims(diffusion_noise, axis=0), batch_size, axis=0
|
||||
)
|
||||
latent = diffusion_noise
|
||||
else:
|
||||
latent = self._get_initial_diffusion_noise(batch_size, seed)
|
||||
|
||||
# Iterative reverse diffusion stage
|
||||
timesteps = tf.range(1, 1000, 1000 // num_steps)
|
||||
alphas, alphas_prev = self._get_initial_alphas(timesteps)
|
||||
progbar = keras.utils.Progbar(len(timesteps))
|
||||
iteration = 0
|
||||
for index, timestep in list(enumerate(timesteps))[::-1]:
|
||||
latent_prev = latent # Set aside the previous latent vector
|
||||
t_emb = self._get_timestep_embedding(timestep, batch_size)
|
||||
|
||||
# Prepare the latent and unconditional latent to be run with a single forward call
|
||||
latent = tf.concat([latent, latent], 0)
|
||||
t_emb = tf.concat([t_emb, t_emb], 0)
|
||||
latent_numpy = self.diffusion_model.forward(
|
||||
[latent.numpy(), t_emb.numpy(), context.numpy()]
|
||||
)
|
||||
latent = tf.convert_to_tensor(latent_numpy, dtype=tf.float32)
|
||||
latent, unconditional_latent = tf.split(latent, 2)
|
||||
|
||||
latent = unconditional_latent + unconditional_guidance_scale * (
|
||||
latent - unconditional_latent
|
||||
)
|
||||
a_t, a_prev = alphas[index], alphas_prev[index]
|
||||
pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(
|
||||
a_t
|
||||
)
|
||||
latent = (
|
||||
latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0
|
||||
)
|
||||
iteration += 1
|
||||
progbar.update(iteration)
|
||||
|
||||
# Decoding stage
|
||||
decoded = self.decoder.predict_on_batch(latent)
|
||||
decoded = ((decoded + 1) / 2) * 255
|
||||
return np.clip(decoded, 0, 255).astype("uint8")
|
||||
|
||||
def _get_unconditional_context(self):
|
||||
unconditional_tokens = tf.convert_to_tensor(
|
||||
[_UNCONDITIONAL_TOKENS], dtype=tf.int32
|
||||
)
|
||||
unconditional_context = self.text_encoder.predict_on_batch(
|
||||
[unconditional_tokens, self._get_pos_ids()]
|
||||
)
|
||||
|
||||
return unconditional_context
|
||||
|
||||
def _get_timestep_embedding(
|
||||
self, timestep, batch_size, dim=320, max_period=10000
|
||||
):
|
||||
half = dim // 2
|
||||
freqs = tf.math.exp(
|
||||
-math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
|
||||
)
|
||||
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
|
||||
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
|
||||
embedding = tf.reshape(embedding, [1, -1])
|
||||
return tf.repeat(embedding, batch_size, axis=0)
|
||||
|
||||
def _get_initial_alphas(self, timesteps):
|
||||
alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
|
||||
alphas_prev = [1.0] + alphas[:-1]
|
||||
|
||||
return alphas, alphas_prev
|
||||
|
||||
def _get_initial_diffusion_noise(self, batch_size, seed):
|
||||
return tf.random.normal(
|
||||
(batch_size, self.img_height // 8, self.img_width // 8, 4),
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_pos_ids():
|
||||
return tf.convert_to_tensor(
|
||||
[list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SD = SharkStableDiffusion(device=args.device)
|
||||
images = SD.text_to_image(args.prompt, num_steps=args.steps)
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
save_fname = args.prompt + ".jpg"
|
||||
if args.save_path is not None:
|
||||
save_fname = args.save_path
|
||||
pil_images[0].save(save_fname)
|
||||
@@ -1,2 +0,0 @@
|
||||
*.vmfb
|
||||
*.jpg
|
||||
@@ -1,56 +0,0 @@
|
||||
# STABLE DIFFUSION
|
||||
|
||||
## Installation
|
||||
|
||||
Follow setup instructions in the main [README.md](https://github.com/nod-ai/SHARK#readme) for regular usage.
|
||||
|
||||
## Debug commands and other advanced usage follows.
|
||||
|
||||
```shell
|
||||
python main.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
|
||||
|
||||
```
|
||||
|
||||
## dump all dispatch .spv and isa using amdllpc
|
||||
|
||||
```shell
|
||||
python main.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
|
||||
```
|
||||
|
||||
## Compile and save the .vmfb (using vulkan fp16 as an example):
|
||||
|
||||
```shell
|
||||
python shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
|
||||
```
|
||||
|
||||
## Capture an RGP trace
|
||||
|
||||
```shell
|
||||
python shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
|
||||
```
|
||||
|
||||
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
|
||||
|
||||
```shell
|
||||
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --device=vulkan --function_input=1x4x64x64xf16
|
||||
```
|
||||
|
||||
## Run the unet module with iree-benchmark-module (same config as above):
|
||||
```shell
|
||||
##if you want to use .npz inputs:
|
||||
unzip ~/.local/shark_tank/<your unet>/inputs.npz
|
||||
|
||||
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --function_input=@arr_0.npy --function_input=1xf16 --function_input=@arr_2.npy --function_input=@arr_3.npy --function_input=@arr_4.npy
|
||||
```
|
||||
|
||||
## Using other supported Stable Diffusion variants with SHARK:
|
||||
|
||||
Currently we support the following fine-tuned versions of Stable Diffusion:
|
||||
- [AnythingV3](https://huggingface.co/Linaqruf/anything-v3.0)
|
||||
- [Analog Diffusion](https://huggingface.co/wavymulder/Analog-Diffusion)
|
||||
|
||||
use the flag `--variant=` to specify the model to be used.
|
||||
|
||||
```shell
|
||||
python .\shark\examples\shark_inference\stable_diffusion\main.py --variant=anythingv3 --max_length=77 --prompt="1girl, brown hair, green eyes, colorful, autumn, cumulonimbus clouds, lighting, blue sky, falling leaves, garden"
|
||||
```
|
||||
@@ -1,25 +0,0 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(
|
||||
text=["a photo of a cat", "a photo of a dog"],
|
||||
images=image,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
|
||||
outputs = model(**inputs)
|
||||
logits_per_image = (
|
||||
outputs.logits_per_image
|
||||
) # this is the image-text similarity score
|
||||
probs = logits_per_image.softmax(
|
||||
dim=1
|
||||
) # we can take the softmax to get the label probabilities
|
||||
@@ -1,254 +0,0 @@
|
||||
import os
|
||||
|
||||
os.environ["AMD_ENABLE_LLPC"] = "1"
|
||||
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import torch
|
||||
from PIL import Image
|
||||
import torchvision.transforms as T
|
||||
from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from stable_args import args
|
||||
|
||||
# This has to come before importing cache objects
|
||||
if args.clear_all:
|
||||
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
|
||||
from glob import glob
|
||||
import shutil
|
||||
|
||||
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
|
||||
for vmfb in vmfbs:
|
||||
if os.path.exists(vmfb):
|
||||
os.remove(vmfb)
|
||||
home = os.path.expanduser("~")
|
||||
if os.name == "nt": # Windows
|
||||
appdata = os.getenv("LOCALAPPDATA")
|
||||
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
|
||||
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
|
||||
elif os.name == "unix":
|
||||
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
|
||||
|
||||
from utils import set_init_device_flags
|
||||
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
from schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
import time
|
||||
import sys
|
||||
from shark.iree_utils.compile_utils import dump_isas
|
||||
|
||||
# Helper function to profile the vulkan device.
|
||||
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
|
||||
if args.vulkan_debug_utils and "vulkan" in args.device:
|
||||
import iree
|
||||
|
||||
print(f"Profiling and saving to {file_path}.")
|
||||
vulkan_device = iree.runtime.get_device(args.device)
|
||||
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
|
||||
return vulkan_device
|
||||
return None
|
||||
|
||||
|
||||
def end_profiling(device):
|
||||
if device:
|
||||
return device.end_profiling()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
|
||||
prompt = args.prompts
|
||||
neg_prompt = args.negative_prompts
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
if args.version == "v2_1":
|
||||
height = 768
|
||||
width = 768
|
||||
|
||||
num_inference_steps = args.steps # Number of denoising steps
|
||||
|
||||
# Scale for classifier-free guidance
|
||||
guidance_scale = torch.tensor(args.guidance_scale).to(torch.float32)
|
||||
|
||||
# Handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
seed = args.seed
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(
|
||||
seed
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
# TODO: Add support for batch_size > 1.
|
||||
batch_size = len(prompt)
|
||||
if batch_size != 1:
|
||||
sys.exit("More than one prompt is not supported yet.")
|
||||
if batch_size != len(neg_prompt):
|
||||
sys.exit("prompts and negative prompts must be of same length")
|
||||
|
||||
set_init_device_flags()
|
||||
clip = get_clip()
|
||||
unet = get_unet()
|
||||
vae = get_vae()
|
||||
if args.dump_isa:
|
||||
dump_isas(args.dispatch_benchmarks_dir)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
cpu_scheduling = True
|
||||
if args.version == "v2_1":
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", subfolder="tokenizer"
|
||||
)
|
||||
|
||||
scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
if args.version == "v2_1base" and args.variant == "stablediffusion":
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
|
||||
)
|
||||
|
||||
if args.use_compiled_scheduler:
|
||||
scheduler = SharkEulerDiscreteScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
scheduler.compile()
|
||||
cpu_scheduling = False
|
||||
else:
|
||||
scheduler = EulerDiscreteScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
# create a random initial latent.
|
||||
latents = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
# Warmup phase to improve performance.
|
||||
if args.warmup_count >= 1:
|
||||
vae_warmup_input = torch.clone(latents).detach().numpy()
|
||||
clip_warmup_input = torch.randint(1, 2, (2, args.max_length))
|
||||
for i in range(args.warmup_count):
|
||||
vae("forward", (vae_warmup_input,))
|
||||
clip("forward", (clip_warmup_input,))
|
||||
|
||||
start = time.time()
|
||||
|
||||
text_input = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=args.max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer(
|
||||
neg_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
|
||||
|
||||
clip_inf_start = time.time()
|
||||
text_embeddings = clip("forward", (text_input,))
|
||||
clip_inf_end = time.time()
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler.is_scale_input_called = True
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
|
||||
avg_ms = 0
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps), disable=args.hide_steps):
|
||||
step_start = time.time()
|
||||
if not args.hide_steps:
|
||||
print(f"i = {i} t = {t}", end="")
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
|
||||
noise_pred = unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
noise_pred = torch.from_numpy(noise_pred.to_host())
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
else:
|
||||
latents = scheduler.step(noise_pred, t, latents)
|
||||
step_time = time.time() - step_start
|
||||
avg_ms += step_time
|
||||
step_ms = int((step_time) * 1000)
|
||||
if not args.hide_steps:
|
||||
print(f" ({step_ms}ms)")
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
if args.use_base_vae:
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
profile_device = start_profiling(file_path="vae.rdc")
|
||||
vae_start = time.time()
|
||||
images = vae("forward", (latents_numpy,))
|
||||
vae_end = time.time()
|
||||
end_profiling(profile_device)
|
||||
if args.use_base_vae:
|
||||
image = torch.from_numpy(images)
|
||||
image = (image.detach().cpu() * 255.0).numpy()
|
||||
images = image.round()
|
||||
end_time = time.time()
|
||||
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
|
||||
vae_inf_time = (vae_end - vae_start) * 1000
|
||||
total_time = end_time - start
|
||||
print(f"\nAverage step time: {avg_ms}ms/it")
|
||||
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
|
||||
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
|
||||
print(f"\nTotal image generation time: {total_time}sec")
|
||||
|
||||
transform = T.ToPILImage()
|
||||
pil_images = [
|
||||
transform(image) for image in torch.from_numpy(images).to(torch.uint8)
|
||||
]
|
||||
for i in range(batch_size):
|
||||
pil_images[i].save(f"{args.prompts[i]}_{i}.jpg")
|
||||
@@ -1,285 +0,0 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from utils import compile_through_fx
|
||||
from stable_args import args
|
||||
import torch
|
||||
|
||||
model_config = {
|
||||
"v2_1": "stabilityai/stable-diffusion-2-1",
|
||||
"v2_1base": "stabilityai/stable-diffusion-2-1-base",
|
||||
"v1_4": "CompVis/stable-diffusion-v1-4",
|
||||
}
|
||||
|
||||
# clip has 2 variants of max length 77 or 64.
|
||||
model_clip_max_length = 64 if args.max_length == 64 else 77
|
||||
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
model_clip_max_length = 77
|
||||
elif args.variant == "openjourney":
|
||||
model_clip_max_length = 64
|
||||
|
||||
model_variant = {
|
||||
"stablediffusion": "SD",
|
||||
"anythingv3": "Linaqruf/anything-v3.0",
|
||||
"dreamlike": "dreamlike-art/dreamlike-diffusion-1.0",
|
||||
"openjourney": "prompthero/openjourney",
|
||||
"analogdiffusion": "wavymulder/Analog-Diffusion",
|
||||
}
|
||||
|
||||
model_input = {
|
||||
"v2_1": {
|
||||
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
|
||||
"vae": (torch.randn(1, 4, 96, 96),),
|
||||
"unet": (
|
||||
torch.randn(1, 4, 96, 96), # latents
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, model_clip_max_length, 1024), # embedding
|
||||
torch.tensor(1).to(torch.float32), # guidance_scale
|
||||
),
|
||||
},
|
||||
"v2_1base": {
|
||||
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
|
||||
"vae": (torch.randn(1, 4, 64, 64),),
|
||||
"unet": (
|
||||
torch.randn(1, 4, 64, 64), # latents
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, model_clip_max_length, 1024), # embedding
|
||||
torch.tensor(1).to(torch.float32), # guidance_scale
|
||||
),
|
||||
},
|
||||
"v1_4": {
|
||||
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
|
||||
"vae": (torch.randn(1, 4, 64, 64),),
|
||||
"unet": (
|
||||
torch.randn(1, 4, 64, 64),
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, model_clip_max_length, 768),
|
||||
torch.tensor(1).to(torch.float32),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# revision param for from_pretrained defaults to "main" => fp32
|
||||
model_revision = {
|
||||
"stablediffusion": "fp16" if args.precision == "fp16" else "main",
|
||||
"anythingv3": "diffusers",
|
||||
"analogdiffusion": "main",
|
||||
"openjourney": "main",
|
||||
"dreamlike": "main",
|
||||
}
|
||||
|
||||
|
||||
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
if args.variant == "stablediffusion":
|
||||
if args.version != "v1_4":
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_config[args.version], subfolder="text_encoder"
|
||||
)
|
||||
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_variant[args.variant],
|
||||
subfolder="text_encoder",
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{args.variant} not yet added")
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
model_input[args.version]["clip"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
|
||||
def get_base_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class BaseVaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="vae",
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
return (x / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
vae = BaseVaeModel()
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["vae"]
|
||||
else:
|
||||
raise ValueError(f"{args.variant} not yet added")
|
||||
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="vae",
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
input = 1 / 0.18215 * input
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
x = (x / 2 + 0.5).clamp(0, 1)
|
||||
x = x * 255.0
|
||||
return x.round()
|
||||
|
||||
vae = VaeModel()
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["vae"]
|
||||
else:
|
||||
raise ValueError(f"{args.variant} not yet added")
|
||||
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="unet",
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, latent, timestep, text_embedding, guidance_scale):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latents = torch.cat([latent] * 2)
|
||||
unet_out = self.unet.forward(
|
||||
latents, timestep, text_embedding, return_dict=False
|
||||
)[0]
|
||||
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel()
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input[args.version]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["unet"]
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input["v1_4"]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["unet"]
|
||||
else:
|
||||
raise ValueError(f"{args.variant} is not yet added")
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_unet
|
||||
@@ -1,99 +0,0 @@
|
||||
import sys
|
||||
from model_wrappers import (
|
||||
get_base_vae_mlir,
|
||||
get_vae_mlir,
|
||||
get_unet_mlir,
|
||||
get_clip_mlir,
|
||||
)
|
||||
from resources import models_db
|
||||
from stable_args import args
|
||||
from utils import get_shark_model
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
if BATCH_SIZE != 1:
|
||||
sys.exit("Only batch size 1 is supported.")
|
||||
|
||||
|
||||
def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
try:
|
||||
bucket = models_db[0][bucket_key]
|
||||
model_name = models_db[1][model_key]
|
||||
iree_flags += models_db[2][model][is_tuned][precision][
|
||||
"default_compilation_flags"
|
||||
]
|
||||
except KeyError:
|
||||
raise Exception(
|
||||
f"{bucket}/{model_key} is not present in the models database"
|
||||
)
|
||||
|
||||
if (
|
||||
"specified_compilation_flags"
|
||||
in models_db[2][model][is_tuned][precision]
|
||||
):
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else args.device.split("://")[0]
|
||||
)
|
||||
if (
|
||||
device
|
||||
not in models_db[2][model][is_tuned][precision][
|
||||
"specified_compilation_flags"
|
||||
]
|
||||
):
|
||||
device = "default_device"
|
||||
iree_flags += models_db[2][model][is_tuned][precision][
|
||||
"specified_compilation_flags"
|
||||
][device]
|
||||
|
||||
return bucket, model_name, iree_flags
|
||||
|
||||
|
||||
def get_unet():
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
bucket_key = f"{args.variant}/{is_tuned}"
|
||||
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "unet", is_tuned, args.precision
|
||||
)
|
||||
if not args.use_tuned and args.import_mlir:
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae():
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
bucket_key = f"{args.variant}/{is_tuned}"
|
||||
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "vae", is_tuned, args.precision
|
||||
)
|
||||
if not args.use_tuned and args.import_mlir:
|
||||
if args.use_base_vae:
|
||||
return get_base_vae_mlir(model_name, iree_flags)
|
||||
return get_vae_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_clip():
|
||||
bucket_key = f"{args.variant}/untuned"
|
||||
model_key = f"{args.variant}/{args.version}/clip/fp32/length_{args.max_length}/untuned"
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "clip", "untuned", "fp32"
|
||||
)
|
||||
if args.import_mlir:
|
||||
return get_clip_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
@@ -1,44 +0,0 @@
|
||||
Compile / Run Instructions:
|
||||
|
||||
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
|
||||
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
|
||||
|
||||
Compile Commands FP32/FP16:
|
||||
|
||||
```shell
|
||||
Vulkan AMD:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
|
||||
# use –iree-input-type=mhlo for tf models
|
||||
|
||||
CUDA NVIDIA:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
CPU:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
```
|
||||
|
||||
|
||||
|
||||
Run / Benchmark Command (FP32 - NCHW):
|
||||
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
|
||||
|
||||
```shell
|
||||
## Vulkan AMD:
|
||||
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --device=vulkan --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
|
||||
|
||||
## CUDA:
|
||||
iree-benchmark-module --module_file=/path/to/vmfb --entry_function=forward --device=cuda --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
|
||||
|
||||
## CPU:
|
||||
iree-benchmark-module --module_file=/path/to/vmfb --entry_function=forward --device=local-task --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
|
||||
|
||||
```
|
||||
|
||||
Run via vulkan_gui for RGP Profiling:
|
||||
|
||||
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
|
||||
```shell
|
||||
./build/vulkan_gui/iree-vulkan-gui --module_file=/path/to/unet.vmfb --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
|
||||
```
|
||||
@@ -1,31 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
prompt_examples = []
|
||||
prompts_loc = resource_path("resources/prompts.json")
|
||||
if os.path.exists(prompts_loc):
|
||||
with open(prompts_loc, encoding="utf-8") as fopen:
|
||||
prompt_examples = json.load(fopen)
|
||||
|
||||
if not prompt_examples:
|
||||
print("Unable to fetch prompt examples.")
|
||||
|
||||
|
||||
models_db = []
|
||||
models_loc = resource_path("resources/model_db.json")
|
||||
if os.path.exists(models_loc):
|
||||
with open(models_loc, encoding="utf-8") as fopen:
|
||||
models_db = json.load(fopen)
|
||||
|
||||
if len(models_db) != 3:
|
||||
sys.exit("Error: Unable to load models database.")
|
||||
@@ -1,164 +0,0 @@
|
||||
[
|
||||
{
|
||||
"stablediffusion/untuned":"gs://shark_tank/stable_diffusion",
|
||||
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
|
||||
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
|
||||
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
|
||||
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
|
||||
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
|
||||
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
|
||||
"openjourney/tuned":"gs://shark_tank/sd_tuned",
|
||||
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
|
||||
},
|
||||
{
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
|
||||
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
|
||||
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
|
||||
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet2base_8dec_fp16",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
|
||||
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip2base_18dec_fp32",
|
||||
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_19dec_v2p1base_fp32_64",
|
||||
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet2_14dec_fp16",
|
||||
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae2_19dec_fp16",
|
||||
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
|
||||
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip2_18dec_fp32",
|
||||
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
|
||||
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
|
||||
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
|
||||
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
|
||||
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
|
||||
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
|
||||
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
|
||||
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
|
||||
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
|
||||
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
|
||||
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
|
||||
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
|
||||
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
|
||||
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
|
||||
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
|
||||
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
|
||||
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
|
||||
"analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
|
||||
"openjourney/v2_1base/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
|
||||
"openjourney/v2_1base/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
|
||||
"openjourney/v2_1base/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
|
||||
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
|
||||
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
|
||||
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
|
||||
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
|
||||
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
|
||||
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
|
||||
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
|
||||
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
|
||||
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
|
||||
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
|
||||
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
|
||||
},
|
||||
{
|
||||
"unet": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": []
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": []
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32"
|
||||
],
|
||||
"specified_compilation_flags": {
|
||||
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
|
||||
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
|
||||
}
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform"
|
||||
]
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -1,8 +0,0 @@
|
||||
[["A high tech solarpunk utopia in the Amazon rainforest"],
|
||||
["A pikachu fine dining with a view to the Eiffel Tower"],
|
||||
["A mecha robot in a favela in expressionist style"],
|
||||
["an insect robot preparing a delicious meal"],
|
||||
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
|
||||
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
|
||||
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
|
||||
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]
|
||||
@@ -1,133 +0,0 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from utils import compile_through_fx, get_shark_model
|
||||
from stable_args import args
|
||||
import torch
|
||||
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
"latent": torch.randn(1, 4, 64, 64),
|
||||
"output": torch.randn(1, 4, 64, 64),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"dt": torch.tensor(1).to(torch.float32),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
beta_start,
|
||||
beta_end,
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
)
|
||||
|
||||
def compile(self):
|
||||
example_latent = model_input["euler"]["latent"]
|
||||
example_output = model_input["euler"]["output"]
|
||||
if args.precision == "fp16":
|
||||
example_latent = example_latent.half()
|
||||
example_output = example_output.half()
|
||||
example_sigma = model_input["euler"]["sigma"]
|
||||
example_dt = model_input["euler"]["dt"]
|
||||
|
||||
class ScalingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma, latent, dt):
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
return latent + derivative * dt
|
||||
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if args.import_mlir:
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model = compile_through_fx(
|
||||
scaling_model,
|
||||
(example_latent, example_sigma),
|
||||
model_name="euler_scale_model_input_" + args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
step_model = SchedulerStepModel()
|
||||
self.step_model = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
model_name="euler_step_" + args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
else:
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
|
||||
)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
return self.scaling_model(
|
||||
"forward",
|
||||
(
|
||||
sample,
|
||||
sigma,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(self, noise_pred, timestep, latent):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
dt = self.sigmas[step_index + 1] - sigma
|
||||
return self.step_model(
|
||||
"forward",
|
||||
(
|
||||
noise_pred,
|
||||
sigma,
|
||||
latent,
|
||||
dt,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
@@ -1,105 +0,0 @@
|
||||
import os
|
||||
from shark.model_annotation import model_annotation, create_context
|
||||
from shark.iree_utils._common import run_cmd, iree_target_map
|
||||
from shark.shark_downloader import (
|
||||
download_model,
|
||||
download_public_file,
|
||||
WORKDIR,
|
||||
)
|
||||
from shark.parser import shark_args
|
||||
from stable_args import args
|
||||
from opt_params import get_params
|
||||
from utils import set_init_device_flags
|
||||
|
||||
|
||||
# Downloads the model (Unet or VAE fp16) from shark_tank
|
||||
set_init_device_flags()
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
bucket_key = f"{args.variant}/untuned"
|
||||
use_winograd = True
|
||||
if args.annotation_model == "unet":
|
||||
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/untuned"
|
||||
elif args.annotation_model == "vae":
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/untuned{is_base}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, args.annotation_model, "untuned", args.precision
|
||||
)
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=bucket,
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
# Downloads the tuned config files from shark_tank
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
if use_winograd:
|
||||
config_name = f"{args.annotation_model}_winograd.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
winograd_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
download_public_file(full_gs_url, winograd_config_dir, True)
|
||||
|
||||
if args.annotation_model == "unet":
|
||||
if args.variant in ["anythingv3", "analogdiffusion"]:
|
||||
args.max_length = 77
|
||||
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
|
||||
# Annotate the model with Winograd attribute on selected conv ops
|
||||
if use_winograd:
|
||||
with create_context() as ctx:
|
||||
winograd_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=mlir_model,
|
||||
config_path=winograd_config_dir,
|
||||
search_op="conv",
|
||||
winograd=use_winograd,
|
||||
)
|
||||
with open(
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir", "w"
|
||||
) as f:
|
||||
f.write(str(winograd_model))
|
||||
|
||||
# For Unet annotate the model with tuned lowering configs
|
||||
if args.annotation_model == "unet":
|
||||
if use_winograd:
|
||||
input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
|
||||
else:
|
||||
input_mlir = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
|
||||
dump_after = "iree-flow-pad-linalg-ops"
|
||||
|
||||
# Dump IR after padding/img2col/winograd passes
|
||||
run_cmd(
|
||||
f"iree-compile {input_mlir} "
|
||||
"--iree-input-type=tm_tensor "
|
||||
f"--iree-hal-target-backends={iree_target_map(args.device)} "
|
||||
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
|
||||
"--iree-stream-resource-index-bits=64 "
|
||||
"--iree-vm-target-index-bits=64 "
|
||||
"--iree-flow-enable-padding-linalg-ops "
|
||||
"--iree-flow-linalg-ops-padding-size=32 "
|
||||
"--iree-flow-enable-conv-img2col-transform "
|
||||
f"--mlir-print-ir-after={dump_after} "
|
||||
"--compile-to=flow "
|
||||
f"2>{args.annotation_output}/dump_after_winograd.mlir "
|
||||
)
|
||||
|
||||
# Annotate the model with lowering configs in the config file
|
||||
with create_context() as ctx:
|
||||
tuned_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=f"{args.annotation_output}/dump_after_winograd.mlir",
|
||||
config_path=lowering_config_dir,
|
||||
search_op="all",
|
||||
)
|
||||
|
||||
# Remove the intermediate mlir and save the final annotated model
|
||||
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
|
||||
output_path = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
with open(output_path, "w") as f:
|
||||
f.write(str(tuned_model))
|
||||
print(f"Saved the annotated mlir in {output_path}.")
|
||||
@@ -1,250 +0,0 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=["cyberpunk forest by Salvador Dali"],
|
||||
help="text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative-prompts",
|
||||
nargs="+",
|
||||
default=[""],
|
||||
help="text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the no. of steps to do the sampling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed to use.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="the value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=64,
|
||||
help="max length of the tokenizer output, options are 64 and 77.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Model Config and Usage Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="vulkan", help="device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="v2_1base",
|
||||
help="Specify version of stable diffusion model",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp16", help="precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="saves the compiled flatbuffer to the local directory",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_base_vae",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Do conversion from the VAE output to pixel space on cpu.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--variant",
|
||||
default="stablediffusion",
|
||||
help="We now support multiple vairants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="SharkEulerDiscrete",
|
||||
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree-vulkan-target-triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Profiles vulkan device and collects the .rdc info",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="4147483648",
|
||||
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for disabling vulkan validation layers when benchmarking",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Misc. Debug and Optimization flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--use_compiled_scheduler",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="use the default scheduler precompiled into the model if available",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dump_isa",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks",
|
||||
default=None,
|
||||
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="temp_dispatch_benchmarks",
|
||||
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--enable_rgp",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for inserting debug frames between iterations for use with rgp.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hide_steps",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for hiding the details of iteration/sec for each step.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--warmup_count",
|
||||
type=int,
|
||||
default=0,
|
||||
help="flag setting warmup count for clip and vae [>= 0].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--clear_all",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Web UI flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--progress_bar",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for removing the pregress bar animation during image generation",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### SD model auto-annotation flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_output",
|
||||
type=path_expand,
|
||||
default="./",
|
||||
help="Directory to save the annotated mlir file",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_model",
|
||||
type=str,
|
||||
default="unet",
|
||||
help="Options are unet and vae.",
|
||||
)
|
||||
|
||||
args = p.parse_args()
|
||||
@@ -1,139 +0,0 @@
|
||||
# Stable Diffusion optimized for AMD RDNA2/RDNA3 GPUs
|
||||
|
||||
Before you start, please be aware that this is beta software that relies on a special AMD driver. Like all StableDiffusion GUIs published so far, you need some technical expertise to set it up. We apologize in advance if you bump into issues. If that happens, please don't hesitate to ask our Discord community for help! If you still can't get it to work, we're sorry, and please be assured that we (Nod and AMD) are working hard to improve the user experience in coming months.
|
||||
If it works well for you, please "star" the following GitHub projects... this is one of the best ways to help and spread the word!
|
||||
|
||||
* https://github.com/nod-ai/SHARK
|
||||
* https://github.com/iree-org/iree
|
||||
|
||||
## Install this specific AMD Drivers (AMD latest may not have all the fixes).
|
||||
|
||||
### AMD KB Drivers for RDNA2 and RDNA3:
|
||||
|
||||
*AMD Software: Adrenalin Edition 22.11.1 for MLIR/IREE Driver Version 22.20.29.09 for Windows® 10 and Windows® 11 (Windows Driver Store Version 31.0.12029.9003)*
|
||||
|
||||
First, download this special driver in a folder of your choice. We recommend you keep that driver around since you may need to re-install it later, if Windows Update decides to overwrite it:
|
||||
https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mlir-iree
|
||||
|
||||
KNOWN ISSUES with this special AMD driver:
|
||||
* `Windows Update` may (depending how it's configured) automatically install a new official AMD driver that overwrites this IREE-specific driver. If Stable Diffusion used to work, then a few days later, it slows down a lot or produces incorrect results (e.g. black images), this may be the cause. To fix this problem, please check the installed driver's version, and re-install the special driver if needed. (TODO: document how to prevent this `Windows Update` behavior!)
|
||||
* Some people using this special driver experience mouse pointer accuracy issues, if you use a larger-than-default mouse pointer. The clicked point isn't centered properly. One possible work-around is to reset the pointer size to "1" in "Change pointer size and color".
|
||||
|
||||
## Installation
|
||||
|
||||
Download the latest Windows SHARK SD binary [423 here](https://github.com/nod-ai/SHARK/releases/download/20230101.423/shark_sd_20230101_423.exe) in a folder of your choice. If you want nighly builds you can look for them in the github releases page. Please read carefully the following notes:
|
||||
|
||||
Notes:
|
||||
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR, that can get outdated if you run multiple EXE from the same folder. You can use `--clean_all` flag once to clean all the old files.
|
||||
* Your browser may warn you about downloading an .exe file
|
||||
* If you recently updated the driver or this binary (EXE file), we recommend you:
|
||||
* clear all the local artifacts with `--clean_all` OR
|
||||
* clear the Vulkan shader cache: For Windows users this can be done by clearing the contents of `C:\Users\%username%\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
|
||||
* clear the `huggingface` cache. In Windows, this is `C:\Users\%username%\.cache\huggingface`.
|
||||
|
||||
## Running
|
||||
|
||||
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE to start the web browser)
|
||||
* The first run may take about 10-15 minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
|
||||
* If successful, you will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
|
||||
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/?__theme=dark.
|
||||
|
||||
## Stopping
|
||||
|
||||
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment. The application should stop.
|
||||
* Please make sure to do the above step before you attempt to update the EXE to a new version.
|
||||
|
||||
# Results
|
||||
|
||||
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
|
||||
|
||||
|
||||
Here are some samples generated:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
<details>
|
||||
<summary>Advanced Installation </summary>
|
||||
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
|
||||
### Windows 10/11 Users
|
||||
|
||||
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
|
||||
|
||||
* Install Git for Windows from [here](https://git-scm.com/download/win)
|
||||
|
||||
#### Allow the install script to run in Powershell
|
||||
```powershell
|
||||
set-executionpolicy remotesigned
|
||||
```
|
||||
|
||||
#### Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...)
|
||||
```powershell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
cd SHARK
|
||||
./setup_venv.ps1 #You can re-run this script to get the latest version
|
||||
```
|
||||
|
||||
### Linux
|
||||
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
cd SHARK
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
```
|
||||
|
||||
### Run Stable Diffusion on your device - WebUI
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\Users\nod\SHARK> cd web
|
||||
(shark.venv) PS C:\Users\nod\SHARK\web> python index.py
|
||||
```
|
||||
#### Linux Users
|
||||
```shell
|
||||
(shark.venv) > cd web
|
||||
(shark.venv) > python index.py
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Run Stable Diffusion on your device - Commandline
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> python .\shark\examples\shark_inference\stable_diffusion\main.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
```
|
||||
|
||||
#### Linux
|
||||
```shell
|
||||
python3.10 shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
```
|
||||
|
||||
The output on a 6900XT would like:
|
||||
|
||||
```shell
|
||||
44it [00:08, 5.14it/s]i = 44 t = 120 (191ms)
|
||||
45it [00:08, 5.15it/s]i = 45 t = 100 (191ms)
|
||||
46it [00:08, 5.16it/s]i = 46 t = 80 (191ms)
|
||||
47it [00:09, 5.16it/s]i = 47 t = 60 (193ms)
|
||||
48it [00:09, 5.15it/s]i = 48 t = 40 (195ms)
|
||||
49it [00:09, 5.12it/s]i = 49 t = 20 (196ms)
|
||||
50it [00:09, 5.14it/s]
|
||||
Average step time: 192.8154182434082ms/it
|
||||
Total image generation runtime (s): 10.390909433364868
|
||||
(shark.venv) PS C:\g\shark>
|
||||
```
|
||||
|
||||
|
||||
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)
|
||||
</details>
|
||||
<details>
|
||||
<summary>Discord link</summary>
|
||||
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
|
||||
</details>
|
||||
@@ -1,231 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from stable_args import args
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
print("Saving to {}".format(vmfb_path))
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(
|
||||
vmfb_path
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.parser import shark_args
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=tank_url,
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
# Converts the torch-module into a shark_module.
|
||||
def compile_through_fx(model, inputs, model_name, extra_args=[]):
|
||||
|
||||
mlir_module, func_name = import_with_fx(model, inputs)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
|
||||
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
|
||||
]
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
"""
|
||||
Inputs: driver_name
|
||||
Returns a list of all the available devices for a given driver sorted by
|
||||
the iree path names of the device as in --list_devices option in iree.
|
||||
"""
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return device_list_src
|
||||
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired combination of name/path.
|
||||
"""
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
|
||||
def get_output_value(dev_dict):
|
||||
if key_combination == 1:
|
||||
return f"{driver}://{dev_dict['path']}"
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
|
||||
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
return device_map
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
try:
|
||||
device_mapping = device_map[device]
|
||||
except KeyError:
|
||||
raise ValueError(f"Device '{device}' is not a valid device.")
|
||||
return device_mapping
|
||||
|
||||
|
||||
def set_init_device_flags():
|
||||
if "vulkan" in args.device:
|
||||
# set runtime flags for vulkan.
|
||||
set_iree_runtime_flags()
|
||||
|
||||
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
||||
device_name, args.device = map_device_to_name_path(args.device)
|
||||
if not args.iree_vulkan_target_triple:
|
||||
triple = get_vulkan_target_triple(device_name)
|
||||
if triple is not None:
|
||||
args.iree_vulkan_target_triple = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
|
||||
)
|
||||
elif "cuda" in args.device:
|
||||
args.device = "cuda"
|
||||
elif "cpu" in args.device:
|
||||
args.device = "cpu"
|
||||
|
||||
# set max_length based on availability.
|
||||
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
args.max_length = 77
|
||||
elif args.variant == "openjourney":
|
||||
args.max_length = 64
|
||||
|
||||
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
|
||||
if (
|
||||
args.variant in ["openjourney", "dreamlike"]
|
||||
or args.precision != "fp16"
|
||||
or "vulkan" not in args.device
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
):
|
||||
args.use_tuned = False
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
elif args.use_base_vae and args.variant != "stablediffusion":
|
||||
args.use_tuned = False
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
if args.use_tuned:
|
||||
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
|
||||
|
||||
|
||||
# Utility to get list of devices available.
|
||||
def get_available_devices():
|
||||
def get_devices_by_name(driver_name):
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
driver_name = iree_device_map(driver_name)
|
||||
device_list_dict = get_all_devices(driver_name)
|
||||
print(f"{driver_name} devices are available.")
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_list.append(f"{driver_name}://{i} => {device['name']}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
vulkan_devices = get_devices_by_name("vulkan")
|
||||
available_devices.extend(vulkan_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
available_devices.append("cpu")
|
||||
return available_devices
|
||||
@@ -1,35 +0,0 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import T5Tokenizer, TFT5Model
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
# Create a set of inputs
|
||||
t5_inputs = [
|
||||
tf.TensorSpec(shape=[1, 10], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[1, 10], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class T5Module(tf.Module):
|
||||
def __init__(self):
|
||||
super(T5Module, self).__init__()
|
||||
self.m = TFT5Model.from_pretrained("t5-small")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, decoder_input_ids=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, decoder_input_ids):
|
||||
return self.m.predict(input_ids, decoder_input_ids)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
text = "I love the distilled version of models."
|
||||
inputs = tokenizer(text, return_tensors="tf").input_ids
|
||||
|
||||
shark_module = SharkInference(T5Module(), (inputs, inputs))
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward((inputs, inputs)))
|
||||
@@ -1,43 +0,0 @@
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.train(False)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
## The vision models present here: https://pytorch.org/vision/stable/models.html
|
||||
vision_models_list = [
|
||||
models.resnet18(pretrained=True),
|
||||
models.alexnet(pretrained=True),
|
||||
models.vgg16(pretrained=True),
|
||||
models.squeezenet1_0(pretrained=True),
|
||||
models.densenet161(pretrained=True),
|
||||
models.inception_v3(pretrained=True),
|
||||
models.shufflenet_v2_x1_0(pretrained=True),
|
||||
models.mobilenet_v2(pretrained=True),
|
||||
models.mobilenet_v3_small(pretrained=True),
|
||||
models.resnext50_32x4d(pretrained=True),
|
||||
models.wide_resnet50_2(pretrained=True),
|
||||
models.mnasnet1_0(pretrained=True),
|
||||
models.efficientnet_b0(pretrained=True),
|
||||
models.regnet_y_400mf(pretrained=True),
|
||||
models.regnet_x_400mf(pretrained=True),
|
||||
]
|
||||
|
||||
for i, vision_model in enumerate(vision_models_list):
|
||||
shark_module = SharkInference(
|
||||
VisionModule(vision_model),
|
||||
(input,),
|
||||
)
|
||||
shark_module.compile()
|
||||
shark_module.forward((input,))
|
||||
@@ -1,39 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
|
||||
class UnetModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = torch.hub.load(
|
||||
"mateuszbuda/brain-segmentation-pytorch",
|
||||
"unet",
|
||||
in_channels=3,
|
||||
out_channels=1,
|
||||
init_features=32,
|
||||
pretrained=True,
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
UnetModule(),
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(vision_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=False
|
||||
)
|
||||
|
||||
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input,))
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
@@ -1,21 +0,0 @@
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from pipeline_shark_stable_diffusion_upscale import (
|
||||
SharkStableDiffusionUpscalePipeline,
|
||||
)
|
||||
import torch
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipeline = SharkStableDiffusionUpscalePipeline(model_id)
|
||||
|
||||
# let's download an image
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
|
||||
response = requests.get(url)
|
||||
low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
low_res_img = low_res_img.resize((128, 128))
|
||||
|
||||
prompt = "a white cat"
|
||||
|
||||
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
|
||||
upscaled_image.save("upsampled_cat.png")
|
||||
@@ -1,99 +0,0 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from utils import compile_through_fx
|
||||
import torch
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
|
||||
model_input = {
|
||||
"clip": (torch.randint(1, 2, (1, 77)),),
|
||||
"vae": (torch.randn(1, 4, 128, 128),),
|
||||
"unet": (
|
||||
torch.randn(2, 7, 128, 128).half(), # latents
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, 77, 1024).half(), # embedding
|
||||
torch.randn(2).to(torch.int64), # noise_level
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
)
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
model_input["clip"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
|
||||
def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
return x
|
||||
|
||||
vae = VaeModel()
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
model_input["vae"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
revision="fp16",
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, latent, timestep, text_embedding, noise_level):
|
||||
unet_out = self.unet.forward(
|
||||
latent,
|
||||
timestep,
|
||||
text_embedding,
|
||||
noise_level,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
return unet_out
|
||||
|
||||
unet = UnetModel()
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple([inputs.cuda() for inputs in model_input["unet"]])
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_unet
|
||||
@@ -1,53 +0,0 @@
|
||||
import sys
|
||||
from model_wrappers import (
|
||||
get_vae_mlir,
|
||||
get_unet_mlir,
|
||||
get_clip_mlir,
|
||||
)
|
||||
from upscaler_args import args
|
||||
from utils import get_shark_model
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
if BATCH_SIZE != 1:
|
||||
sys.exit("Only batch size 1 is supported.")
|
||||
|
||||
|
||||
unet_flag = [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
|
||||
vae_flag = [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
]
|
||||
|
||||
clip_flag = [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
]
|
||||
|
||||
bucket = "gs://shark_tank/stable_diffusion/"
|
||||
|
||||
|
||||
def get_unet():
|
||||
model_name = "upscaler_unet"
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(model_name, unet_flag)
|
||||
return get_shark_model(bucket, model_name, unet_flag)
|
||||
|
||||
|
||||
def get_vae():
|
||||
model_name = "upscaler_vae"
|
||||
if args.import_mlir:
|
||||
return get_vae_mlir(model_name, vae_flag)
|
||||
return get_shark_model(bucket, model_name, vae_flag)
|
||||
|
||||
|
||||
def get_clip():
|
||||
model_name = "upscaler_clip"
|
||||
if args.import_mlir:
|
||||
return get_clip_mlir(model_name, clip_flag)
|
||||
return get_shark_model(bucket, model_name, clip_flag)
|
||||
@@ -1,490 +0,0 @@
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers import logging
|
||||
from diffusers.pipeline_utils import ImagePipelineOutput
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
|
||||
image = [np.array(i.resize((w, h)))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
def shark_run_wrapper(model, *args):
|
||||
np_inputs = tuple([x.detach().numpy() for x in args])
|
||||
outputs = model("forward", np_inputs)
|
||||
return torch.from_numpy(outputs)
|
||||
|
||||
|
||||
class SharkStableDiffusionUpscalePipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_id,
|
||||
):
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
model_id, subfolder="tokenizer"
|
||||
)
|
||||
self.low_res_scheduler = DDPMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
self.scheduler = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
self.vae = get_vae()
|
||||
self.unet = get_unet()
|
||||
self.text_encoder = get_clip()
|
||||
self.max_noise_level = (350,)
|
||||
self._execution_device = "cpu"
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
Args:
|
||||
prompt (`str` or `list(int)`):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
"""
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(
|
||||
prompt, padding="longest", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
||||
-1
|
||||
] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
# if (
|
||||
# hasattr(self.text_encoder.config, "use_attention_mask")
|
||||
# and self.text_encoder.config.use_attention_mask
|
||||
# ):
|
||||
# attention_mask = text_inputs.attention_mask.to(device)
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
text_embeddings = shark_run_wrapper(
|
||||
self.text_encoder, text_input_ids.to(device)
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(
|
||||
bs_embed * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# if (
|
||||
# hasattr(self.text_encoder.config, "use_attention_mask")
|
||||
# and self.text_encoder.config.use_attention_mask
|
||||
# ):
|
||||
# attention_mask = uncond_input.attention_mask.to(device)
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
uncond_embeddings = shark_run_wrapper(
|
||||
self.text_encoder,
|
||||
uncond_input.input_ids.to(device),
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(
|
||||
1, num_images_per_prompt, 1
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
image = shark_run_wrapper(self.vae, latents)
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def check_inputs(self, prompt, image, noise_level, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(
|
||||
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
||||
)
|
||||
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
)
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor
|
||||
if isinstance(image, list) or isinstance(image, torch.Tensor):
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = len(prompt)
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
else:
|
||||
image_batch_size = image.shape[0]
|
||||
if batch_size != image_batch_size:
|
||||
raise ValueError(
|
||||
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
|
||||
" Please make sure that passed `prompt` matches the batch size of `image`."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
if images.shape[-1] == 1:
|
||||
# special case for grayscale (single channel) images
|
||||
pil_images = [
|
||||
Image.fromarray(image.squeeze(), mode="L") for image in images
|
||||
]
|
||||
else:
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
if latents is None:
|
||||
if device == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(
|
||||
shape, generator=generator, device="cpu", dtype=dtype
|
||||
).to(device)
|
||||
else:
|
||||
latents = torch.randn(
|
||||
shape, generator=generator, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(
|
||||
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
|
||||
)
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[
|
||||
torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]
|
||||
],
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 20,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[
|
||||
Union[torch.Generator, List[torch.Generator]]
|
||||
] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[
|
||||
Callable[[int, int, torch.FloatTensor], None]
|
||||
] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
):
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, image, noise_level, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = image.to(dtype=text_embeddings.dtype, device=device)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor(
|
||||
[noise_level], dtype=torch.long, device=device
|
||||
)
|
||||
if device == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
noise = torch.randn(
|
||||
image.shape,
|
||||
generator=generator,
|
||||
device="cpu",
|
||||
dtype=text_embeddings.dtype,
|
||||
).to(device)
|
||||
else:
|
||||
noise = torch.randn(
|
||||
image.shape,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=text_embeddings.dtype,
|
||||
)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
image = torch.cat([image] * batch_multiplier * num_images_per_prompt)
|
||||
noise_level = torch.cat([noise_level] * image.shape[0])
|
||||
|
||||
# 6. Prepare latent variables
|
||||
height, width = image.shape[2:]
|
||||
# num_channels_latents = self.vae.config.latent_channels
|
||||
num_channels_latents = 4
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Check that sizes of image and latents match
|
||||
num_channels_image = image.shape[1]
|
||||
# if (
|
||||
# num_channels_latents + num_channels_image
|
||||
# != self.unet.config.in_channels
|
||||
# ):
|
||||
# raise ValueError(
|
||||
# f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
# f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
# f" `num_channels_image`: {num_channels_image} "
|
||||
# f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
||||
# " `pipeline.unet` or your `image` input."
|
||||
# )
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = (
|
||||
len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
)
|
||||
for i, t in tqdm(enumerate(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
latent_model_input = torch.cat([latent_model_input, image], dim=1)
|
||||
|
||||
timestep = torch.tensor([t]).to(torch.float32)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = shark_run_wrapper(
|
||||
self.unet,
|
||||
latent_model_input.half(),
|
||||
timestep,
|
||||
text_embeddings.half(),
|
||||
noise_level,
|
||||
)
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# # call the callback, if provided
|
||||
# if i == len(timesteps) - 1 or (
|
||||
# (i + 1) > num_warmup_steps
|
||||
# and (i + 1) % self.scheduler.order == 0
|
||||
# ):
|
||||
# progress_bar.update()
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
# self.vae.to(dtype=torch.float32)
|
||||
image = self.decode_latents(latents.float())
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user