mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-12 07:18:27 -05:00
Compare commits
1 Commits
rdna
...
minilmLoad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
14a56ca9b0 |
37
.github/workflows/gh-pages-releases.yml
vendored
37
.github/workflows/gh-pages-releases.yml
vendored
@@ -1,37 +0,0 @@
|
||||
# See: https://github.com/llvm/torch-mlir/issues/1374
|
||||
name: Publish releases page
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
scrape_and_publish_releases:
|
||||
name: "Scrape and publish releases"
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
# Don't run this in everyone's forks.
|
||||
if: github.repository == 'nod-ai/SHARK'
|
||||
|
||||
steps:
|
||||
- name: Checking out repository
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
token: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
- name: Run scrape releases script
|
||||
run: python ./build_tools/scrape_releases.py nod-ai SHARK > /tmp/index.html
|
||||
shell: bash
|
||||
- run: git fetch --all
|
||||
- run: git switch github-pages
|
||||
- run: git config --global user.email "none@none.com"
|
||||
- run: git config --global user.name "nod-ai"
|
||||
- run: mv /tmp/index.html package-index/index.html
|
||||
- run: git add package-index/index.html
|
||||
|
||||
# Only try to make a commit if the file has changed.
|
||||
- run: git diff --cached --exit-code || git commit -m "Update releases."
|
||||
|
||||
- name: GitHub Push
|
||||
uses: ad-m/github-push-action@v0.6.0
|
||||
with:
|
||||
github_token: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
branch: github-pages
|
||||
50
.github/workflows/nightly.yml
vendored
50
.github/workflows/nightly.yml
vendored
@@ -11,12 +11,11 @@ on:
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: a100
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
backend: [IREE, SHARK]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
@@ -39,10 +38,6 @@ jobs:
|
||||
tag_name="${package_version}"
|
||||
echo "package_version=${package_version}" >> $GITHUB_ENV
|
||||
echo "tag_name=${tag_name}" >> $GITHUB_ENV
|
||||
- name: Set Environment Variables
|
||||
run: |
|
||||
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
@@ -54,60 +49,34 @@ jobs:
|
||||
body: |
|
||||
Automatic snapshot release of nod.ai SHARK.
|
||||
draft: true
|
||||
prerelease: false
|
||||
prerelease: false
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
|
||||
python -m pip install flake8 pytest yapf toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/SHARK-Runtime/releases; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
|
||||
- name: Build and validate the IREE package
|
||||
if: ${{ matrix.backend == 'IREE' }}
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
|
||||
source iree.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://iree-org.github.io/iree/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" tank/test_models.py |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
if !(grep -Fxq " failed" pytest_results.txt)
|
||||
then
|
||||
export SHA=$(git log -1 --format='%h')
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
|
||||
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/latest/
|
||||
fi
|
||||
rm -rf ./wheelhouse/nodai*
|
||||
yapf -i --style .style.yapf shark/*.py
|
||||
|
||||
- name: Build and validate the SHARK Runtime package
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
- name: Build and validate the package
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
./setup_venv.sh
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
pip wheel -v -w wheelhouse . --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/SHARK-Runtime/releases
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
pytest --ci --ci_sha=${SHORT_SHA} tank/test_models.py |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
pytest -k 'not benchmark' --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py --ignore=shark/tests/test_shark_importer.py --ignore=tank/tf/
|
||||
|
||||
- name: Upload Release Assets
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
env:
|
||||
@@ -117,7 +86,6 @@ jobs:
|
||||
assets_path: ./wheelhouse/nodai_*.whl
|
||||
|
||||
- name: Publish Release
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
id: publish_release
|
||||
uses: eregon/publish-release@v1
|
||||
env:
|
||||
|
||||
160
.github/workflows/test-models.yml
vendored
160
.github/workflows/test-models.yml
vendored
@@ -1,7 +1,7 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Validate Models on Shark Runtime
|
||||
name: Validate torch-models on Shark Runtime
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -10,127 +10,93 @@ on:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
|
||||
# Ensure that only a single job or workflow using the same
|
||||
# concurrency group will run at a time. This would cancel
|
||||
# any in-progress jobs in the same github workflow and github
|
||||
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-validate:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [icelake, a100, MacStudio, ubuntu-latest]
|
||||
suite: [cpu,cuda,vulkan]
|
||||
python-version: ["3.10"]
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
suite: lint
|
||||
exclude:
|
||||
- os: ubuntu-latest
|
||||
suite: vulkan
|
||||
- os: ubuntu-latest
|
||||
suite: cuda
|
||||
- os: ubuntu-latest
|
||||
suite: cpu
|
||||
- os: MacStudio
|
||||
suite: vulkan
|
||||
- os: MacStudio
|
||||
suite: cuda
|
||||
- os: MacStudio
|
||||
suite: cpu
|
||||
- os: icelake
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
suite: cuda
|
||||
- os: a100
|
||||
suite: cpu
|
||||
build-linux:
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set Environment Variables
|
||||
run: |
|
||||
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Python Version File ${{ matrix.python-version }}
|
||||
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
|
||||
run: |
|
||||
# See https://github.com/actions/setup-python/issues/433
|
||||
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '${{ matrix.python-version }}'
|
||||
#cache: 'pip'
|
||||
#cache-dependency-path: |
|
||||
# **/requirements-importer.txt
|
||||
# **/requirements.txt
|
||||
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup pip cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
if: matrix.suite == 'lint'
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml black
|
||||
|
||||
python -m pip install flake8 pytest yapf toml
|
||||
|
||||
- name: Lint with flake8
|
||||
if: matrix.suite == 'lint'
|
||||
run: |
|
||||
# black format check
|
||||
black --version
|
||||
black --line-length 79 --check .
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude lit.cfg.py
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude lit.cfg.py
|
||||
yapf -i --style .style.yapf shark/*.py
|
||||
|
||||
- name: Validate Models on CPU
|
||||
if: matrix.suite == 'cpu'
|
||||
- name: Validate Models
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cpu --update_tank
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
pytest -k 'not benchmark' --ignore=tank/tf/ --ignore=shark/tests/test_shark_importer.py
|
||||
|
||||
perf-macOS:
|
||||
runs-on: MacStudio
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
|
||||
- name: Validate Models on NVIDIA GPU
|
||||
if: matrix.suite == 'cuda'
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Validate Models dependencies
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
PYTHON=python3.10 IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cuda --update_tank
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
pytest -k 'not benchmark' --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py --ignore=tank/tf/ --ignore=shark/tests/test_shark_importer.py
|
||||
|
||||
perf-linux:
|
||||
runs-on: a100
|
||||
timeout-minutes: 45
|
||||
continue-on-error: true
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
|
||||
- name: Validate Vulkan Models (MacOS)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup pip cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Validate Models
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
echo "VULKAN SDK PATH wo setup: $VULKAN_SDK"
|
||||
cd /Users/anush/VulkanSDK/1.3.224.1/
|
||||
source setup-env.sh
|
||||
cd $GITHUB_WORKSPACE
|
||||
echo "VULKAN SDK PATH with setup: $VULKAN_SDK"
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" tank/test_models.py -k vulkan --update_tank
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os != 'MacStudio'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k vulkan --update_tank
|
||||
pytest --ignore=shark/tests/test_shark_importer.py --ignore=tank/tf/
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -162,12 +162,7 @@ cython_debug/
|
||||
|
||||
# Shark related artefacts
|
||||
*venv/
|
||||
shark_tmp/
|
||||
|
||||
# ORT related artefacts
|
||||
cache_models/
|
||||
onnx_models/
|
||||
|
||||
#web logging
|
||||
web/logs/
|
||||
web/stored_results/stable_diffusion/
|
||||
|
||||
218
LICENSE
218
LICENSE
@@ -1,218 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
---- LLVM Exceptions to the Apache 2.0 License ----
|
||||
|
||||
As an exception, if, as a result of your compiling your source code, portions
|
||||
of this Software are embedded into an Object form of such source code, you
|
||||
may redistribute such embedded portions in such Object form without complying
|
||||
with the conditions of Sections 4(a), 4(b) and 4(d) of the License.
|
||||
|
||||
In addition, if you combine or link compiled forms of this Software with
|
||||
software that is licensed under the GPLv2 ("Combined Software") and if a
|
||||
court of competent jurisdiction determines that the patent provision (Section
|
||||
3), the indemnity provision (Section 9) or other Section of the License
|
||||
conflicts with the conditions of the GPLv2, you may retroactively and
|
||||
prospectively choose to deem waived or otherwise exclude such Section(s) of
|
||||
the License, but only in their entirety and only with respect to the Combined
|
||||
Software.
|
||||
247
README.md
247
README.md
@@ -14,16 +14,16 @@ High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerator
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
<summary>Installation (Linux, macOS and Windows)</summary>
|
||||
|
||||
<summary>Installation (Linux and macOS)</summary>
|
||||
|
||||
### Setup a new pip Virtual Environment
|
||||
|
||||
This step sets up a new VirtualEnv for Python
|
||||
|
||||
|
||||
```shell
|
||||
python --version #Check you have 3.10 on Linux, macOS or Windows Powershell
|
||||
python --version #Check you have 3.7->3.10 on Linux or 3.10 on macOS
|
||||
python -m venv shark_venv
|
||||
source shark_venv/bin/activate # Use shark_venv/Scripts/activate on Windows
|
||||
source shark_venv/bin/activate
|
||||
|
||||
# If you are using conda create and activate a new conda env
|
||||
|
||||
@@ -31,37 +31,32 @@ source shark_venv/bin/activate # Use shark_venv/Scripts/activate on Windows
|
||||
python -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
*macOS Metal* users please install https://sdk.lunarg.com/sdk/download/latest/mac/vulkan-sdk.dmg and enable "System wide install"
|
||||
*macOS Metal* users please install https://sdk.lunarg.com/sdk/download/latest/mac/vulkan-sdk.dmg
|
||||
|
||||
### Install SHARK
|
||||
|
||||
|
||||
This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10
|
||||
|
||||
```shell
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip install nodai-shark -f https://github.com/nod-ai/SHARK/releases -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/shark-runtime/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
|
||||
### Run shark tank model tests.
|
||||
```shell
|
||||
pytest tank/test_models.py
|
||||
```
|
||||
See tank/README.md for a more detailed walkthrough of our pytest suite and CLI.
|
||||
If you are on an Intel macOS machine you need this [workaround](https://github.com/nod-ai/SHARK/issues/102) for an upstream issue.
|
||||
|
||||
### Download and run Resnet50 sample
|
||||
|
||||
|
||||
```shell
|
||||
curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/resnet50_script.py
|
||||
#Install deps for test script
|
||||
pip install --pre torch torchvision torchaudio tqdm pillow gsutil --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python ./resnet50_script.py --device="cpu" #use cuda or vulkan or metal
|
||||
pip install --pre torch torchvision torchaudio tqdm pillow --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python ./resnet50_script.py --device="cpu" #use cuda or vulkan or metal
|
||||
```
|
||||
|
||||
|
||||
### Download and run BERT (MiniLM) sample
|
||||
```shell
|
||||
curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/minilm_jit.py
|
||||
#Install deps for test script
|
||||
pip install transformers torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
|
||||
python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
|
||||
```
|
||||
</details>
|
||||
|
||||
@@ -72,125 +67,55 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
|
||||
## Check out the code
|
||||
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
```
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
|
||||
### Windows Users
|
||||
|
||||
```shell
|
||||
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
|
||||
# Requires Python 3.10 and Powershell
|
||||
./setup_venv.ps1
|
||||
shark.venv/Scripts/activate
|
||||
```
|
||||
|
||||
### Linux / macOS Users
|
||||
|
||||
```shell
|
||||
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
# Please activate the venv after installation.
|
||||
```
|
||||
|
||||
|
||||
### Run a demo script
|
||||
```shell
|
||||
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
|
||||
# Or a pytest
|
||||
pytest tank/test_models.py -k "MiniLM"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Development, Testing and Benchmarks</summary>
|
||||
|
||||
If you want to use Python3.10 and with TF Import tools you can use the environment variables like:
|
||||
Set `USE_IREE=1` to use upstream IREE
|
||||
```
|
||||
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
|
||||
```
|
||||
|
||||
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
|
||||
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
|
||||
with Python bindings and set your PYTHONPATH as mentioned [here](https://google.github.io/iree/bindings/python/)
|
||||
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
|
||||
for Torch-MLIR.
|
||||
|
||||
### How to use your locally built Torch-MLIR with SHARK
|
||||
### Run all model tests on CPU/GPU/VULKAN/Metal
|
||||
```shell
|
||||
1.) Run `./setup_venv.sh in SHARK` and activate `shark.venv` virtual env.
|
||||
2.) Run `pip uninstall torch-mlir`.
|
||||
3.) Go to your local Torch-MLIR directory.
|
||||
4.) Activate mlir_venv virtual envirnoment.
|
||||
5.) Run `pip uninstall -r requirements.txt`.
|
||||
6.) Run `pip install -r requirements.txt`.
|
||||
7.) Build Torch-MLIR.
|
||||
8.) Activate shark.venv virtual environment from the Torch-MLIR directory.
|
||||
8.) Run `export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples` in the Torch-MLIR directory.
|
||||
9.) Go to the SHARK directory.
|
||||
```
|
||||
Now the SHARK will use your locally build Torch-MLIR repo.
|
||||
pytest shark/tests/models
|
||||
|
||||
|
||||
## Benchmarking Dispatches
|
||||
|
||||
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your command line argument.
|
||||
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
|
||||
|
||||
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
|
||||
|
||||
```
|
||||
shark_module = SharkInference(
|
||||
mlir_model,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
dispatch_benchmarks="all",
|
||||
dispatch_benchmarks_dir="results"
|
||||
)
|
||||
# If on Linux for quicker results:
|
||||
pytest shark/tests/models -n auto
|
||||
```
|
||||
|
||||
Output will include:
|
||||
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
|
||||
- An .mlir file containing the dispatch benchmark
|
||||
- A compiled .vmfb file containing the dispatch benchmark
|
||||
- An .mlir file containing just the hal executable
|
||||
- A compiled .vmfb file of the hal executable
|
||||
- A .txt file containing benchmark output
|
||||
|
||||
|
||||
See tank/README.md for instructions on how to run model tests and benchmarks from the SHARK tank.
|
||||
|
||||
### Run all model benchmark tests on CPU/GPU/VULKAN/Metal
|
||||
```shell
|
||||
pytest shark/tests/benchmarks
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>API Reference</summary>
|
||||
|
||||
### Shark Inference API
|
||||
|
||||
```
|
||||
from shark_runner import SharkInference
|
||||
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
# SharkImporter imports mlir file from the torch, tensorflow or tf-lite module.
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
torch_module,
|
||||
(input),
|
||||
frontend="torch", #tf, #tf-lite
|
||||
)
|
||||
torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
|
||||
|
||||
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(
|
||||
module = model class.
|
||||
(input,) = inputs to model (must be a torch-tensor)
|
||||
dynamic (boolean) = Pass the input shapes as static or dynamic.
|
||||
device = `cpu`, `gpu` or `vulkan` is supported.
|
||||
tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )
|
||||
shark_module.set_frontend("pytorch") # Use tensorflow, mhlo, linalg, tosa
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input))
|
||||
|
||||
result = shark_module.forward(inputs)
|
||||
```
|
||||
|
||||
|
||||
@@ -210,30 +135,104 @@ mhlo_ir = r"""builtin.module {
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
|
||||
|
||||
shark_module = SharkInference(mhlo_ir, (arg0, arg1))
|
||||
shark_module.set_frontend("mhlo")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((arg0, arg1))
|
||||
print(shark_module.forward((arg0, arg1)))
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
## Supported and Validated Models
|
||||
|
||||
SHARK is maintained to support the latest innovations in ML Models:
|
||||
<details>
|
||||
<summary>PyTorch Models</summary>
|
||||
|
||||
| TF HuggingFace Models | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DistilBERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| GPT2 | :green_heart: | :green_heart: | :green_heart: |
|
||||
| BLOOM | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Stable Diffusion | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Vision Transformer | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ResNet50 | :green_heart: | :green_heart: | :green_heart: |
|
||||
### Huggingface PyTorch Models
|
||||
|
||||
For a complete list of the models supported in SHARK, please refer to [tank/README.md](https://github.com/nod-ai/SHARK/blob/main/tank/README.md).
|
||||
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :heavy_check_mark: (JIT) | :heavy_check_mark: | | |
|
||||
| Albert | :heavy_check_mark: (JIT) | :heavy_check_mark: | | |
|
||||
| BigBird | :heavy_check_mark: (AOT) | | | |
|
||||
| DistilBERT | :heavy_check_mark: (JIT) | :heavy_check_mark: | | |
|
||||
| GPT2 | :x: (AOT) | | | |
|
||||
|
||||
### Torchvision Models
|
||||
|
||||
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|--------------------|----------------------|----------|----------|-------------|
|
||||
| AlexNet | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| DenseNet121 | :heavy_check_mark: (Script) | | | |
|
||||
| MNasNet1_0 | :heavy_check_mark: (Script) | | | |
|
||||
| MobileNetV2 | :heavy_check_mark: (Script) | | | |
|
||||
| MobileNetV3 | :heavy_check_mark: (Script) | | | |
|
||||
| Unet | :x: (Script) | | | |
|
||||
| Resnet18 | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| Resnet50 | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| Resnet101 | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| Resnext50_32x4d | :heavy_check_mark: (Script) | | | |
|
||||
| ShuffleNet_v2 | :x: (Script) | | | |
|
||||
| SqueezeNet | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| EfficientNet | :heavy_check_mark: (Script) | | | |
|
||||
| Regnet | :heavy_check_mark: (Script) | | | |
|
||||
| Resnest | :x: (Script) | | | |
|
||||
| Vision Transformer | :heavy_check_mark: (Script) | | | |
|
||||
| VGG 16 | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| Wide Resnet | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| RAFT | :x: (JIT) | | | |
|
||||
|
||||
For more information refer to [MODEL TRACKING SHEET](https://docs.google.com/spreadsheets/d/15PcjKeHZIrB5LfDyuw7DGEEE8XnQEX2aX8lm8qbxV8A/edit#gid=0)
|
||||
|
||||
### PyTorch Training Models
|
||||
|
||||
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :x: | :x: | | |
|
||||
| FullyConnected | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>JAX Models</summary>
|
||||
|
||||
|
||||
### JAX Models
|
||||
|
||||
| Models | JAX-MHLO lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| DALL-E | :x: | :x: | | |
|
||||
| FullyConnected | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>TFLite Models</summary>
|
||||
|
||||
### TFLite Models
|
||||
|
||||
| Models | TOSA/LinAlg | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :x: | :x: | | |
|
||||
| FullyConnected | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>TF Models</summary>
|
||||
|
||||
### Tensorflow Models
|
||||
|
||||
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :x: | :x: | | |
|
||||
| FullyConnected | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
|
||||
</details>
|
||||
|
||||
## Related Projects
|
||||
|
||||
|
||||
<details>
|
||||
<summary>IREE Project Channels</summary>
|
||||
|
||||
@@ -244,7 +243,7 @@ For a complete list of the models supported in SHARK, please refer to [tank/READ
|
||||
* [iree-discuss email list](https://groups.google.com/forum/#!forum/iree-discuss):
|
||||
Announcements, general and low-priority discussion
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>MLIR and Torch-MLIR Project Channels</summary>
|
||||
|
||||
|
||||
@@ -6,16 +6,16 @@ parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help='Specifies name of HF model to benchmark. (For exmaple "microsoft/MiniLM-L12-H384-uncased"',
|
||||
help=
|
||||
"Specifies name of HF model to benchmark. (For exmaple \"microsoft/MiniLM-L12-H384-uncased\""
|
||||
)
|
||||
load_args, unknown = parser.parse_known_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_name = load_args.model_name
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
shark_module = SharkHFBenchmarkRunner(
|
||||
model_name, (test_input,), jit_trace=True
|
||||
)
|
||||
shark_module = SharkHFBenchmarkRunner(model_name, (test_input,),
|
||||
jit_trace=True)
|
||||
shark_module.benchmark_c()
|
||||
shark_module.benchmark_python((test_input,))
|
||||
shark_module.benchmark_torch(test_input)
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import torch
|
||||
from shark.shark_benchmark_runner import SharkBenchmarkRunner
|
||||
from shark.shark_runner import SharkBenchmarkRunner
|
||||
from shark.parser import shark_args
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from onnxruntime.transformers.benchmark import (
|
||||
run_pytorch,
|
||||
run_tensorflow,
|
||||
run_onnxruntime,
|
||||
)
|
||||
from onnxruntime.transformers.benchmark import run_pytorch, run_tensorflow, run_onnxruntime
|
||||
from onnxruntime.transformers.huggingface_models import MODELS
|
||||
from onnxruntime.transformers.benchmark_helper import ConfigModifier, Precision
|
||||
import os
|
||||
@@ -14,6 +10,7 @@ import psutil
|
||||
|
||||
|
||||
class OnnxFusionOptions(object):
|
||||
|
||||
def __init__(self):
|
||||
self.disable_gelu = False
|
||||
self.disable_layer_norm = False
|
||||
@@ -28,13 +25,17 @@ class OnnxFusionOptions(object):
|
||||
|
||||
|
||||
class HuggingFaceLanguage(torch.nn.Module):
|
||||
|
||||
def __init__(self, hf_model_name):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
hf_model_name, # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
num_labels=
|
||||
2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=
|
||||
False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=
|
||||
False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
@@ -61,16 +62,8 @@ class SharkHFBenchmarkRunner(SharkBenchmarkRunner):
|
||||
)
|
||||
self.model_name = model_name
|
||||
model = HuggingFaceLanguage(model_name)
|
||||
SharkBenchmarkRunner.__init__(
|
||||
self,
|
||||
model,
|
||||
input,
|
||||
dynamic,
|
||||
self.device,
|
||||
jit_trace,
|
||||
from_aot,
|
||||
frontend,
|
||||
)
|
||||
SharkBenchmarkRunner.__init__(self, model, input, dynamic, self.device,
|
||||
jit_trace, from_aot, frontend)
|
||||
|
||||
def benchmark_torch(self, inputs):
|
||||
use_gpu = self.device == "gpu"
|
||||
@@ -81,20 +74,10 @@ class SharkHFBenchmarkRunner(SharkBenchmarkRunner):
|
||||
sequence_lengths = [inputs.shape[-1]]
|
||||
cache_dir = os.path.join(".", "cache_models")
|
||||
verbose = False
|
||||
result = run_pytorch(
|
||||
use_gpu,
|
||||
[self.model_name],
|
||||
None,
|
||||
config_modifier,
|
||||
Precision.FLOAT32,
|
||||
num_threads,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
shark_args.num_iterations,
|
||||
False,
|
||||
cache_dir,
|
||||
verbose,
|
||||
)
|
||||
result = run_pytorch(use_gpu, [self.model_name], None, config_modifier,
|
||||
Precision.FLOAT32, num_threads, batch_sizes,
|
||||
sequence_lengths, shark_args.num_iterations, False,
|
||||
cache_dir, verbose)
|
||||
print(
|
||||
f"ONNX Pytorch-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
@@ -109,19 +92,10 @@ class SharkHFBenchmarkRunner(SharkBenchmarkRunner):
|
||||
sequence_lengths = [inputs.shape[-1]]
|
||||
cache_dir = os.path.join(".", "cache_models")
|
||||
verbose = False
|
||||
result = run_tensorflow(
|
||||
use_gpu,
|
||||
[self.model_name],
|
||||
None,
|
||||
config_modifier,
|
||||
Precision.FLOAT32,
|
||||
num_threads,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
shark_args.num_iterations,
|
||||
cache_dir,
|
||||
verbose,
|
||||
)
|
||||
result = run_tensorflow(use_gpu, [self.model_name], None,
|
||||
config_modifier, Precision.FLOAT32, num_threads,
|
||||
batch_sizes, sequence_lengths,
|
||||
shark_args.num_iterations, cache_dir, verbose)
|
||||
print(
|
||||
f"ONNX TF-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
@@ -131,8 +105,7 @@ class SharkHFBenchmarkRunner(SharkBenchmarkRunner):
|
||||
print(
|
||||
f"{self.model_name} is currently not supported in ORT's HF. Check \
|
||||
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/huggingface_models.py \
|
||||
for currently supported models. Exiting benchmark ONNX."
|
||||
)
|
||||
for currently supported models. Exiting benchmark ONNX.")
|
||||
return
|
||||
use_gpu = self.device == "gpu"
|
||||
num_threads = psutil.cpu_count(logical=False)
|
||||
@@ -148,34 +121,17 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
use_raw_attention_mask = True
|
||||
model_fusion_statistics = {}
|
||||
overwrite = False
|
||||
model_source = "pt" # Either "pt" or "tf"
|
||||
model_source = "pt" #Either "pt" or "tf"
|
||||
provider = None
|
||||
config_modifier = ConfigModifier(None)
|
||||
onnx_args = OnnxFusionOptions()
|
||||
result = run_onnxruntime(
|
||||
use_gpu,
|
||||
provider,
|
||||
[self.model_name],
|
||||
None,
|
||||
config_modifier,
|
||||
Precision.FLOAT32,
|
||||
num_threads,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
shark_args.num_iterations,
|
||||
input_counts,
|
||||
optimize_onnx,
|
||||
validate_onnx,
|
||||
cache_dir,
|
||||
onnx_dir,
|
||||
verbose,
|
||||
overwrite,
|
||||
disable_ort_io_binding,
|
||||
use_raw_attention_mask,
|
||||
model_fusion_statistics,
|
||||
model_source,
|
||||
onnx_args,
|
||||
)
|
||||
use_gpu, provider, [self.model_name], None, config_modifier,
|
||||
Precision.FLOAT32, num_threads, batch_sizes, sequence_lengths,
|
||||
shark_args.num_iterations, input_counts, optimize_onnx,
|
||||
validate_onnx, cache_dir, onnx_dir, verbose, overwrite,
|
||||
disable_ort_io_binding, use_raw_attention_mask,
|
||||
model_fusion_statistics, model_source, onnx_args)
|
||||
print(
|
||||
f"ONNX ORT-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils import check_device_drivers
|
||||
|
||||
import torch
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import torchvision.models as models
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
BertTokenizer,
|
||||
TFBertModel,
|
||||
)
|
||||
from transformers import AutoModelForSequenceClassification, BertTokenizer, TFBertModel
|
||||
import importlib
|
||||
import pytest
|
||||
import unittest
|
||||
|
||||
torch.manual_seed(0)
|
||||
gpus = tf.config.experimental.list_physical_devices("GPU")
|
||||
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
|
||||
##################### Tensorflow Hugging Face LM Models ###################################
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
@@ -27,11 +23,12 @@ BATCH_SIZE = 1
|
||||
tf_bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32)
|
||||
]
|
||||
|
||||
|
||||
class TFHuggingFaceLanguage(tf.Module):
|
||||
|
||||
def __init__(self, hf_model_name):
|
||||
super(TFHuggingFaceLanguage, self).__init__()
|
||||
# Create a BERT trainer with the created network.
|
||||
@@ -39,8 +36,7 @@ class TFHuggingFaceLanguage(tf.Module):
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m.predict = lambda x, y, z: self.m.call(
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False)
|
||||
|
||||
@tf.function(input_signature=tf_bert_input)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
@@ -51,21 +47,15 @@ def get_TFhf_model(name):
|
||||
model = TFHuggingFaceLanguage(name)
|
||||
tokenizer = BertTokenizer.from_pretrained(name)
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
encoded_input = tokenizer(text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
test_input = (
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
)
|
||||
tf.convert_to_tensor(encoded_input[key]), 0)
|
||||
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"])
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
@@ -74,13 +64,17 @@ def get_TFhf_model(name):
|
||||
|
||||
|
||||
class HuggingFaceLanguage(torch.nn.Module):
|
||||
|
||||
def __init__(self, hf_model_name):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
hf_model_name, # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
num_labels=
|
||||
2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=
|
||||
False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=
|
||||
False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
@@ -102,6 +96,7 @@ def get_hf_model(name):
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
@@ -122,56 +117,46 @@ def get_vision_model(torch_model):
|
||||
############################# Benchmark Tests ####################################
|
||||
|
||||
pytest_benchmark_param = pytest.mark.parametrize(
|
||||
("dynamic", "device"),
|
||||
('dynamic', 'device'),
|
||||
[
|
||||
pytest.param(False, "cpu"),
|
||||
pytest.param(False, 'cpu'),
|
||||
# TODO: Language models are failing for dynamic case..
|
||||
pytest.param(True, "cpu", marks=pytest.mark.skip),
|
||||
pytest.param(True, 'cpu', marks=pytest.mark.skip),
|
||||
pytest.param(False,
|
||||
'gpu',
|
||||
marks=pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")),
|
||||
pytest.param(True,
|
||||
'gpu',
|
||||
marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"gpu",
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
),
|
||||
),
|
||||
pytest.param(True, "gpu", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"vulkan",
|
||||
'vulkan',
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
),
|
||||
),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)),
|
||||
pytest.param(
|
||||
True,
|
||||
"vulkan",
|
||||
'vulkan',
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)),
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("iree.tools") is None,
|
||||
reason="Cannot find tools to import TF",
|
||||
)
|
||||
@pytest.mark.skipif(importlib.util.find_spec("iree.tools") is None,
|
||||
reason="Cannot find tools to import TF")
|
||||
@pytest_benchmark_param
|
||||
def test_bench_minilm_torch(dynamic, device):
|
||||
model, test_input, act_out = get_hf_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(test_input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
"microsoft/MiniLM-L12-H384-uncased")
|
||||
shark_module = SharkInference(model, (test_input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=True)
|
||||
try:
|
||||
# If becnhmarking succesful, assert success/True.
|
||||
shark_module.compile()
|
||||
@@ -182,21 +167,17 @@ def test_bench_minilm_torch(dynamic, device):
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("iree.tools") is None,
|
||||
reason="Cannot find tools to import TF",
|
||||
)
|
||||
@pytest.mark.skipif(importlib.util.find_spec("iree.tools") is None,
|
||||
reason="Cannot find tools to import TF")
|
||||
@pytest_benchmark_param
|
||||
def test_bench_distilbert(dynamic, device):
|
||||
model, test_input, act_out = get_TFhf_model("distilbert-base-uncased")
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
test_input,
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
shark_module = SharkInference(model,
|
||||
test_input,
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=True)
|
||||
try:
|
||||
# If becnhmarking succesful, assert success/True.
|
||||
shark_module.set_frontend("tensorflow")
|
||||
@@ -212,14 +193,12 @@ def test_bench_distilbert(dynamic, device):
|
||||
@pytest_benchmark_param
|
||||
def test_bench_xlm_roberta(dynamic, device):
|
||||
model, test_input, act_out = get_TFhf_model("xlm-roberta-base")
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
test_input,
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
shark_module = SharkInference(model,
|
||||
test_input,
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=True)
|
||||
try:
|
||||
# If becnhmarking succesful, assert success/True.
|
||||
shark_module.set_frontend("tensorflow")
|
||||
|
||||
@@ -9,31 +9,25 @@ torch.manual_seed(0)
|
||||
|
||||
# Test running benchmark module without failing.
|
||||
pytest_benchmark_param = pytest.mark.parametrize(
|
||||
("dynamic", "device"),
|
||||
('dynamic', 'device'),
|
||||
[
|
||||
pytest.param(False, "cpu"),
|
||||
pytest.param(False, 'cpu'),
|
||||
# TODO: Language models are failing for dynamic case..
|
||||
pytest.param(True, "cpu", marks=pytest.mark.skip),
|
||||
],
|
||||
)
|
||||
pytest.param(True, 'cpu', marks=pytest.mark.skip),
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("onnxruntime") is None,
|
||||
reason="Cannot find ONNXRUNTIME.",
|
||||
)
|
||||
@pytest.mark.skipif(importlib.util.find_spec("onnxruntime") is None,
|
||||
reason="Cannot find ONNXRUNTIME.")
|
||||
@pytest_benchmark_param
|
||||
def test_HFbench_minilm_torch(dynamic, device):
|
||||
model_name = "bert-base-uncased"
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
try:
|
||||
shark_module = SharkHFBenchmarkRunner(
|
||||
model_name,
|
||||
(test_input,),
|
||||
jit_trace=True,
|
||||
dynamic=dynamic,
|
||||
device=device,
|
||||
)
|
||||
shark_module = SharkHFBenchmarkRunner(model_name, (test_input,),
|
||||
jit_trace=True,
|
||||
dynamic=dynamic,
|
||||
device=device)
|
||||
shark_module.benchmark_c()
|
||||
shark_module.benchmark_python((test_input,))
|
||||
shark_module.benchmark_torch(test_input)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
source $GITHUB_WORKSPACE/shark.venv/bin/activate
|
||||
python generate_sharktank.py --upload=False --ci_tank_dir=True
|
||||
@@ -1,37 +0,0 @@
|
||||
"""Scrapes the github releases API to generate a static pip-install-able releases page.
|
||||
|
||||
See https://github.com/llvm/torch-mlir/issues/1374
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
# Parse arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("owner", type=str)
|
||||
parser.add_argument("repo", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get releases
|
||||
response = requests.get(
|
||||
f"https://api.github.com/repos/{args.owner}/{args.repo}/releases"
|
||||
)
|
||||
body = json.loads(response.content)
|
||||
|
||||
# Parse releases
|
||||
releases = []
|
||||
for row in body:
|
||||
for asset in row["assets"]:
|
||||
releases.append((asset["name"], asset["browser_download_url"]))
|
||||
|
||||
# Output HTML
|
||||
html = """<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
"""
|
||||
for name, url in releases:
|
||||
html += f" <a href='{url}'>{name}</a><br />\n"
|
||||
html += """ </body>
|
||||
</html>"""
|
||||
print(html)
|
||||
62
conftest.py
62
conftest.py
@@ -1,62 +0,0 @@
|
||||
def pytest_addoption(parser):
|
||||
# Attaches SHARK command-line arguments to the pytest machinery.
|
||||
parser.addoption(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to benchmark and write results.csv",
|
||||
)
|
||||
parser.addoption(
|
||||
"--onnx_bench",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Add ONNX benchmark results to pytest benchmarks.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--tf32",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Use TensorFloat-32 calculations.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--save_repro",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to save reproduction artifacts to SHARK/shark_tmp/test_case/",
|
||||
)
|
||||
parser.addoption(
|
||||
"--save_fails",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Save reproduction artifacts for a test case only if it fails. Default is False.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--ci",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Enables uploading of reproduction artifacts upon test case failure during iree-compile or validation. Must be passed with --ci_sha option ",
|
||||
)
|
||||
parser.addoption(
|
||||
"--update_tank",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Update local shark tank with latest artifacts.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--ci_sha",
|
||||
action="store",
|
||||
default="None",
|
||||
help="Passes the github SHA of the CI workflow to include in google storage directory for reproduction artifacts.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--local_tank_cache",
|
||||
action="store",
|
||||
default="",
|
||||
help="Specify the directory in which all downloaded shark_tank artifacts will be cached.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--tank_url",
|
||||
type=str,
|
||||
default="gs://shark_tank/latest",
|
||||
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
|
||||
)
|
||||
3
cpp/.gitignore
vendored
3
cpp/.gitignore
vendored
@@ -1,3 +0,0 @@
|
||||
*.mlir
|
||||
*.vmfb
|
||||
*.ini
|
||||
@@ -1,52 +0,0 @@
|
||||
# Copyright 2022 The IREE Authors
|
||||
#
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
cmake_minimum_required(VERSION 3.21...3.23)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Project configuration
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
project(iree-samples C CXX)
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Core project dependency
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
message(STATUS "Fetching core IREE repo (this may take a few minutes)...")
|
||||
# Note: for log output, set -DFETCHCONTENT_QUIET=OFF,
|
||||
# see https://gitlab.kitware.com/cmake/cmake/-/issues/18238#note_440475
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
iree
|
||||
GIT_REPOSITORY https://github.com/nod-ai/shark-runtime.git
|
||||
GIT_TAG shark
|
||||
GIT_SUBMODULES_RECURSE OFF
|
||||
GIT_SHALLOW OFF
|
||||
GIT_PROGRESS ON
|
||||
USES_TERMINAL_DOWNLOAD ON
|
||||
)
|
||||
|
||||
# Extend module path to find MLIR CMake modules.
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_BINARY_DIR}/lib/cmake/mlir")
|
||||
|
||||
# Disable core project features not needed for these out of tree samples.
|
||||
set(IREE_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(IREE_BUILD_SAMPLES OFF CACHE BOOL "" FORCE)
|
||||
|
||||
FetchContent_MakeAvailable(iree)
|
||||
FetchContent_GetProperties(iree SOURCE_DIR IREE_SOURCE_DIR)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Individual samples
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
add_subdirectory(vulkan_gui)
|
||||
@@ -1,82 +0,0 @@
|
||||
# SHARK C/C++ Samples
|
||||
|
||||
These C/C++ samples can be built using CMake. The samples depend on the main
|
||||
SHARK-Runtime project's C/C++ sources, including both the runtime and the compiler.
|
||||
|
||||
Individual samples may require additional dependencies. Watch CMake's output
|
||||
for information about which you are missing for individual samples.
|
||||
|
||||
On Windows we recommend using https://github.com/microsoft/vcpkg to download packages for
|
||||
your system. The general setup flow looks like
|
||||
|
||||
*Install and activate SHARK*
|
||||
|
||||
```bash
|
||||
source shark.venv/bin/activate #follow main repo instructions to setup your venv
|
||||
```
|
||||
|
||||
*Install Dependencies*
|
||||
|
||||
```bash
|
||||
vcpkg install [library] --triplet [your platform]
|
||||
vcpkg integrate install
|
||||
|
||||
# Then pass `-DCMAKE_TOOLCHAIN_FILE=[check logs for path]` when configuring CMake
|
||||
```
|
||||
|
||||
In Ubuntu Linux you can install
|
||||
|
||||
```bash
|
||||
sudo apt install libsdl2-dev
|
||||
```
|
||||
|
||||
*Build*
|
||||
```bash
|
||||
cd cpp
|
||||
cmake -GNinja -B build/
|
||||
cmake --build build/
|
||||
```
|
||||
|
||||
*Prepare the model*
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvm-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
```
|
||||
*Prepare the input*
|
||||
|
||||
```bash
|
||||
python save_img.py
|
||||
```
|
||||
Note that this requires tensorflow, e.g.
|
||||
```bash
|
||||
python -m pip install tensorflow
|
||||
```
|
||||
|
||||
*Run the vulkan_gui*
|
||||
```bash
|
||||
./build/vulkan_gui/iree-samples-resnet-vulkan-gui
|
||||
```
|
||||
|
||||
## Other models
|
||||
A tool for benchmarking other models is built and can be invoked with a command like the following
|
||||
```bash
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=path/to/.vmfb --function_input=...
|
||||
```
|
||||
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
|
||||
```
|
||||
VAE and Autoencoder are also available
|
||||
```bash
|
||||
# VAE
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
|
||||
|
||||
# CLIP Autoencoder
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
|
||||
```
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 26 KiB |
@@ -1,19 +0,0 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
|
||||
|
||||
def load_and_preprocess_image(fname: str):
|
||||
image = tf.io.read_file(fname)
|
||||
image = tf.image.decode_image(image, channels=3)
|
||||
image = tf.image.resize(image, (224, 224))
|
||||
image = image[tf.newaxis, :]
|
||||
# preprocessing pipeline
|
||||
input_tensor = tf.keras.applications.resnet50.preprocess_input(image)
|
||||
return input_tensor
|
||||
|
||||
|
||||
data = load_and_preprocess_image("dog_imagenet.jpg").numpy()
|
||||
|
||||
data.tofile("dog.bin")
|
||||
@@ -1,84 +0,0 @@
|
||||
# Copyright 2022 The IREE Authors
|
||||
#
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
if(NOT IREE_TARGET_BACKEND_LLVM_CPU OR
|
||||
NOT IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF)
|
||||
message(STATUS "Missing LLVM backend and/or embeddded elf loader, skipping vision_inference sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# vcpkg install stb
|
||||
# tested with version 2021-09-10
|
||||
find_package(Stb)
|
||||
if(NOT Stb_FOUND)
|
||||
message(STATUS "Could not find Stb, skipping vision inference sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Compile mnist.mlir to mnist.vmfb.
|
||||
set(_COMPILE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)
|
||||
set(_COMPILE_ARGS)
|
||||
list(APPEND _COMPILE_ARGS "--iree-input-type=mhlo")
|
||||
list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu")
|
||||
list(APPEND _COMPILE_ARGS "${IREE_SOURCE_DIR}/samples/models/mnist.mlir")
|
||||
list(APPEND _COMPILE_ARGS "-o")
|
||||
list(APPEND _COMPILE_ARGS "mnist.vmfb")
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb
|
||||
COMMAND ${_COMPILE_TOOL_EXECUTABLE} ${_COMPILE_ARGS}
|
||||
DEPENDS ${_COMPILE_TOOL_EXECUTABLE} "${IREE_SOURCE_DIR}/samples/models/mnist.mlir"
|
||||
)
|
||||
# Embed mnist.vmfb into a C file as mnist_bytecode_module_c.[h/c]
|
||||
set(_EMBED_DATA_EXECUTABLE $<TARGET_FILE:generate_embed_data>)
|
||||
set(_EMBED_ARGS)
|
||||
list(APPEND _EMBED_ARGS "--output_header=mnist_bytecode_module_c.h")
|
||||
list(APPEND _EMBED_ARGS "--output_impl=mnist_bytecode_module_c.c")
|
||||
list(APPEND _EMBED_ARGS "--identifier=iree_samples_vision_inference_mnist_bytecode_module")
|
||||
list(APPEND _EMBED_ARGS "--flatten")
|
||||
list(APPEND _EMBED_ARGS "${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb")
|
||||
add_custom_command(
|
||||
OUTPUT "mnist_bytecode_module_c.h" "mnist_bytecode_module_c.c"
|
||||
COMMAND ${_EMBED_DATA_EXECUTABLE} ${_EMBED_ARGS}
|
||||
DEPENDS ${_EMBED_DATA_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb
|
||||
)
|
||||
# Define a library target for mnist_bytecode_module_c.
|
||||
add_library(iree_samples_vision_inference_mnist_bytecode_module_c OBJECT)
|
||||
target_sources(iree_samples_vision_inference_mnist_bytecode_module_c
|
||||
PRIVATE
|
||||
mnist_bytecode_module_c.h
|
||||
mnist_bytecode_module_c.c
|
||||
)
|
||||
|
||||
# Define the sample executable.
|
||||
set(_NAME "iree-run-mnist-module")
|
||||
add_executable(${_NAME} "")
|
||||
target_sources(${_NAME}
|
||||
PRIVATE
|
||||
"image_util.h"
|
||||
"image_util.c"
|
||||
"iree-run-mnist-module.c"
|
||||
)
|
||||
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "iree-run-mnist-module")
|
||||
target_include_directories(${_NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
|
||||
)
|
||||
target_include_directories(${_NAME} PRIVATE
|
||||
${Stb_INCLUDE_DIR}
|
||||
)
|
||||
target_link_libraries(${_NAME}
|
||||
iree_base_base
|
||||
iree_base_tracing
|
||||
iree_hal_hal
|
||||
iree_runtime_runtime
|
||||
iree_samples_vision_inference_mnist_bytecode_module_c
|
||||
)
|
||||
|
||||
# Define a target that copies the test image into the build directory.
|
||||
add_custom_target(iree_samples_vision_inference_test_image
|
||||
COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/mnist_test.png" "${CMAKE_CURRENT_BINARY_DIR}/mnist_test.png")
|
||||
add_dependencies(${_NAME} iree_samples_vision_inference_test_image)
|
||||
|
||||
message(STATUS "Configured vision_inference sample successfully")
|
||||
@@ -1,8 +0,0 @@
|
||||
# Vision Inference Sample (C code)
|
||||
|
||||
This sample demonstrates how to run a MNIST handwritten digit detection vision
|
||||
model on an image using IREE's C API.
|
||||
|
||||
A similar sample is implemented using a Python script and IREE's command line
|
||||
tools over in the primary iree repository at
|
||||
https://github.com/iree-org/iree/tree/main/samples/vision_inference
|
||||
@@ -1,224 +0,0 @@
|
||||
// Copyright 2021 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
#include "image_util.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include "iree/base/internal/flags.h"
|
||||
#include "iree/base/tracing.h"
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
iree_status_t iree_tools_utils_pixel_rescaled_to_buffer(
|
||||
const uint8_t* pixel_data, iree_host_size_t buffer_length,
|
||||
const float* input_range, iree_host_size_t range_length,
|
||||
float* out_buffer) {
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
if (range_length != 2) {
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"range defined as 2-element [min, max] array.");
|
||||
}
|
||||
float input_scale = fabsf(input_range[1] - input_range[0]) / 2.0f;
|
||||
float input_offset = (input_range[0] + input_range[1]) / 2.0f;
|
||||
const float kUint8Mean = 127.5f;
|
||||
for (int i = 0; i < buffer_length; ++i) {
|
||||
out_buffer[i] =
|
||||
(((float)(pixel_data[i])) - kUint8Mean) / kUint8Mean * input_scale +
|
||||
input_offset;
|
||||
}
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return iree_ok_status();
|
||||
}
|
||||
|
||||
iree_status_t iree_tools_utils_load_pixel_data_impl(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length) {
|
||||
int img_dims[3];
|
||||
if (stbi_info(filename.data, img_dims, &(img_dims[1]), &(img_dims[2])) == 0) {
|
||||
return iree_make_status(IREE_STATUS_NOT_FOUND, "can't load image %.*s",
|
||||
(int)filename.size, filename.data);
|
||||
}
|
||||
if (!(element_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 ||
|
||||
element_type == IREE_HAL_ELEMENT_TYPE_SINT_8 ||
|
||||
element_type == IREE_HAL_ELEMENT_TYPE_UINT_8)) {
|
||||
char element_type_str[16];
|
||||
IREE_RETURN_IF_ERROR(iree_hal_format_element_type(
|
||||
element_type, sizeof(element_type_str), element_type_str, NULL));
|
||||
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
|
||||
"element type %s not supported", element_type_str);
|
||||
}
|
||||
switch (shape_rank) {
|
||||
case 2: { // Assume tensor <height x width>
|
||||
if (img_dims[2] != 1 || (shape[0] != img_dims[1]) ||
|
||||
(shape[1] != img_dims[0])) {
|
||||
return iree_make_status(
|
||||
IREE_STATUS_INVALID_ARGUMENT,
|
||||
"image size: %dx%dx%d, expected: %" PRIdim "x%" PRIdim, img_dims[0],
|
||||
img_dims[1], img_dims[2], shape[1], shape[0]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 3: { // Assume tensor <height x width x channel>
|
||||
if (shape[0] != img_dims[1] || shape[1] != img_dims[0] ||
|
||||
shape[2] != img_dims[2]) {
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"image size: %dx%dx%d, expected: %" PRIdim
|
||||
"x%" PRIdim "x%" PRIdim,
|
||||
img_dims[0], img_dims[1], img_dims[2], shape[1],
|
||||
shape[0], shape[2]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 4: { // Assume tensor <batch x height x width x channel>
|
||||
if (shape[1] != img_dims[1] || shape[2] != img_dims[0] ||
|
||||
shape[3] != img_dims[2]) {
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"image size: %dx%dx%d, expected: %" PRIdim
|
||||
"x%" PRIdim "x%" PRIdim,
|
||||
img_dims[0], img_dims[1], img_dims[2], shape[2],
|
||||
shape[1], shape[3]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return iree_make_status(
|
||||
IREE_STATUS_INVALID_ARGUMENT,
|
||||
"Input buffer shape rank %" PRIhsz " not supported", shape_rank);
|
||||
}
|
||||
// Drop the alpha channel if present.
|
||||
int req_ch = (img_dims[2] >= 3) ? 3 : 0;
|
||||
*out_pixel_data = stbi_load(filename.data, img_dims, &(img_dims[1]),
|
||||
&(img_dims[2]), req_ch);
|
||||
if (*out_pixel_data == NULL) {
|
||||
return iree_make_status(IREE_STATUS_NOT_FOUND, "can't load image %.*s",
|
||||
(int)filename.size, filename.data);
|
||||
}
|
||||
*out_buffer_length =
|
||||
img_dims[0] * img_dims[1] * (img_dims[2] > 3 ? 3 : img_dims[2]);
|
||||
return iree_ok_status();
|
||||
}
|
||||
|
||||
iree_status_t iree_tools_utils_load_pixel_data(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length) {
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
iree_status_t result = iree_tools_utils_load_pixel_data_impl(
|
||||
filename, shape, shape_rank, element_type, out_pixel_data,
|
||||
out_buffer_length);
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return result;
|
||||
}
|
||||
|
||||
iree_status_t iree_tools_utils_buffer_view_from_image(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
iree_hal_allocator_t* allocator, iree_hal_buffer_view_t** out_buffer_view) {
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
*out_buffer_view = NULL;
|
||||
if (element_type != IREE_HAL_ELEMENT_TYPE_SINT_8 &&
|
||||
element_type != IREE_HAL_ELEMENT_TYPE_UINT_8) {
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"element type should be i8 or u8");
|
||||
}
|
||||
|
||||
iree_status_t result;
|
||||
uint8_t* pixel_data = NULL;
|
||||
iree_host_size_t buffer_length;
|
||||
result = iree_tools_utils_load_pixel_data(
|
||||
filename, shape, shape_rank, element_type, &pixel_data, &buffer_length);
|
||||
if (iree_status_is_ok(result)) {
|
||||
iree_host_size_t element_byte =
|
||||
iree_hal_element_dense_byte_count(element_type);
|
||||
// SINT_8 and UINT_8 perform direct buffer wrap.
|
||||
result = iree_hal_buffer_view_allocate_buffer(
|
||||
allocator, shape_rank, shape, element_type,
|
||||
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
|
||||
(iree_hal_buffer_params_t){
|
||||
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
|
||||
.access = IREE_HAL_MEMORY_ACCESS_READ,
|
||||
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
|
||||
IREE_HAL_BUFFER_USAGE_TRANSFER,
|
||||
},
|
||||
iree_make_const_byte_span(pixel_data, element_byte * buffer_length),
|
||||
out_buffer_view);
|
||||
}
|
||||
stbi_image_free(pixel_data);
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return result;
|
||||
}
|
||||
|
||||
typedef struct iree_tools_utils_buffer_view_load_params_t {
|
||||
const uint8_t* pixel_data;
|
||||
iree_host_size_t pixel_data_length;
|
||||
const float* input_range;
|
||||
iree_host_size_t input_range_length;
|
||||
} iree_tools_utils_buffer_view_load_params_t;
|
||||
static iree_status_t iree_tools_utils_buffer_view_load_image_rescaled(
|
||||
iree_hal_buffer_mapping_t* mapping, void* user_data) {
|
||||
iree_tools_utils_buffer_view_load_params_t* params =
|
||||
(iree_tools_utils_buffer_view_load_params_t*)user_data;
|
||||
return iree_tools_utils_pixel_rescaled_to_buffer(
|
||||
params->pixel_data, params->pixel_data_length, params->input_range,
|
||||
params->input_range_length, (float*)mapping->contents.data);
|
||||
}
|
||||
|
||||
iree_status_t iree_tools_utils_buffer_view_from_image_rescaled(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
iree_hal_allocator_t* allocator, const float* input_range,
|
||||
iree_host_size_t input_range_length,
|
||||
iree_hal_buffer_view_t** out_buffer_view) {
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
*out_buffer_view = NULL;
|
||||
if (element_type != IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"element type should be f32");
|
||||
}
|
||||
|
||||
// Classic row-major image layout.
|
||||
iree_hal_encoding_type_t encoding_type =
|
||||
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
|
||||
|
||||
// Load pixel data from the file into a new host memory allocation (the only
|
||||
// interface stb_image provides). A real application would want to use the
|
||||
// generation callback to directly decode the image into the target mapped
|
||||
// device buffer.
|
||||
uint8_t* pixel_data = NULL;
|
||||
iree_host_size_t buffer_length = 0;
|
||||
IREE_RETURN_AND_END_ZONE_IF_ERROR(
|
||||
z0, iree_tools_utils_load_pixel_data(filename, shape, shape_rank,
|
||||
element_type, &pixel_data,
|
||||
&buffer_length));
|
||||
|
||||
iree_tools_utils_buffer_view_load_params_t params = {
|
||||
.pixel_data = pixel_data,
|
||||
.pixel_data_length = buffer_length,
|
||||
.input_range = input_range,
|
||||
.input_range_length = input_range_length,
|
||||
};
|
||||
iree_status_t status = iree_hal_buffer_view_generate_buffer(
|
||||
allocator, shape_rank, shape, element_type, encoding_type,
|
||||
(iree_hal_buffer_params_t){
|
||||
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
|
||||
IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
|
||||
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
|
||||
IREE_HAL_BUFFER_USAGE_TRANSFER |
|
||||
IREE_HAL_BUFFER_USAGE_MAPPING,
|
||||
},
|
||||
iree_tools_utils_buffer_view_load_image_rescaled, ¶ms,
|
||||
out_buffer_view);
|
||||
|
||||
stbi_image_free(pixel_data);
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return status;
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
// Copyright 2021 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
#ifndef IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
|
||||
#define IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
|
||||
|
||||
#include "iree/base/api.h"
|
||||
#include "iree/hal/api.h"
|
||||
#include "iree/hal/buffer_view.h"
|
||||
|
||||
#if __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// Loads the image at |filename| into |out_pixel_data| and sets
|
||||
// |out_buffer_length| to its length.
|
||||
//
|
||||
// The image dimension must match the width, height, and channel in|shape|,
|
||||
// while 2 <= |shape_rank| <= 4 to match the image tensor format.
|
||||
//
|
||||
// The file must be in a format supported by stb_image.h.
|
||||
// The returned |out_pixel_data| buffer must be released by the caller.
|
||||
iree_status_t iree_tools_utils_load_pixel_data(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length);
|
||||
|
||||
// Parse the content in an image file in |filename| into a HAL buffer view
|
||||
// |out_buffer_view|. |out_buffer_view| properties are defined by |shape|,
|
||||
// |shape_rank|, and |element_type|, while being allocated by |allocator|.
|
||||
//
|
||||
// The |element_type| has to be SINT_8 or UINT_8. For FLOAT_32, use
|
||||
// |iree_tools_utils_buffer_view_from_image_rescaled| instead.
|
||||
//
|
||||
// The returned |out_buffer_view| must be released by the caller.
|
||||
iree_status_t iree_tools_utils_buffer_view_from_image(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
iree_hal_allocator_t* allocator, iree_hal_buffer_view_t** out_buffer_view);
|
||||
|
||||
// Parse the content in an image file in |filename| into a HAL buffer view
|
||||
// |out_buffer_view|. |out_buffer_view| properties are defined by |shape|,
|
||||
// |shape_rank|, and |element_type|, while being allocated by |allocator|.
|
||||
// The value in |out_buffer_view| is rescaled with |input_range|.
|
||||
//
|
||||
// The |element_type| has to be FLOAT_32, For SINT_8 or UINT_8, use
|
||||
// |iree_tools_utils_buffer_view_from_image| instead.
|
||||
//
|
||||
// The returned |out_buffer_view| must be released by the caller.
|
||||
iree_status_t iree_tools_utils_buffer_view_from_image_rescaled(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
iree_hal_allocator_t* allocator, const float* input_range,
|
||||
iree_host_size_t input_range_length,
|
||||
iree_hal_buffer_view_t** out_buffer_view);
|
||||
|
||||
// Normalize uint8_t |pixel_data| of the size |buffer_length| to float buffer
|
||||
// |out_buffer| with the range |input_range|.
|
||||
//
|
||||
// float32_x = (uint8_x - 127.5) / 127.5 * input_scale + input_offset, where
|
||||
// input_scale = abs(|input_range[0]| - |input_range[1]| / 2
|
||||
// input_offset = |input_range[0]| + |input_range[1]| / 2
|
||||
//
|
||||
// |out_buffer| needs to be allocated before the call.
|
||||
iree_status_t iree_tools_utils_pixel_rescaled_to_buffer(
|
||||
const uint8_t* pixel_data, iree_host_size_t pixel_count,
|
||||
const float* input_range, iree_host_size_t input_range_length,
|
||||
float* out_buffer);
|
||||
|
||||
#if __cplusplus
|
||||
}
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
|
||||
@@ -1,121 +0,0 @@
|
||||
// Copyright 2021 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
// This sample uses image_util to load a hand-written image as an
|
||||
// iree_hal_buffer_view_t then passes it to the bytecode module built from
|
||||
// mnist.mlir on the CPU backend with the local-task driver.
|
||||
|
||||
#include <float.h>
|
||||
|
||||
#include "image_util.h"
|
||||
#include "iree/runtime/api.h"
|
||||
#include "mnist_bytecode_module_c.h"
|
||||
|
||||
iree_status_t Run(const iree_string_view_t image_path) {
|
||||
iree_runtime_instance_options_t instance_options;
|
||||
iree_runtime_instance_options_initialize(IREE_API_VERSION_LATEST,
|
||||
&instance_options);
|
||||
iree_runtime_instance_options_use_all_available_drivers(&instance_options);
|
||||
iree_runtime_instance_t* instance = NULL;
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_instance_create(
|
||||
&instance_options, iree_allocator_system(), &instance));
|
||||
|
||||
// TODO(#5724): move device selection into the compiled modules.
|
||||
iree_hal_device_t* device = NULL;
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_instance_try_create_default_device(
|
||||
instance, iree_make_cstring_view("local-task"), &device));
|
||||
|
||||
// Create one session per loaded module to hold the module state.
|
||||
iree_runtime_session_options_t session_options;
|
||||
iree_runtime_session_options_initialize(&session_options);
|
||||
iree_runtime_session_t* session = NULL;
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_session_create_with_device(
|
||||
instance, &session_options, device,
|
||||
iree_runtime_instance_host_allocator(instance), &session));
|
||||
iree_hal_device_release(device);
|
||||
|
||||
const struct iree_file_toc_t* module_file =
|
||||
iree_samples_vision_inference_mnist_bytecode_module_create();
|
||||
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_session_append_bytecode_module_from_memory(
|
||||
session, iree_make_const_byte_span(module_file->data, module_file->size),
|
||||
iree_allocator_null()));
|
||||
|
||||
iree_runtime_call_t call;
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name(
|
||||
session, iree_make_cstring_view("module.predict"), &call));
|
||||
|
||||
// Prepare the input hal buffer view with image_util library.
|
||||
// The input of the mmist model is single 28x28 pixel image as a
|
||||
// tensor<1x28x28x1xf32>, with pixels in [0.0, 1.0].
|
||||
iree_hal_buffer_view_t* buffer_view = NULL;
|
||||
iree_hal_dim_t buffer_shape[] = {1, 28, 28, 1};
|
||||
iree_hal_element_type_t hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32;
|
||||
float input_range[2] = {0.0f, 1.0f};
|
||||
IREE_RETURN_IF_ERROR(
|
||||
iree_tools_utils_buffer_view_from_image_rescaled(
|
||||
image_path, buffer_shape, IREE_ARRAYSIZE(buffer_shape),
|
||||
hal_element_type, iree_hal_device_allocator(device), input_range,
|
||||
IREE_ARRAYSIZE(input_range), &buffer_view),
|
||||
"load image");
|
||||
IREE_RETURN_IF_ERROR(
|
||||
iree_runtime_call_inputs_push_back_buffer_view(&call, buffer_view));
|
||||
iree_hal_buffer_view_release(buffer_view);
|
||||
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_call_invoke(&call, /*flags=*/0));
|
||||
|
||||
// Get the result buffers from the invocation.
|
||||
iree_hal_buffer_view_t* ret_buffer_view = NULL;
|
||||
IREE_RETURN_IF_ERROR(
|
||||
iree_runtime_call_outputs_pop_front_buffer_view(&call, &ret_buffer_view));
|
||||
|
||||
// Read back the results. The output of the mnist model is a 1x10 prediction
|
||||
// confidence values for each digit in [0, 9].
|
||||
float predictions[1 * 10] = {0.0f};
|
||||
IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
|
||||
iree_runtime_session_device(session),
|
||||
iree_hal_buffer_view_buffer(ret_buffer_view), 0, predictions,
|
||||
sizeof(predictions), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
|
||||
iree_infinite_timeout()));
|
||||
iree_hal_buffer_view_release(ret_buffer_view);
|
||||
|
||||
// Get the highest index from the output.
|
||||
float result_val = FLT_MIN;
|
||||
int result_idx = 0;
|
||||
for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(predictions); ++i) {
|
||||
if (predictions[i] > result_val) {
|
||||
result_val = predictions[i];
|
||||
result_idx = i;
|
||||
}
|
||||
}
|
||||
fprintf(stdout, "Detected number: %d\n", result_idx);
|
||||
|
||||
iree_runtime_call_deinitialize(&call);
|
||||
iree_runtime_session_release(session);
|
||||
iree_runtime_instance_release(instance);
|
||||
return iree_ok_status();
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc > 2) {
|
||||
fprintf(stderr, "Usage: iree-run-mnist-module <image file>\n");
|
||||
return -1;
|
||||
}
|
||||
iree_string_view_t image_path;
|
||||
if (argc == 1) {
|
||||
image_path = iree_make_cstring_view("mnist_test.png");
|
||||
} else {
|
||||
image_path = iree_make_cstring_view(argv[1]);
|
||||
}
|
||||
iree_status_t result = Run(image_path);
|
||||
if (!iree_status_is_ok(result)) {
|
||||
iree_status_fprint(stderr, result);
|
||||
iree_status_ignore(result);
|
||||
return -1;
|
||||
}
|
||||
iree_status_ignore(result);
|
||||
return 0;
|
||||
}
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 261 B |
@@ -1,116 +0,0 @@
|
||||
# Copyright 2022 The IREE Authors
|
||||
#
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
if(NOT IREE_TARGET_BACKEND_VULKAN_SPIRV OR
|
||||
NOT IREE_HAL_DRIVER_VULKAN)
|
||||
message(STATUS "Missing Vulkan backend and/or driver, skipping vulkan_gui sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# This target statically links against Vulkan.
|
||||
# One way to achieve this is by installing the Vulkan SDK from
|
||||
# https://vulkan.lunarg.com/.
|
||||
include(FindVulkan)
|
||||
if(NOT Vulkan_FOUND)
|
||||
message(STATUS "Could not find Vulkan, skipping vulkan_gui sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# vcpkg install sdl2[vulkan]
|
||||
# tested with versions 2.0.14#4 - 2.0.22#1
|
||||
find_package(SDL2)
|
||||
if(NOT SDL2_FOUND)
|
||||
message(STATUS "Could not find SDL2, skipping vulkan_gui sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
imgui
|
||||
GIT_REPOSITORY https://github.com/ocornut/imgui
|
||||
GIT_TAG master
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(imgui)
|
||||
|
||||
# Dear ImGui
|
||||
set(IMGUI_DIR ${CMAKE_BINARY_DIR}/_deps/imgui-src)
|
||||
message("Looking for Imgui in ${IMGUI_DIR}")
|
||||
include_directories(${IMGUI_DIR} ${IMGUI_DIR}/backends ..)
|
||||
|
||||
|
||||
function(iree_vulkan_sample)
|
||||
|
||||
cmake_parse_arguments(
|
||||
_RULE
|
||||
""
|
||||
"NAME"
|
||||
"SRCS"
|
||||
${ARGN}
|
||||
)
|
||||
|
||||
|
||||
# Define the sample executable.
|
||||
set(_NAME "${_RULE_NAME}")
|
||||
set(SRCS "${_RULE_SRCS}")
|
||||
add_executable(${_NAME} "")
|
||||
target_sources(${_NAME}
|
||||
PRIVATE
|
||||
${SRCS}
|
||||
"${IMGUI_DIR}/backends/imgui_impl_sdl.cpp"
|
||||
"${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp"
|
||||
"${IMGUI_DIR}/imgui.cpp"
|
||||
"${IMGUI_DIR}/imgui_draw.cpp"
|
||||
"${IMGUI_DIR}/imgui_demo.cpp"
|
||||
"${IMGUI_DIR}/imgui_tables.cpp"
|
||||
"${IMGUI_DIR}/imgui_widgets.cpp"
|
||||
)
|
||||
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_NAME}")
|
||||
target_include_directories(${_NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
|
||||
)
|
||||
target_link_libraries(${_NAME}
|
||||
SDL2::SDL2
|
||||
Vulkan::Vulkan
|
||||
iree_runtime_runtime
|
||||
iree_base_internal_main
|
||||
iree_hal_drivers_vulkan_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_vm_cc
|
||||
iree_tooling_vm_util_cc
|
||||
iree_tooling_context_util
|
||||
)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
|
||||
set(_GUI_LINKOPTS "-SUBSYSTEM:CONSOLE")
|
||||
else()
|
||||
set(_GUI_LINKOPTS "")
|
||||
endif()
|
||||
|
||||
target_link_options(${_NAME}
|
||||
PRIVATE
|
||||
${_GUI_LINKOPTS}
|
||||
)
|
||||
endfunction()
|
||||
|
||||
iree_vulkan_sample(
|
||||
NAME
|
||||
iree-samples-resnet-vulkan-gui
|
||||
|
||||
SRCS
|
||||
vulkan_resnet_inference_gui.cc
|
||||
)
|
||||
|
||||
iree_vulkan_sample(
|
||||
NAME
|
||||
iree-vulkan-gui
|
||||
|
||||
SRCS
|
||||
vulkan_inference_gui.cc
|
||||
)
|
||||
|
||||
message(STATUS "Configured vulkan_gui sample successfully")
|
||||
@@ -1,4 +0,0 @@
|
||||
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = "arith.mulf"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 14 KiB |
File diff suppressed because it is too large
Load Diff
@@ -1,957 +0,0 @@
|
||||
// Copyright 2019 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
// Vulkan Graphics + IREE API Integration Sample.
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_vulkan.h>
|
||||
#include <imgui.h>
|
||||
#include <imgui_impl_sdl.h>
|
||||
#include <imgui_impl_vulkan.h>
|
||||
#include <vulkan/vulkan.h>
|
||||
|
||||
|
||||
#include <cstring>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <array>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "iree/hal/drivers/vulkan/api.h"
|
||||
|
||||
// IREE's C API:
|
||||
#include "iree/base/api.h"
|
||||
#include "iree/hal/api.h"
|
||||
#include "iree/hal/drivers/vulkan/registration/driver_module.h"
|
||||
#include "iree/modules/hal/module.h"
|
||||
#include "iree/vm/api.h"
|
||||
#include "iree/vm/bytecode_module.h"
|
||||
#include "iree/vm/ref_cc.h"
|
||||
|
||||
// iree-run-module
|
||||
#include "iree/base/internal/flags.h"
|
||||
#include "iree/base/status_cc.h"
|
||||
#include "iree/base/tracing.h"
|
||||
#include "iree/modules/hal/types.h"
|
||||
#include "iree/tooling/comparison.h"
|
||||
#include "iree/tooling/context_util.h"
|
||||
#include "iree/tooling/vm_util_cc.h"
|
||||
|
||||
// Other dependencies (helpers, etc.)
|
||||
#include "iree/base/internal/main.h"
|
||||
|
||||
#define IMGUI_UNLIMITED_FRAME_RATE
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
IREE_FLAG(string, entry_function, "",
|
||||
"Name of a function contained in the module specified by module_file "
|
||||
"to run.");
|
||||
|
||||
// TODO(benvanik): move --function_input= flag into a util.
|
||||
static iree_status_t parse_function_io(iree_string_view_t flag_name,
|
||||
void* storage,
|
||||
iree_string_view_t value) {
|
||||
auto* list = (std::vector<std::string>*)storage;
|
||||
list->push_back(std::string(value.data, value.size));
|
||||
return iree_ok_status();
|
||||
}
|
||||
static void print_function_io(iree_string_view_t flag_name, void* storage,
|
||||
FILE* file) {
|
||||
auto* list = (std::vector<std::string>*)storage;
|
||||
if (list->empty()) {
|
||||
fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
|
||||
} else {
|
||||
for (size_t i = 0; i < list->size(); ++i) {
|
||||
fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
|
||||
list->at(i).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
static std::vector<std::string> FLAG_function_inputs;
|
||||
IREE_FLAG_CALLBACK(
|
||||
parse_function_io, print_function_io, &FLAG_function_inputs, function_input,
|
||||
"An input (a) value or (b) buffer of the format:\n"
|
||||
" (a) scalar value\n"
|
||||
" value\n"
|
||||
" e.g.: --function_input=\"3.14\"\n"
|
||||
" (b) buffer:\n"
|
||||
" [shape]xtype=[value]\n"
|
||||
" e.g.: --function_input=\"2x2xi32=1 2 3 4\"\n"
|
||||
"Optionally, brackets may be used to separate the element values:\n"
|
||||
" 2x2xi32=[[1 2][3 4]]\n"
|
||||
"Raw binary files can be read to provide buffer contents:\n"
|
||||
" 2x2xi32=@some/file.bin\n"
|
||||
"numpy npy files (from numpy.save) can be read to provide 1+ values:\n"
|
||||
" @some.npy\n"
|
||||
"Each occurrence of the flag indicates an input in the order they were\n"
|
||||
"specified on the command line.");
|
||||
|
||||
typedef struct iree_file_toc_t {
|
||||
const char* name; // the file's original name
|
||||
char* data; // beginning of the file
|
||||
size_t size; // length of the file
|
||||
} iree_file_toc_t;
|
||||
|
||||
bool load_file(const char* filename, char** pOut, size_t* pSize)
|
||||
{
|
||||
FILE* f = fopen(filename, "rb");
|
||||
if (f == NULL)
|
||||
{
|
||||
fprintf(stderr, "Can't open %s\n", filename);
|
||||
return false;
|
||||
}
|
||||
|
||||
fseek(f, 0L, SEEK_END);
|
||||
*pSize = ftell(f);
|
||||
fseek(f, 0L, SEEK_SET);
|
||||
|
||||
*pOut = (char*)malloc(*pSize);
|
||||
|
||||
size_t size = fread(*pOut, *pSize, 1, f);
|
||||
|
||||
fclose(f);
|
||||
|
||||
return size != 0;
|
||||
}
|
||||
|
||||
static VkAllocationCallbacks* g_Allocator = NULL;
|
||||
static VkInstance g_Instance = VK_NULL_HANDLE;
|
||||
static VkPhysicalDevice g_PhysicalDevice = VK_NULL_HANDLE;
|
||||
static VkDevice g_Device = VK_NULL_HANDLE;
|
||||
static uint32_t g_QueueFamily = (uint32_t)-1;
|
||||
static VkQueue g_Queue = VK_NULL_HANDLE;
|
||||
static VkPipelineCache g_PipelineCache = VK_NULL_HANDLE;
|
||||
static VkDescriptorPool g_DescriptorPool = VK_NULL_HANDLE;
|
||||
|
||||
static ImGui_ImplVulkanH_Window g_MainWindowData;
|
||||
static uint32_t g_MinImageCount = 2;
|
||||
static bool g_SwapChainRebuild = false;
|
||||
static int g_SwapChainResizeWidth = 0;
|
||||
static int g_SwapChainResizeHeight = 0;
|
||||
|
||||
static void check_vk_result(VkResult err) {
|
||||
if (err == 0) return;
|
||||
fprintf(stderr, "VkResult: %d\n", err);
|
||||
abort();
|
||||
}
|
||||
|
||||
// Returns the names of the Vulkan layers used for the given IREE
|
||||
// |extensibility_set| and |features|.
|
||||
std::vector<const char*> GetIreeLayers(
|
||||
iree_hal_vulkan_extensibility_set_t extensibility_set,
|
||||
iree_hal_vulkan_features_t features) {
|
||||
iree_host_size_t required_count;
|
||||
iree_hal_vulkan_query_extensibility_set(
|
||||
features, extensibility_set, /*string_capacity=*/0, &required_count,
|
||||
/*out_string_values=*/NULL);
|
||||
std::vector<const char*> layers(required_count);
|
||||
iree_hal_vulkan_query_extensibility_set(features, extensibility_set,
|
||||
layers.size(), &required_count,
|
||||
layers.data());
|
||||
return layers;
|
||||
}
|
||||
|
||||
// Returns the names of the Vulkan extensions used for the given IREE
|
||||
// |extensibility_set| and |features|.
|
||||
std::vector<const char*> GetIreeExtensions(
|
||||
iree_hal_vulkan_extensibility_set_t extensibility_set,
|
||||
iree_hal_vulkan_features_t features) {
|
||||
iree_host_size_t required_count;
|
||||
iree_hal_vulkan_query_extensibility_set(
|
||||
features, extensibility_set, /*string_capacity=*/0, &required_count,
|
||||
/*out_string_values=*/NULL);
|
||||
std::vector<const char*> extensions(required_count);
|
||||
iree_hal_vulkan_query_extensibility_set(features, extensibility_set,
|
||||
extensions.size(), &required_count,
|
||||
extensions.data());
|
||||
return extensions;
|
||||
}
|
||||
|
||||
// Returns the names of the Vulkan extensions used for the given IREE
|
||||
// |vulkan_features|.
|
||||
std::vector<const char*> GetDeviceExtensions(
|
||||
VkPhysicalDevice physical_device,
|
||||
iree_hal_vulkan_features_t vulkan_features) {
|
||||
std::vector<const char*> iree_required_extensions = GetIreeExtensions(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_REQUIRED,
|
||||
vulkan_features);
|
||||
std::vector<const char*> iree_optional_extensions = GetIreeExtensions(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
|
||||
vulkan_features);
|
||||
|
||||
uint32_t extension_count = 0;
|
||||
check_vk_result(vkEnumerateDeviceExtensionProperties(
|
||||
physical_device, nullptr, &extension_count, nullptr));
|
||||
std::vector<VkExtensionProperties> extension_properties(extension_count);
|
||||
check_vk_result(vkEnumerateDeviceExtensionProperties(
|
||||
physical_device, nullptr, &extension_count, extension_properties.data()));
|
||||
|
||||
// Merge extensions lists, including optional and required for simplicity.
|
||||
std::set<const char*> ext_set;
|
||||
ext_set.insert("VK_KHR_swapchain");
|
||||
ext_set.insert(iree_required_extensions.begin(),
|
||||
iree_required_extensions.end());
|
||||
for (int i = 0; i < iree_optional_extensions.size(); ++i) {
|
||||
const char* optional_extension = iree_optional_extensions[i];
|
||||
for (int j = 0; j < extension_count; ++j) {
|
||||
if (strcmp(optional_extension, extension_properties[j].extensionName) ==
|
||||
0) {
|
||||
ext_set.insert(optional_extension);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<const char*> extensions(ext_set.begin(), ext_set.end());
|
||||
return extensions;
|
||||
}
|
||||
|
||||
std::vector<const char*> GetInstanceLayers(
|
||||
iree_hal_vulkan_features_t vulkan_features) {
|
||||
// Query the layers that IREE wants / needs.
|
||||
std::vector<const char*> required_layers = GetIreeLayers(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_REQUIRED, vulkan_features);
|
||||
std::vector<const char*> optional_layers = GetIreeLayers(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL, vulkan_features);
|
||||
|
||||
// Query the layers that are available on the Vulkan ICD.
|
||||
uint32_t layer_property_count = 0;
|
||||
check_vk_result(
|
||||
vkEnumerateInstanceLayerProperties(&layer_property_count, NULL));
|
||||
std::vector<VkLayerProperties> layer_properties(layer_property_count);
|
||||
check_vk_result(vkEnumerateInstanceLayerProperties(&layer_property_count,
|
||||
layer_properties.data()));
|
||||
|
||||
// Match between optional/required and available layers.
|
||||
std::vector<const char*> layers;
|
||||
for (const char* layer_name : required_layers) {
|
||||
bool found = false;
|
||||
for (const auto& layer_property : layer_properties) {
|
||||
if (std::strcmp(layer_name, layer_property.layerName) == 0) {
|
||||
found = true;
|
||||
layers.push_back(layer_name);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
fprintf(stderr, "Required layer %s not available\n", layer_name);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
for (const char* layer_name : optional_layers) {
|
||||
for (const auto& layer_property : layer_properties) {
|
||||
if (std::strcmp(layer_name, layer_property.layerName) == 0) {
|
||||
layers.push_back(layer_name);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return layers;
|
||||
}
|
||||
|
||||
std::vector<const char*> GetInstanceExtensions(
|
||||
SDL_Window* window, iree_hal_vulkan_features_t vulkan_features) {
|
||||
// Ask SDL for its list of required instance extensions.
|
||||
uint32_t sdl_extensions_count = 0;
|
||||
SDL_Vulkan_GetInstanceExtensions(window, &sdl_extensions_count, NULL);
|
||||
std::vector<const char*> sdl_extensions(sdl_extensions_count);
|
||||
SDL_Vulkan_GetInstanceExtensions(window, &sdl_extensions_count,
|
||||
sdl_extensions.data());
|
||||
|
||||
std::vector<const char*> iree_required_extensions = GetIreeExtensions(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_REQUIRED,
|
||||
vulkan_features);
|
||||
std::vector<const char*> iree_optional_extensions = GetIreeExtensions(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL,
|
||||
vulkan_features);
|
||||
|
||||
// Merge extensions lists, including optional and required for simplicity.
|
||||
std::set<const char*> ext_set;
|
||||
ext_set.insert(sdl_extensions.begin(), sdl_extensions.end());
|
||||
ext_set.insert(iree_required_extensions.begin(),
|
||||
iree_required_extensions.end());
|
||||
ext_set.insert(iree_optional_extensions.begin(),
|
||||
iree_optional_extensions.end());
|
||||
std::vector<const char*> extensions(ext_set.begin(), ext_set.end());
|
||||
return extensions;
|
||||
}
|
||||
|
||||
void SetupVulkan(iree_hal_vulkan_features_t vulkan_features,
|
||||
const char** instance_layers, uint32_t instance_layers_count,
|
||||
const char** instance_extensions,
|
||||
uint32_t instance_extensions_count,
|
||||
const VkAllocationCallbacks* allocator, VkInstance* instance,
|
||||
uint32_t* queue_family_index,
|
||||
VkPhysicalDevice* physical_device, VkQueue* queue,
|
||||
VkDevice* device, VkDescriptorPool* descriptor_pool) {
|
||||
VkResult err;
|
||||
|
||||
// Create Vulkan Instance
|
||||
{
|
||||
VkInstanceCreateInfo create_info = {};
|
||||
create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
|
||||
create_info.enabledLayerCount = instance_layers_count;
|
||||
create_info.ppEnabledLayerNames = instance_layers;
|
||||
create_info.enabledExtensionCount = instance_extensions_count;
|
||||
create_info.ppEnabledExtensionNames = instance_extensions;
|
||||
err = vkCreateInstance(&create_info, allocator, instance);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Select GPU
|
||||
{
|
||||
uint32_t gpu_count;
|
||||
err = vkEnumeratePhysicalDevices(*instance, &gpu_count, NULL);
|
||||
check_vk_result(err);
|
||||
IM_ASSERT(gpu_count > 0);
|
||||
|
||||
VkPhysicalDevice* gpus =
|
||||
(VkPhysicalDevice*)malloc(sizeof(VkPhysicalDevice) * gpu_count);
|
||||
err = vkEnumeratePhysicalDevices(*instance, &gpu_count, gpus);
|
||||
check_vk_result(err);
|
||||
|
||||
// Use the first reported GPU for simplicity.
|
||||
*physical_device = gpus[0];
|
||||
|
||||
VkPhysicalDeviceProperties properties;
|
||||
vkGetPhysicalDeviceProperties(*physical_device, &properties);
|
||||
fprintf(stdout, "Selected Vulkan device: '%s'\n", properties.deviceName);
|
||||
free(gpus);
|
||||
}
|
||||
|
||||
// Select queue family. We want a single queue with graphics and compute for
|
||||
// simplicity, but we could also discover and use separate queues for each.
|
||||
{
|
||||
uint32_t count;
|
||||
vkGetPhysicalDeviceQueueFamilyProperties(*physical_device, &count, NULL);
|
||||
VkQueueFamilyProperties* queues = (VkQueueFamilyProperties*)malloc(
|
||||
sizeof(VkQueueFamilyProperties) * count);
|
||||
vkGetPhysicalDeviceQueueFamilyProperties(*physical_device, &count, queues);
|
||||
for (uint32_t i = 0; i < count; i++) {
|
||||
if (queues[i].queueFlags &
|
||||
(VK_QUEUE_GRAPHICS_BIT | VK_QUEUE_COMPUTE_BIT)) {
|
||||
*queue_family_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
free(queues);
|
||||
IM_ASSERT(*queue_family_index != (uint32_t)-1);
|
||||
}
|
||||
|
||||
// Create Logical Device (with 1 queue)
|
||||
{
|
||||
std::vector<const char*> device_extensions =
|
||||
GetDeviceExtensions(*physical_device, vulkan_features);
|
||||
const float queue_priority[] = {1.0f};
|
||||
VkDeviceQueueCreateInfo queue_info = {};
|
||||
queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
|
||||
queue_info.queueFamilyIndex = *queue_family_index;
|
||||
queue_info.queueCount = 1;
|
||||
queue_info.pQueuePriorities = queue_priority;
|
||||
VkDeviceCreateInfo create_info = {};
|
||||
create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
|
||||
create_info.queueCreateInfoCount = 1;
|
||||
create_info.pQueueCreateInfos = &queue_info;
|
||||
create_info.enabledExtensionCount =
|
||||
static_cast<uint32_t>(device_extensions.size());
|
||||
create_info.ppEnabledExtensionNames = device_extensions.data();
|
||||
|
||||
// Enable timeline semaphores.
|
||||
VkPhysicalDeviceFeatures2 features2;
|
||||
memset(&features2, 0, sizeof(features2));
|
||||
features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
|
||||
create_info.pNext = &features2;
|
||||
VkPhysicalDeviceTimelineSemaphoreFeatures semaphore_features;
|
||||
memset(&semaphore_features, 0, sizeof(semaphore_features));
|
||||
semaphore_features.sType =
|
||||
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_TIMELINE_SEMAPHORE_FEATURES;
|
||||
semaphore_features.pNext = features2.pNext;
|
||||
features2.pNext = &semaphore_features;
|
||||
semaphore_features.timelineSemaphore = VK_TRUE;
|
||||
|
||||
err = vkCreateDevice(*physical_device, &create_info, allocator, device);
|
||||
check_vk_result(err);
|
||||
vkGetDeviceQueue(*device, *queue_family_index, 0, queue);
|
||||
}
|
||||
|
||||
// Create Descriptor Pool
|
||||
{
|
||||
VkDescriptorPoolSize pool_sizes[] = {
|
||||
{VK_DESCRIPTOR_TYPE_SAMPLER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT, 1000}};
|
||||
VkDescriptorPoolCreateInfo pool_info = {};
|
||||
pool_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
|
||||
pool_info.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
|
||||
pool_info.maxSets = 1000 * IREE_ARRAYSIZE(pool_sizes);
|
||||
pool_info.poolSizeCount = (uint32_t)IREE_ARRAYSIZE(pool_sizes);
|
||||
pool_info.pPoolSizes = pool_sizes;
|
||||
err =
|
||||
vkCreateDescriptorPool(*device, &pool_info, allocator, descriptor_pool);
|
||||
check_vk_result(err);
|
||||
}
|
||||
}
|
||||
|
||||
void SetupVulkanWindow(ImGui_ImplVulkanH_Window* wd,
|
||||
const VkAllocationCallbacks* allocator,
|
||||
VkInstance instance, uint32_t queue_family_index,
|
||||
VkPhysicalDevice physical_device, VkDevice device,
|
||||
VkSurfaceKHR surface, int width, int height,
|
||||
uint32_t min_image_count) {
|
||||
wd->Surface = surface;
|
||||
|
||||
// Check for WSI support
|
||||
VkBool32 res;
|
||||
vkGetPhysicalDeviceSurfaceSupportKHR(physical_device, queue_family_index,
|
||||
wd->Surface, &res);
|
||||
if (res != VK_TRUE) {
|
||||
fprintf(stderr, "Error no WSI support on physical device 0\n");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Select Surface Format
|
||||
const VkFormat requestSurfaceImageFormat[] = {
|
||||
VK_FORMAT_B8G8R8A8_UNORM, VK_FORMAT_R8G8B8A8_UNORM,
|
||||
VK_FORMAT_B8G8R8_UNORM, VK_FORMAT_R8G8B8_UNORM};
|
||||
const VkColorSpaceKHR requestSurfaceColorSpace =
|
||||
VK_COLORSPACE_SRGB_NONLINEAR_KHR;
|
||||
wd->SurfaceFormat = ImGui_ImplVulkanH_SelectSurfaceFormat(
|
||||
physical_device, wd->Surface, requestSurfaceImageFormat,
|
||||
(size_t)IREE_ARRAYSIZE(requestSurfaceImageFormat),
|
||||
requestSurfaceColorSpace);
|
||||
|
||||
// Select Present Mode
|
||||
#ifdef IMGUI_UNLIMITED_FRAME_RATE
|
||||
VkPresentModeKHR present_modes[] = {VK_PRESENT_MODE_MAILBOX_KHR,
|
||||
VK_PRESENT_MODE_IMMEDIATE_KHR,
|
||||
VK_PRESENT_MODE_FIFO_KHR};
|
||||
#else
|
||||
VkPresentModeKHR present_modes[] = {VK_PRESENT_MODE_FIFO_KHR};
|
||||
#endif
|
||||
wd->PresentMode = ImGui_ImplVulkanH_SelectPresentMode(
|
||||
physical_device, wd->Surface, &present_modes[0],
|
||||
IREE_ARRAYSIZE(present_modes));
|
||||
|
||||
// Create SwapChain, RenderPass, Framebuffer, etc.
|
||||
IM_ASSERT(min_image_count >= 2);
|
||||
ImGui_ImplVulkanH_CreateOrResizeWindow(instance, physical_device, device, wd,
|
||||
queue_family_index, allocator, width,
|
||||
height, min_image_count);
|
||||
|
||||
// Set clear color.
|
||||
ImVec4 clear_color = ImVec4(0.45f, 0.55f, 0.60f, 1.00f);
|
||||
memcpy(&wd->ClearValue.color.float32[0], &clear_color, 4 * sizeof(float));
|
||||
}
|
||||
|
||||
void RenderFrame(ImGui_ImplVulkanH_Window* wd, VkDevice device, VkQueue queue) {
|
||||
VkResult err;
|
||||
|
||||
VkSemaphore image_acquired_semaphore =
|
||||
wd->FrameSemaphores[wd->SemaphoreIndex].ImageAcquiredSemaphore;
|
||||
VkSemaphore render_complete_semaphore =
|
||||
wd->FrameSemaphores[wd->SemaphoreIndex].RenderCompleteSemaphore;
|
||||
err = vkAcquireNextImageKHR(device, wd->Swapchain, UINT64_MAX,
|
||||
image_acquired_semaphore, VK_NULL_HANDLE,
|
||||
&wd->FrameIndex);
|
||||
check_vk_result(err);
|
||||
|
||||
ImGui_ImplVulkanH_Frame* fd = &wd->Frames[wd->FrameIndex];
|
||||
{
|
||||
err = vkWaitForFences(
|
||||
device, 1, &fd->Fence, VK_TRUE,
|
||||
UINT64_MAX); // wait indefinitely instead of periodically checking
|
||||
check_vk_result(err);
|
||||
|
||||
err = vkResetFences(device, 1, &fd->Fence);
|
||||
check_vk_result(err);
|
||||
}
|
||||
{
|
||||
err = vkResetCommandPool(device, fd->CommandPool, 0);
|
||||
check_vk_result(err);
|
||||
VkCommandBufferBeginInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
|
||||
info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
|
||||
err = vkBeginCommandBuffer(fd->CommandBuffer, &info);
|
||||
check_vk_result(err);
|
||||
}
|
||||
{
|
||||
VkRenderPassBeginInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_RENDER_PASS_BEGIN_INFO;
|
||||
info.renderPass = wd->RenderPass;
|
||||
info.framebuffer = fd->Framebuffer;
|
||||
info.renderArea.extent.width = wd->Width;
|
||||
info.renderArea.extent.height = wd->Height;
|
||||
info.clearValueCount = 1;
|
||||
info.pClearValues = &wd->ClearValue;
|
||||
vkCmdBeginRenderPass(fd->CommandBuffer, &info, VK_SUBPASS_CONTENTS_INLINE);
|
||||
}
|
||||
|
||||
// Record Imgui Draw Data and draw funcs into command buffer
|
||||
ImGui_ImplVulkan_RenderDrawData(ImGui::GetDrawData(), fd->CommandBuffer);
|
||||
|
||||
// Submit command buffer
|
||||
vkCmdEndRenderPass(fd->CommandBuffer);
|
||||
{
|
||||
VkPipelineStageFlags wait_stage =
|
||||
VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT;
|
||||
VkSubmitInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
|
||||
info.waitSemaphoreCount = 1;
|
||||
info.pWaitSemaphores = &image_acquired_semaphore;
|
||||
info.pWaitDstStageMask = &wait_stage;
|
||||
info.commandBufferCount = 1;
|
||||
info.pCommandBuffers = &fd->CommandBuffer;
|
||||
info.signalSemaphoreCount = 1;
|
||||
info.pSignalSemaphores = &render_complete_semaphore;
|
||||
|
||||
err = vkEndCommandBuffer(fd->CommandBuffer);
|
||||
check_vk_result(err);
|
||||
err = vkQueueSubmit(queue, 1, &info, fd->Fence);
|
||||
check_vk_result(err);
|
||||
}
|
||||
}
|
||||
|
||||
void PresentFrame(ImGui_ImplVulkanH_Window* wd, VkQueue queue) {
|
||||
VkSemaphore render_complete_semaphore =
|
||||
wd->FrameSemaphores[wd->SemaphoreIndex].RenderCompleteSemaphore;
|
||||
VkPresentInfoKHR info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_PRESENT_INFO_KHR;
|
||||
info.waitSemaphoreCount = 1;
|
||||
info.pWaitSemaphores = &render_complete_semaphore;
|
||||
info.swapchainCount = 1;
|
||||
info.pSwapchains = &wd->Swapchain;
|
||||
info.pImageIndices = &wd->FrameIndex;
|
||||
VkResult err = vkQueuePresentKHR(queue, &info);
|
||||
check_vk_result(err);
|
||||
wd->SemaphoreIndex =
|
||||
(wd->SemaphoreIndex + 1) %
|
||||
wd->ImageCount; // Now we can use the next set of semaphores
|
||||
}
|
||||
|
||||
static void CleanupVulkan() {
|
||||
vkDestroyDescriptorPool(g_Device, g_DescriptorPool, g_Allocator);
|
||||
|
||||
vkDestroyDevice(g_Device, g_Allocator);
|
||||
vkDestroyInstance(g_Instance, g_Allocator);
|
||||
}
|
||||
|
||||
static void CleanupVulkanWindow() {
|
||||
ImGui_ImplVulkanH_DestroyWindow(g_Instance, g_Device, &g_MainWindowData,
|
||||
g_Allocator);
|
||||
}
|
||||
|
||||
namespace iree {
|
||||
|
||||
extern "C" int iree_main(int argc, char** argv) {
|
||||
|
||||
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
|
||||
if (argc > 1) {
|
||||
// Avoid iree-run-module spinning endlessly on stdin if the user uses single
|
||||
// dashes for flags.
|
||||
printf(
|
||||
"[ERROR] unexpected positional argument (expected none)."
|
||||
" Did you use pass a flag with a single dash ('-')?"
|
||||
" Use '--' instead.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Create a window.
|
||||
if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) {
|
||||
fprintf(stderr, "Failed to initialize SDL\n");
|
||||
abort();
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Setup window
|
||||
// clang-format off
|
||||
SDL_WindowFlags window_flags = (SDL_WindowFlags)(
|
||||
SDL_WINDOW_VULKAN | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
|
||||
// clang-format on
|
||||
SDL_Window* window = SDL_CreateWindow(
|
||||
"IREE Samples - Vulkan Inference GUI", SDL_WINDOWPOS_CENTERED,
|
||||
SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
|
||||
if (window == nullptr)
|
||||
{
|
||||
const char* sdl_err = SDL_GetError();
|
||||
fprintf(stderr, "Error, SDL_CreateWindow returned: %s\n", sdl_err);
|
||||
abort();
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Setup Vulkan
|
||||
iree_hal_vulkan_features_t iree_vulkan_features =
|
||||
static_cast<iree_hal_vulkan_features_t>(
|
||||
IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS |
|
||||
IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS);
|
||||
std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features);
|
||||
std::vector<const char*> extensions =
|
||||
GetInstanceExtensions(window, iree_vulkan_features);
|
||||
SetupVulkan(iree_vulkan_features, layers.data(),
|
||||
static_cast<uint32_t>(layers.size()), extensions.data(),
|
||||
static_cast<uint32_t>(extensions.size()), g_Allocator,
|
||||
&g_Instance, &g_QueueFamily, &g_PhysicalDevice, &g_Queue,
|
||||
&g_Device, &g_DescriptorPool);
|
||||
|
||||
// Create Window Surface
|
||||
VkSurfaceKHR surface;
|
||||
VkResult err;
|
||||
if (SDL_Vulkan_CreateSurface(window, g_Instance, &surface) == 0) {
|
||||
fprintf(stderr, "Failed to create Vulkan surface.\n");
|
||||
abort();
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Create Framebuffers
|
||||
int w, h;
|
||||
SDL_GetWindowSize(window, &w, &h);
|
||||
ImGui_ImplVulkanH_Window* wd = &g_MainWindowData;
|
||||
SetupVulkanWindow(wd, g_Allocator, g_Instance, g_QueueFamily,
|
||||
g_PhysicalDevice, g_Device, surface, w, h, g_MinImageCount);
|
||||
|
||||
// Setup Dear ImGui context
|
||||
IMGUI_CHECKVERSION();
|
||||
ImGui::CreateContext();
|
||||
ImGuiIO& io = ImGui::GetIO();
|
||||
(void)io;
|
||||
|
||||
ImGui::StyleColorsDark();
|
||||
|
||||
// Setup Platform/Renderer bindings
|
||||
ImGui_ImplSDL2_InitForVulkan(window);
|
||||
ImGui_ImplVulkan_InitInfo init_info = {};
|
||||
init_info.Instance = g_Instance;
|
||||
init_info.PhysicalDevice = g_PhysicalDevice;
|
||||
init_info.Device = g_Device;
|
||||
init_info.QueueFamily = g_QueueFamily;
|
||||
init_info.Queue = g_Queue;
|
||||
init_info.PipelineCache = g_PipelineCache;
|
||||
init_info.DescriptorPool = g_DescriptorPool;
|
||||
init_info.Allocator = g_Allocator;
|
||||
init_info.MinImageCount = g_MinImageCount;
|
||||
init_info.ImageCount = wd->ImageCount;
|
||||
init_info.CheckVkResultFn = check_vk_result;
|
||||
ImGui_ImplVulkan_Init(&init_info, wd->RenderPass);
|
||||
|
||||
// Upload Fonts
|
||||
{
|
||||
// Use any command queue
|
||||
VkCommandPool command_pool = wd->Frames[wd->FrameIndex].CommandPool;
|
||||
VkCommandBuffer command_buffer = wd->Frames[wd->FrameIndex].CommandBuffer;
|
||||
|
||||
err = vkResetCommandPool(g_Device, command_pool, 0);
|
||||
check_vk_result(err);
|
||||
VkCommandBufferBeginInfo begin_info = {};
|
||||
begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
|
||||
begin_info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
|
||||
err = vkBeginCommandBuffer(command_buffer, &begin_info);
|
||||
check_vk_result(err);
|
||||
|
||||
ImGui_ImplVulkan_CreateFontsTexture(command_buffer);
|
||||
|
||||
VkSubmitInfo end_info = {};
|
||||
end_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
|
||||
end_info.commandBufferCount = 1;
|
||||
end_info.pCommandBuffers = &command_buffer;
|
||||
err = vkEndCommandBuffer(command_buffer);
|
||||
check_vk_result(err);
|
||||
err = vkQueueSubmit(g_Queue, 1, &end_info, VK_NULL_HANDLE);
|
||||
check_vk_result(err);
|
||||
|
||||
err = vkDeviceWaitIdle(g_Device);
|
||||
check_vk_result(err);
|
||||
ImGui_ImplVulkan_DestroyFontUploadObjects();
|
||||
}
|
||||
|
||||
// Demo state.
|
||||
bool show_iree_window = true;
|
||||
// --------------------------------------------------------------------------
|
||||
// Setup IREE.
|
||||
|
||||
// Check API version.
|
||||
iree_api_version_t actual_version;
|
||||
iree_status_t status =
|
||||
iree_api_version_check(IREE_API_VERSION_LATEST, &actual_version);
|
||||
if (iree_status_is_ok(status)) {
|
||||
fprintf(stdout, "IREE runtime API version: %d\n", actual_version);
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported runtime API version: %d\n", actual_version);
|
||||
abort();
|
||||
}
|
||||
|
||||
// Create a runtime Instance.
|
||||
iree_vm_instance_t* iree_instance = nullptr;
|
||||
IREE_CHECK_OK(
|
||||
iree_vm_instance_create(iree_allocator_system(), &iree_instance));
|
||||
|
||||
// Register HAL drivers and VM module types.
|
||||
IREE_CHECK_OK(iree_hal_vulkan_driver_module_register(
|
||||
iree_hal_driver_registry_default()));
|
||||
IREE_CHECK_OK(iree_hal_module_register_all_types(iree_instance));
|
||||
|
||||
// Create IREE Vulkan Driver and Device, sharing our VkInstance/VkDevice.
|
||||
fprintf(stdout, "Creating Vulkan driver/device\n");
|
||||
// Load symbols from our static `vkGetInstanceProcAddr` for IREE to use.
|
||||
iree_hal_vulkan_syms_t* iree_vk_syms = nullptr;
|
||||
IREE_CHECK_OK(iree_hal_vulkan_syms_create(
|
||||
reinterpret_cast<void*>(&vkGetInstanceProcAddr), iree_allocator_system(),
|
||||
&iree_vk_syms));
|
||||
// Create the driver sharing our VkInstance.
|
||||
iree_hal_driver_t* iree_vk_driver = nullptr;
|
||||
iree_string_view_t driver_identifier = iree_make_cstring_view("vulkan");
|
||||
iree_hal_vulkan_driver_options_t driver_options;
|
||||
driver_options.api_version = VK_API_VERSION_1_0;
|
||||
driver_options.requested_features = static_cast<iree_hal_vulkan_features_t>(
|
||||
IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS);
|
||||
IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance(
|
||||
driver_identifier, &driver_options, iree_vk_syms, g_Instance,
|
||||
iree_allocator_system(), &iree_vk_driver));
|
||||
// Create a device sharing our VkDevice and queue.
|
||||
// We could also create a separate (possibly low priority) compute queue for
|
||||
// IREE, and/or provide a dedicated transfer queue.
|
||||
iree_string_view_t device_identifier = iree_make_cstring_view("vulkan");
|
||||
iree_hal_vulkan_queue_set_t compute_queue_set;
|
||||
compute_queue_set.queue_family_index = g_QueueFamily;
|
||||
compute_queue_set.queue_indices = 1 << 0;
|
||||
iree_hal_vulkan_queue_set_t transfer_queue_set;
|
||||
transfer_queue_set.queue_indices = 0;
|
||||
iree_hal_device_t* iree_vk_device = nullptr;
|
||||
IREE_CHECK_OK(iree_hal_vulkan_wrap_device(
|
||||
device_identifier, &driver_options.device_options, iree_vk_syms,
|
||||
g_Instance, g_PhysicalDevice, g_Device, &compute_queue_set,
|
||||
&transfer_queue_set, iree_allocator_system(), &iree_vk_device));
|
||||
// Create a HAL module using the HAL device.
|
||||
iree_vm_module_t* hal_module = nullptr;
|
||||
IREE_CHECK_OK(iree_hal_module_create(iree_instance, iree_vk_device,
|
||||
IREE_HAL_MODULE_FLAG_NONE,
|
||||
iree_allocator_system(), &hal_module));
|
||||
|
||||
|
||||
// Load bytecode module
|
||||
//iree_file_toc_t module_file_toc;
|
||||
//const char network_model[] = "resnet50_tf.vmfb";
|
||||
//fprintf(stdout, "Loading: %s\n", network_model);
|
||||
//if (load_file(network_model, &module_file_toc.data, &module_file_toc.size) == false)
|
||||
//{
|
||||
// abort();
|
||||
// return 1;
|
||||
//}
|
||||
//fprintf(stdout, "module size: %zu\n", module_file_toc.size);
|
||||
|
||||
iree_vm_module_t* bytecode_module = nullptr;
|
||||
iree_status_t module_status = iree_tooling_load_module_from_flags(
|
||||
iree_instance, iree_allocator_system(), &bytecode_module);
|
||||
if (!iree_status_is_ok(module_status))
|
||||
return -1;
|
||||
//IREE_CHECK_OK(iree_vm_bytecode_module_create(
|
||||
// iree_instance,
|
||||
// iree_const_byte_span_t{
|
||||
// reinterpret_cast<const uint8_t*>(module_file_toc.data),
|
||||
// module_file_toc.size},
|
||||
// iree_allocator_null(), iree_allocator_system(), &bytecode_module));
|
||||
//// Query for details about what is in the loaded module.
|
||||
//iree_vm_module_signature_t bytecode_module_signature =
|
||||
// iree_vm_module_signature(bytecode_module);
|
||||
//fprintf(stdout, "Module loaded, have <%" PRIhsz "> exported functions:\n",
|
||||
// bytecode_module_signature.export_function_count);
|
||||
//for (int i = 0; i < bytecode_module_signature.export_function_count; ++i) {
|
||||
// iree_vm_function_t function;
|
||||
// IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal(
|
||||
// bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
|
||||
// auto function_name = iree_vm_function_name(&function);
|
||||
// auto function_signature = iree_vm_function_signature(&function);
|
||||
|
||||
// fprintf(stdout, " %d: '%.*s' with calling convention '%.*s'\n", i,
|
||||
// (int)function_name.size, function_name.data,
|
||||
// (int)function_signature.calling_convention.size,
|
||||
// function_signature.calling_convention.data);
|
||||
//}
|
||||
|
||||
// Allocate a context that will hold the module state across invocations.
|
||||
iree_vm_context_t* iree_context = nullptr;
|
||||
std::vector<iree_vm_module_t*> modules = {hal_module, bytecode_module};
|
||||
IREE_CHECK_OK(iree_vm_context_create_with_modules(
|
||||
iree_instance, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(),
|
||||
iree_allocator_system(), &iree_context));
|
||||
fprintf(stdout, "Context with modules is ready for use\n");
|
||||
|
||||
// Lookup the entry point function.
|
||||
iree_vm_function_t main_function;
|
||||
const char kMainFunctionName[] = "module.forward";
|
||||
IREE_CHECK_OK(iree_vm_context_resolve_function(
|
||||
iree_context,
|
||||
iree_string_view_t{kMainFunctionName, sizeof(kMainFunctionName) - 1},
|
||||
&main_function));
|
||||
iree_string_view_t main_function_name = iree_vm_function_name(&main_function);
|
||||
fprintf(stdout, "Resolved main function named '%.*s'\n",
|
||||
(int)main_function_name.size, main_function_name.data);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Write inputs into mappable buffers.
|
||||
iree_hal_allocator_t* allocator =
|
||||
iree_hal_device_allocator(iree_vk_device);
|
||||
//iree_hal_memory_type_t input_memory_type =
|
||||
// static_cast<iree_hal_memory_type_t>(
|
||||
// IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
|
||||
// IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
|
||||
//iree_hal_buffer_usage_t input_buffer_usage =
|
||||
// static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
|
||||
//iree_hal_buffer_params_t buffer_params;
|
||||
//buffer_params.type = input_memory_type;
|
||||
//buffer_params.usage = input_buffer_usage;
|
||||
//buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE;
|
||||
|
||||
// Wrap input buffers in buffer views.
|
||||
|
||||
vm::ref<iree_vm_list_t> inputs;
|
||||
iree_status_t input_status = ParseToVariantList(
|
||||
allocator,
|
||||
iree::span<const std::string>{FLAG_function_inputs.data(),
|
||||
FLAG_function_inputs.size()},
|
||||
iree_allocator_system(), &inputs);
|
||||
if (!iree_status_is_ok(input_status))
|
||||
return -1;
|
||||
//vm::ref<iree_vm_list_t> inputs;
|
||||
//IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
|
||||
|
||||
//iree_hal_buffer_view_t* input0_buffer_view = nullptr;
|
||||
//constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
|
||||
//IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
|
||||
// allocator,
|
||||
// /*shape_rank=*/4, /*shape=*/input_buffer_shape,
|
||||
// IREE_HAL_ELEMENT_TYPE_FLOAT_32,
|
||||
// IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
|
||||
// iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
|
||||
// &input0_buffer_view));
|
||||
|
||||
//auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
|
||||
//IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
|
||||
|
||||
// Prepare outputs list to accept results from the invocation.
|
||||
|
||||
vm::ref<iree_vm_list_t> outputs;
|
||||
constexpr iree_hal_dim_t kOutputCount = 1000;
|
||||
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, kOutputCount * sizeof(float), iree_allocator_system(), &outputs));
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Main loop.
|
||||
bool done = false;
|
||||
while (!done) {
|
||||
SDL_Event event;
|
||||
|
||||
while (SDL_PollEvent(&event)) {
|
||||
if (event.type == SDL_QUIT) {
|
||||
done = true;
|
||||
}
|
||||
|
||||
ImGui_ImplSDL2_ProcessEvent(&event);
|
||||
if (event.type == SDL_QUIT) done = true;
|
||||
if (event.type == SDL_WINDOWEVENT &&
|
||||
event.window.event == SDL_WINDOWEVENT_RESIZED &&
|
||||
event.window.windowID == SDL_GetWindowID(window)) {
|
||||
g_SwapChainResizeWidth = (int)event.window.data1;
|
||||
g_SwapChainResizeHeight = (int)event.window.data2;
|
||||
g_SwapChainRebuild = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (g_SwapChainRebuild) {
|
||||
g_SwapChainRebuild = false;
|
||||
ImGui_ImplVulkan_SetMinImageCount(g_MinImageCount);
|
||||
ImGui_ImplVulkanH_CreateOrResizeWindow(
|
||||
g_Instance, g_PhysicalDevice, g_Device, &g_MainWindowData,
|
||||
g_QueueFamily, g_Allocator, g_SwapChainResizeWidth,
|
||||
g_SwapChainResizeHeight, g_MinImageCount);
|
||||
g_MainWindowData.FrameIndex = 0;
|
||||
}
|
||||
|
||||
// Start the Dear ImGui frame
|
||||
ImGui_ImplVulkan_NewFrame();
|
||||
ImGui_ImplSDL2_NewFrame(window);
|
||||
ImGui::NewFrame();
|
||||
|
||||
// Custom window.
|
||||
{
|
||||
ImGui::Begin("IREE Vulkan Integration Demo", &show_iree_window);
|
||||
|
||||
ImGui::Separator();
|
||||
|
||||
// ImGui Inputs for two input tensors.
|
||||
// Run computation whenever any of the values changes.
|
||||
static bool dirty = true;
|
||||
if (dirty) {
|
||||
|
||||
// Synchronously invoke the function.
|
||||
IREE_CHECK_OK(iree_vm_invoke(iree_context, main_function,
|
||||
IREE_VM_INVOCATION_FLAG_NONE,
|
||||
/*policy=*/nullptr, inputs.get(),
|
||||
outputs.get(), iree_allocator_system()));
|
||||
|
||||
|
||||
// we want to run continuously so we can use tools like RenderDoc, RGP, etc...
|
||||
dirty = true;
|
||||
}
|
||||
|
||||
// Framerate counter.
|
||||
ImGui::Text("Application average %.3f ms/frame (%.1f FPS)",
|
||||
1000.0f / ImGui::GetIO().Framerate, ImGui::GetIO().Framerate);
|
||||
|
||||
ImGui::End();
|
||||
}
|
||||
|
||||
// Rendering
|
||||
ImGui::Render();
|
||||
RenderFrame(wd, g_Device, g_Queue);
|
||||
|
||||
PresentFrame(wd, g_Queue);
|
||||
}
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Cleanup
|
||||
iree_vm_module_release(hal_module);
|
||||
iree_vm_module_release(bytecode_module);
|
||||
iree_vm_context_release(iree_context);
|
||||
iree_hal_device_release(iree_vk_device);
|
||||
iree_hal_allocator_release(allocator);
|
||||
iree_hal_driver_release(iree_vk_driver);
|
||||
iree_hal_vulkan_syms_release(iree_vk_syms);
|
||||
iree_vm_instance_release(iree_instance);
|
||||
|
||||
err = vkDeviceWaitIdle(g_Device);
|
||||
check_vk_result(err);
|
||||
ImGui_ImplVulkan_Shutdown();
|
||||
ImGui_ImplSDL2_Shutdown();
|
||||
ImGui::DestroyContext();
|
||||
|
||||
CleanupVulkanWindow();
|
||||
CleanupVulkan();
|
||||
|
||||
SDL_DestroyWindow(window);
|
||||
SDL_Quit();
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace iree
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,251 +0,0 @@
|
||||
# Lint as: python3
|
||||
"""SHARK Tank"""
|
||||
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
|
||||
# will generate local shark tank folder like this:
|
||||
# HOME
|
||||
# /.local
|
||||
# /shark_tank
|
||||
# /albert_lite_base
|
||||
# /...model_name...
|
||||
#
|
||||
|
||||
import os
|
||||
import csv
|
||||
import argparse
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
import tensorflow as tf
|
||||
import subprocess as sp
|
||||
import hashlib
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
visible_default = tf.config.list_physical_devices("GPU")
|
||||
try:
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
visible_devices = tf.config.get_visible_devices()
|
||||
for device in visible_devices:
|
||||
assert device.device_type != "GPU"
|
||||
except:
|
||||
# Invalid device or cannot modify virtual devices once initialized.
|
||||
pass
|
||||
|
||||
|
||||
def create_hash(file_name):
|
||||
with open(file_name, "rb") as f:
|
||||
file_hash = hashlib.blake2b()
|
||||
while chunk := f.read(2**20):
|
||||
file_hash.update(chunk)
|
||||
|
||||
return file_hash.hexdigest()
|
||||
|
||||
|
||||
def save_torch_model(torch_model_list):
|
||||
from tank.model_utils import get_hf_model
|
||||
from tank.model_utils import get_vision_model
|
||||
from tank.model_utils import get_hf_img_cls_model
|
||||
|
||||
with open(torch_model_list) as csvfile:
|
||||
torch_reader = csv.reader(csvfile, delimiter=",")
|
||||
fields = next(torch_reader)
|
||||
for row in torch_reader:
|
||||
torch_model_name = row[0]
|
||||
tracing_required = row[1]
|
||||
model_type = row[2]
|
||||
is_dynamic = row[3]
|
||||
|
||||
tracing_required = False if tracing_required == "False" else True
|
||||
is_dynamic = False if is_dynamic == "False" else True
|
||||
|
||||
model = None
|
||||
input = None
|
||||
if model_type == "vision":
|
||||
model, input, _ = get_vision_model(torch_model_name)
|
||||
elif model_type == "hf":
|
||||
model, input, _ = get_hf_model(torch_model_name)
|
||||
elif model_type == "hf_img_cls":
|
||||
model, input, _ = get_hf_img_cls_model(torch_model_name)
|
||||
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
torch_model_dir = os.path.join(
|
||||
WORKDIR, str(torch_model_name) + "_torch"
|
||||
)
|
||||
os.makedirs(torch_model_dir, exist_ok=True)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name,
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(
|
||||
torch_model_dir, torch_model_name + "_torch" + ".mlir"
|
||||
)
|
||||
)
|
||||
np.save(os.path.join(torch_model_dir, "hash"), np.array(mlir_hash))
|
||||
# Generate torch dynamic models.
|
||||
if is_dynamic:
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=True,
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name + "_dynamic",
|
||||
)
|
||||
|
||||
|
||||
def save_tf_model(tf_model_list):
|
||||
from tank.model_utils_tf import (
|
||||
get_causal_image_model,
|
||||
get_causal_lm_model,
|
||||
get_keras_model,
|
||||
get_TFhf_model,
|
||||
)
|
||||
|
||||
with open(tf_model_list) as csvfile:
|
||||
tf_reader = csv.reader(csvfile, delimiter=",")
|
||||
fields = next(tf_reader)
|
||||
for row in tf_reader:
|
||||
tf_model_name = row[0]
|
||||
model_type = row[1]
|
||||
|
||||
model = None
|
||||
input = None
|
||||
print(f"Generating artifacts for model {tf_model_name}")
|
||||
if model_type == "hf":
|
||||
model, input, _ = get_causal_lm_model(tf_model_name)
|
||||
if model_type == "img":
|
||||
model, input, _ = get_causal_image_model(tf_model_name)
|
||||
if model_type == "keras":
|
||||
model, input, _ = get_keras_model(tf_model_name)
|
||||
if model_type == "TFhf":
|
||||
model, input, _ = get_TFhf_model(tf_model_name)
|
||||
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
input,
|
||||
frontend="tf",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
dir=tf_model_dir,
|
||||
model_name=tf_model_name,
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
|
||||
)
|
||||
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
|
||||
|
||||
|
||||
def save_tflite_model(tflite_model_list):
|
||||
from shark.tflite_utils import TFLitePreprocessor
|
||||
|
||||
with open(tflite_model_list) as csvfile:
|
||||
tflite_reader = csv.reader(csvfile, delimiter=",")
|
||||
for row in tflite_reader:
|
||||
print("\n")
|
||||
tflite_model_name = row[0]
|
||||
tflite_model_link = row[1]
|
||||
print("tflite_model_name", tflite_model_name)
|
||||
print("tflite_model_link", tflite_model_link)
|
||||
tflite_model_name_dir = os.path.join(
|
||||
WORKDIR, str(tflite_model_name) + "_tflite"
|
||||
)
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
|
||||
# Preprocess to get SharkImporter input args
|
||||
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
|
||||
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
|
||||
inputs = tflite_preprocessor.get_inputs()
|
||||
tflite_interpreter = tflite_preprocessor.get_interpreter()
|
||||
|
||||
# Use SharkImporter to get SharkInference input args
|
||||
my_shark_importer = SharkImporter(
|
||||
module=tflite_interpreter,
|
||||
inputs=inputs,
|
||||
frontend="tflite",
|
||||
raw_model_file=raw_model_file_path,
|
||||
)
|
||||
my_shark_importer.import_debug(
|
||||
dir=tflite_model_name_dir,
|
||||
model_name=tflite_model_name,
|
||||
func_name="main",
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(
|
||||
tflite_model_name_dir,
|
||||
tflite_model_name + "_tflite" + ".mlir",
|
||||
)
|
||||
)
|
||||
np.save(
|
||||
os.path.join(tflite_model_name_dir, "hash"),
|
||||
np.array(mlir_hash),
|
||||
)
|
||||
|
||||
|
||||
# Validates whether the file is present or not.
|
||||
def is_valid_file(arg):
|
||||
if not os.path.exists(arg):
|
||||
return None
|
||||
else:
|
||||
return arg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--torch_model_csv",
|
||||
type=lambda x: is_valid_file(x),
|
||||
default="./tank/torch_model_list.csv",
|
||||
help="""Contains the file with torch_model name and args.
|
||||
Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tf_model_csv",
|
||||
type=lambda x: is_valid_file(x),
|
||||
default="./tank/tf_model_list.csv",
|
||||
help="Contains the file with tf model name and args.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tflite_model_csv",
|
||||
type=lambda x: is_valid_file(x),
|
||||
default="./tank/tflite/tflite_model_list.csv",
|
||||
help="Contains the file with tf model name and args.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ci_tank_dir",
|
||||
type=bool,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument("--upload", type=bool, default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
home = str(Path.home())
|
||||
if args.ci_tank_dir == True:
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
|
||||
else:
|
||||
WORKDIR = os.path.join(home, ".local/shark_tank/")
|
||||
|
||||
if args.torch_model_csv:
|
||||
save_torch_model(args.torch_model_csv)
|
||||
|
||||
if args.tf_model_csv:
|
||||
save_tf_model(args.tf_model_csv)
|
||||
|
||||
if args.tflite_model_csv:
|
||||
save_tflite_model(args.tflite_model_csv)
|
||||
|
||||
if args.upload:
|
||||
git_hash = sp.getoutput("git log -1 --format='%h'") + "/"
|
||||
print("uploading files to gs://shark_tank/" + git_hash)
|
||||
os.system(f"gsutil cp -r {WORKDIR}* gs://shark_tank/" + git_hash)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,9 +4,9 @@ requires = [
|
||||
"wheel",
|
||||
"packaging",
|
||||
|
||||
"numpy>=1.22.4",
|
||||
"torch-mlir>=20221021.633",
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
"numpy==1.22.4",
|
||||
"torch-mlir>=20220428.420",
|
||||
"iree-compiler>=20220427.13",
|
||||
"iree-runtime>=20220427.13",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
|
||||
numpy
|
||||
torch==1.14.0.dev20221021
|
||||
torch
|
||||
torchvision
|
||||
|
||||
tqdm
|
||||
@@ -19,16 +19,12 @@ tensorflow-macos
|
||||
tensorflow-metal
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
tensorflow-probability
|
||||
transformers==4.18.0
|
||||
#jax[cpu]
|
||||
|
||||
# tflitehub dependencies.
|
||||
Pillow
|
||||
|
||||
# web dependecies.
|
||||
gradio
|
||||
|
||||
# Testing and support.
|
||||
#lit
|
||||
#pyyaml
|
||||
|
||||
@@ -14,13 +14,10 @@ iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow==2.10
|
||||
keras==2.10
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
diffusers
|
||||
#tensorflow-probability
|
||||
tensorflow
|
||||
tf-models-nightly
|
||||
tensorflow-text-nightly
|
||||
transformers==4.18.0
|
||||
#jax[cpu]
|
||||
|
||||
|
||||
@@ -30,19 +27,13 @@ Pillow
|
||||
# Testing and support.
|
||||
lit
|
||||
pyyaml
|
||||
python-dateutil
|
||||
sacremoses
|
||||
|
||||
# web dependecies.
|
||||
gradio
|
||||
scipy
|
||||
|
||||
#ONNX and ORT for benchmarking
|
||||
#--extra-index-url https://test.pypi.org/simple/
|
||||
#protobuf
|
||||
#coloredlogs
|
||||
#flatbuffers
|
||||
#sympy
|
||||
#psutil
|
||||
#onnx-weekly
|
||||
#ort-nightly
|
||||
--extra-index-url https://test.pypi.org/simple/
|
||||
protobuf
|
||||
coloredlogs
|
||||
flatbuffers
|
||||
sympy
|
||||
psutil
|
||||
onnx-weekly
|
||||
ort-nightly
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
setuptools
|
||||
wheel
|
||||
|
||||
# SHARK Runner
|
||||
#SHARK Runner
|
||||
tqdm
|
||||
|
||||
# SHARK Downloader
|
||||
gsutil
|
||||
|
||||
# Testing
|
||||
#Testing
|
||||
pytest
|
||||
pytest-xdist
|
||||
Pillow
|
||||
parameterized
|
||||
|
||||
17
setup.py
17
setup.py
@@ -7,12 +7,6 @@ with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.4"
|
||||
backend_deps = []
|
||||
if "NO_BACKEND" in os.environ.keys():
|
||||
backend_deps = [
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
]
|
||||
|
||||
setup(
|
||||
name="nodai-SHARK",
|
||||
@@ -32,12 +26,13 @@ setup(
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
packages=find_packages(exclude=("examples")),
|
||||
python_requires=">=3.9",
|
||||
packages=find_packages(exclude=('examples')),
|
||||
python_requires=">=3.7",
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"PyYAML",
|
||||
"torch-mlir>=20221021.633",
|
||||
]
|
||||
+ backend_deps,
|
||||
"torch-mlir>=20220428.420",
|
||||
"iree-compiler>=20220427.13",
|
||||
"iree-runtime>=20220427.13",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
#Write-Host "Installing python"
|
||||
|
||||
#Start-Process winget install Python.Python.3.10 '/quiet InstallAllUsers=1 PrependPath=1' -wait -NoNewWindow
|
||||
|
||||
#Write-Host "python installation completed successfully"
|
||||
|
||||
#Write-Host "Reload environment variables"
|
||||
#$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||
#Write-Host "Reloaded environment variables"
|
||||
|
||||
|
||||
# redirect stderr into stdout
|
||||
$p = &{python -V} 2>&1
|
||||
# check if an ErrorRecord was returned
|
||||
$version = if($p -is [System.Management.Automation.ErrorRecord])
|
||||
{
|
||||
# grab the version string from the error message
|
||||
$p.Exception.Message
|
||||
}
|
||||
else
|
||||
{
|
||||
# otherwise return as is
|
||||
$p
|
||||
}
|
||||
|
||||
Write-Host "Python version found is"
|
||||
Write-Host $p
|
||||
|
||||
|
||||
Write-Host "Installing Build Dependencies"
|
||||
python -m venv .\shark.venv\
|
||||
.\shark.venv\Scripts\activate
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu116 -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
pip install diffusers transformers scipy pillow gradio
|
||||
Write-Host "Build and installation completed successfully"
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
@@ -7,8 +7,6 @@
|
||||
# VENV_DIR=myshark.venv #create a venv called myshark.venv
|
||||
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
|
||||
# IMPORTER=1 #Install importer deps
|
||||
# BENCHMARK=1 #Install benchmark deps
|
||||
# NO_BACKEND=1 #Don't install iree or shark backend
|
||||
# if you run the script from a conda env it will install in your conda env
|
||||
|
||||
TD="$(cd $(dirname $0) && pwd)"
|
||||
@@ -76,15 +74,11 @@ fi
|
||||
$PYTHON -m pip install --upgrade pip || die "Could not upgrade pip"
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
|
||||
if [ "$torch_mlir_bin" = true ]; then
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "MacOS detected. Please install torch-mlir from source or .whl, as dependency problems may occur otherwise."
|
||||
$PYTHON -m pip install --find-links https://github.com/llvm/torch-mlir/releases torch-mlir --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch-mlir"
|
||||
else
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch-mlir"
|
||||
else
|
||||
echo "Could not install torch-mlir" >&2
|
||||
fi
|
||||
echo "Could not install torch-mlir" >&2
|
||||
fi
|
||||
else
|
||||
echo "${Red}No binaries found for Python $PYTHON_VERSION_X_Y on $(uname -s)"
|
||||
@@ -93,51 +87,26 @@ else
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "${USE_IREE}" ]]; then
|
||||
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
|
||||
RUNTIME="nod-ai/SHARK-Runtime"
|
||||
else
|
||||
RUNTIME="https://iree-org.github.io/iree/pip-release-links.html"
|
||||
fi
|
||||
if [[ -z "${NO_BACKEND}" ]]; then
|
||||
echo "Installing ${RUNTIME}..."
|
||||
$PYTHON -m pip install --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
|
||||
else
|
||||
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
|
||||
RUNTIME="google/iree"
|
||||
fi
|
||||
echo "Installing ${RUNTIME}..."
|
||||
$PYTHON -m pip install --find-links https://github.com/${RUNTIME}/releases iree-compiler iree-runtime
|
||||
|
||||
if [[ ! -z "${IMPORTER}" ]]; then
|
||||
echo "${Yellow}Installing importer tools.."
|
||||
if [[ $(uname -s) = 'Linux' ]]; then
|
||||
echo "${Yellow}Linux detected.. installing Linux importer tools"
|
||||
#Always get the importer tools from upstream IREE
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements-importer.txt" -f https://iree-org.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements-importer.txt" -f https://github.com/${RUNTIME}/releases --extra-index-url https://test.pypi.org/simple/ --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
elif [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "${Yellow}macOS detected.. installing macOS importer tools"
|
||||
#Conda seems to have some problems installing these packages and hope they get resolved upstream.
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements-importer-macos.txt" -f ${RUNTIME} --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
$PYTHON -m pip install https://github.com/llvm/torch-mlir/releases/download/snapshot-20221024.636/torch_mlir-20221024.636-cp310-cp310-macosx_11_0_universal2.whl
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements-importer-macos.txt" -f https://github.com/${RUNTIME}/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
fi
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME}
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu116
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu116."
|
||||
else
|
||||
echo "Could not install torch + cu116." >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ ! -z "${ONNX}" ]]; then
|
||||
echo "${Yellow}Installing ONNX and onnxruntime for benchmarks..."
|
||||
$PYTHON -m pip install onnx onnxruntime psutil
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully installed ONNX and ONNX runtime."
|
||||
else
|
||||
echo "Could not install ONNX." >&2
|
||||
fi
|
||||
fi
|
||||
$PYTHON -m pip install -e . --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://github.com/llvm/torch-mlir/releases -f https://github.com/${RUNTIME}/releases
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
echo "${Green}Before running examples activate venv with:"
|
||||
|
||||
@@ -18,10 +18,12 @@ from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.nn.utils import _stateless
|
||||
|
||||
from torch import fx
|
||||
import copy
|
||||
import tempfile
|
||||
|
||||
|
||||
class MakeFxModule:
|
||||
|
||||
def __init__(self, model, inputs, labels=None, custom_inference_fn=None):
|
||||
self.model = model
|
||||
self.inputs = inputs
|
||||
@@ -51,28 +53,20 @@ class MakeFxModule:
|
||||
return fx_g
|
||||
|
||||
def generate_graph(self):
|
||||
fx_g = make_fx(
|
||||
self.custom_inference_fn,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
]
|
||||
),
|
||||
)(
|
||||
dict(self.model.named_parameters()),
|
||||
dict(self.model.named_buffers()),
|
||||
self.inputs,
|
||||
)
|
||||
fx_g = make_fx(self.custom_inference_fn,
|
||||
decomposition_table=get_decompositions([
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward
|
||||
]))(dict(self.model.named_parameters()),
|
||||
dict(self.model.named_buffers()), self.inputs)
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
temp = tempfile.NamedTemporaryFile(
|
||||
suffix="_shark_ts", prefix="temp_ts_"
|
||||
)
|
||||
temp = tempfile.NamedTemporaryFile(suffix='_shark_ts',
|
||||
prefix='temp_ts_')
|
||||
ts_g.save(temp.name)
|
||||
new_ts = torch.jit.load(temp.name)
|
||||
self.training_graph = new_ts
|
||||
|
||||
78
shark/cuda_utils.py
Normal file
78
shark/cuda_utils.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import ctypes
|
||||
|
||||
#Some constants taken from cuda.h
|
||||
CUDA_SUCCESS = 0
|
||||
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16
|
||||
CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR = 39
|
||||
CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13
|
||||
CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36
|
||||
|
||||
|
||||
def get_cuda_sm_cc():
|
||||
libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll')
|
||||
for libname in libnames:
|
||||
try:
|
||||
cuda = ctypes.CDLL(libname)
|
||||
except OSError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise OSError("could not load any of: " + ' '.join(libnames))
|
||||
|
||||
nGpus = ctypes.c_int()
|
||||
name = b' ' * 100
|
||||
cc_major = ctypes.c_int()
|
||||
cc_minor = ctypes.c_int()
|
||||
|
||||
result = ctypes.c_int()
|
||||
device = ctypes.c_int()
|
||||
context = ctypes.c_void_p()
|
||||
error_str = ctypes.c_char_p()
|
||||
|
||||
result = cuda.cuInit(0)
|
||||
if result != CUDA_SUCCESS:
|
||||
cuda.cuGetErrorString(result, ctypes.byref(error_str))
|
||||
print("cuInit failed with error code %d: %s" %
|
||||
(result, error_str.value.decode()))
|
||||
return 1
|
||||
result = cuda.cuDeviceGetCount(ctypes.byref(nGpus))
|
||||
if result != CUDA_SUCCESS:
|
||||
cuda.cuGetErrorString(result, ctypes.byref(error_str))
|
||||
print("cuDeviceGetCount failed with error code %d: %s" %
|
||||
(result, error_str.value.decode()))
|
||||
return 1
|
||||
print("Found %d device(s)." % nGpus.value)
|
||||
for i in range(nGpus.value):
|
||||
result = cuda.cuDeviceGet(ctypes.byref(device), i)
|
||||
if result != CUDA_SUCCESS:
|
||||
cuda.cuGetErrorString(result, ctypes.byref(error_str))
|
||||
print("cuDeviceGet failed with error code %d: %s" %
|
||||
(result, error_str.value.decode()))
|
||||
return 1
|
||||
print("Device: %d" % i)
|
||||
if cuda.cuDeviceGetName(ctypes.c_char_p(name), len(name),
|
||||
device) == CUDA_SUCCESS:
|
||||
print(" Name: %s" % (name.split(b'\0', 1)[0].decode()))
|
||||
if cuda.cuDeviceComputeCapability(ctypes.byref(cc_major),
|
||||
ctypes.byref(cc_minor),
|
||||
device) == CUDA_SUCCESS:
|
||||
print(" Compute Capability: %d.%d" %
|
||||
(cc_major.value, cc_minor.value))
|
||||
sm = f"sm_{cc_major.value}{cc_minor.value}"
|
||||
return sm
|
||||
@@ -1,70 +0,0 @@
|
||||
import torchdynamo
|
||||
import torch
|
||||
import torch_mlir
|
||||
from shark.sharkdynamo.utils import make_shark_compiler
|
||||
|
||||
|
||||
import warnings, logging
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
torchdynamo.config.log_level = logging.ERROR
|
||||
|
||||
|
||||
torchdynamo.reset()
|
||||
|
||||
|
||||
@torchdynamo.optimize(
|
||||
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
|
||||
)
|
||||
def foo(t):
|
||||
return 2 * t
|
||||
|
||||
|
||||
example_input = torch.rand((2, 3))
|
||||
x = foo(example_input)
|
||||
print(x)
|
||||
|
||||
|
||||
torchdynamo.reset()
|
||||
|
||||
|
||||
@torchdynamo.optimize(
|
||||
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
|
||||
)
|
||||
def foo(a, b):
|
||||
x = a / (a + 1)
|
||||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
|
||||
|
||||
print(foo(torch.rand((2, 3)), -torch.rand((2, 3))))
|
||||
|
||||
|
||||
torchdynamo.reset()
|
||||
|
||||
|
||||
@torchdynamo.optimize(
|
||||
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
|
||||
)
|
||||
def foo(a):
|
||||
for i in range(10):
|
||||
a += 1.0
|
||||
return a
|
||||
|
||||
|
||||
print(foo(torch.rand((1, 2))))
|
||||
|
||||
torchdynamo.reset()
|
||||
|
||||
|
||||
@torchdynamo.optimize(
|
||||
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
|
||||
)
|
||||
def test_unsupported_types(t, y):
|
||||
return t, 2 * y
|
||||
|
||||
|
||||
str_input = "hello"
|
||||
tensor_input = torch.randn(2)
|
||||
print(test_unsupported_types(str_input, tensor_input))
|
||||
@@ -8,9 +8,7 @@ try:
|
||||
from torchdynamo.optimizations.backends import create_backend
|
||||
from torchdynamo.optimizations.subgraph import SubGraph
|
||||
except ModuleNotFoundError:
|
||||
print(
|
||||
"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo"
|
||||
)
|
||||
print("Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo")
|
||||
exit()
|
||||
|
||||
NUM_ITERS = 10
|
||||
@@ -26,9 +24,7 @@ def __torch_mlir(fx_graph, *args, **kwargs):
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
assert len(node.args) == 1, "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple) and len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
@@ -45,12 +41,8 @@ def __torch_mlir(fx_graph, *args, **kwargs):
|
||||
if len(args) == 1 and isinstance(args[0], list):
|
||||
args = args[0]
|
||||
|
||||
linalg_module = compile(
|
||||
ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS
|
||||
)
|
||||
callable, _ = get_iree_compiled_module(
|
||||
linalg_module, "cuda", func_name="forward"
|
||||
)
|
||||
linalg_module = compile(ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS)
|
||||
callable, _ = get_iree_compiled_module(linalg_module, "cuda", func_name="forward")
|
||||
|
||||
def forward(*inputs):
|
||||
return callable(*inputs)
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
model = torch.hub.load(
|
||||
"pytorch/vision:v0.10.0", "squeezenet1_0", pretrained=True
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# from PIL import Image
|
||||
# from torchvision import transforms
|
||||
# import urllib
|
||||
#
|
||||
# url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
|
||||
# try: urllib.URLopener().retrieve(url, filename)
|
||||
# except: urllib.request.urlretrieve(url, filename)
|
||||
#
|
||||
#
|
||||
# input_image = Image.open(filename)
|
||||
# preprocess = transforms.Compose([
|
||||
# transforms.Resize(256),
|
||||
# transforms.CenterCrop(224),
|
||||
# transforms.ToTensor(),
|
||||
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
# ])
|
||||
# input_tensor = preprocess(input_image)
|
||||
# input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
||||
# print(input_batch.shape) # size = [1, 3, 224, 224]
|
||||
|
||||
# The above is code for generating sample inputs from an image. We can just use
|
||||
# random values for accuracy testing though
|
||||
input_batch = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
# Focus on CPU for now
|
||||
if False and torch.cuda.is_available():
|
||||
input_batch = input_batch.to("cuda")
|
||||
model.to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_batch)
|
||||
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
|
||||
golden_confidences = output[0]
|
||||
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
|
||||
golden_probabilities = torch.nn.functional.softmax(
|
||||
golden_confidences, dim=0
|
||||
).numpy()
|
||||
|
||||
golden_confidences = golden_confidences.numpy()
|
||||
|
||||
from shark.torch_mlir_lockstep_tensor import TorchMLIRLockstepTensor
|
||||
|
||||
input_detached_clone = input_batch.clone()
|
||||
eager_input_batch = TorchMLIRLockstepTensor(input_detached_clone)
|
||||
|
||||
print("getting torch-mlir result")
|
||||
|
||||
output = model(eager_input_batch)
|
||||
|
||||
static_output = output.elem
|
||||
confidences = static_output[0]
|
||||
probabilities = torch.nn.functional.softmax(
|
||||
torch.from_numpy(confidences), dim=0
|
||||
).numpy()
|
||||
|
||||
print("The obtained result via shark is: ", confidences)
|
||||
print("The golden result is:", golden_confidences)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
golden_confidences, confidences, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
golden_probabilities, probabilities, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
@@ -9,24 +9,23 @@ from shark.shark_inference import SharkInference
|
||||
clip_vit_inputs = [
|
||||
tf.TensorSpec(shape=[2, 7], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[2, 7], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
|
||||
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32)
|
||||
]
|
||||
|
||||
|
||||
class CLIPModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(CLIPModule, self).__init__()
|
||||
self.m = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
self.m.predict = lambda x, y, z: self.m(
|
||||
input_ids=x, attention_mask=y, pixel_values=z
|
||||
)
|
||||
input_ids=x, attention_mask=y, pixel_values=z)
|
||||
|
||||
@tf.function(input_signature=clip_vit_inputs)
|
||||
def forward(self, input_ids, attention_mask, pixel_values):
|
||||
return self.m.predict(
|
||||
input_ids, attention_mask, pixel_values
|
||||
).logits_per_image
|
||||
return self.m.predict(input_ids, attention_mask,
|
||||
pixel_values).logits_per_image
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -36,30 +35,17 @@ if __name__ == "__main__":
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(
|
||||
text=["a photo of a cat", "a photo of a dog"],
|
||||
images=image,
|
||||
return_tensors="tf",
|
||||
padding=True,
|
||||
)
|
||||
inputs = processor(text=["a photo of a cat", "a photo of a dog"],
|
||||
images=image,
|
||||
return_tensors="tf",
|
||||
padding=True)
|
||||
|
||||
shark_module = SharkInference(
|
||||
CLIPModule(),
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["pixel_values"],
|
||||
),
|
||||
)
|
||||
(inputs["input_ids"], inputs["attention_mask"], inputs["pixel_values"]))
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
|
||||
print(
|
||||
shark_module.forward(
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["pixel_values"],
|
||||
)
|
||||
)
|
||||
)
|
||||
shark_module.forward((inputs["input_ids"], inputs["attention_mask"],
|
||||
inputs["pixel_values"])))
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from iree.compiler import compile_str
|
||||
from iree import runtime as ireert
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
|
||||
class AlbertModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForMaskedLM.from_pretrained("albert-base-v2")
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.model(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).logits
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||
text = "This [MASK] is very tasty."
|
||||
encoded_inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
|
||||
mlir_importer = SharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
encoded_inputs["input_ids"] == tokenizer.mask_token_id
|
||||
)[1]
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence with [MASK] to fill: ")
|
||||
encoded_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
encoded_inputs["input_ids"] == tokenizer.mask_token_id
|
||||
)[1]
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = (
|
||||
torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
|
||||
)
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
break
|
||||
@@ -1,100 +0,0 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from iree.compiler import tf as tfc
|
||||
from iree.compiler import compile_str
|
||||
from iree import runtime as ireert
|
||||
import os
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
# Create a set of inputs
|
||||
t5_inputs = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class AlbertModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(AlbertModule, self).__init__()
|
||||
self.m = TFAutoModelForMaskedLM.from_pretrained("albert-base-v2")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||
# text = "This is a great [MASK]."
|
||||
text = "This [MASK] is very tasty."
|
||||
encoded_inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="tf",
|
||||
)
|
||||
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
|
||||
mlir_importer = SharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="tf",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
output_idx = 0
|
||||
data_idx = 1
|
||||
token_logits = shark_module.forward(inputs)[output_idx][data_idx]
|
||||
mask_id = np.where(
|
||||
tf.squeeze(encoded_inputs["input_ids"]) == tokenizer.mask_token_id
|
||||
)
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[0:5]
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence with [MASK] to fill: ")
|
||||
encoded_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="tf",
|
||||
)
|
||||
inputs = (
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
token_logits = shark_module.forward(inputs)[output_idx][data_idx]
|
||||
mask_id = np.where(
|
||||
tf.squeeze(encoded_inputs["input_ids"])
|
||||
== tokenizer.mask_token_id
|
||||
)
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[
|
||||
0:5
|
||||
]
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
sys.exit()
|
||||
@@ -1,12 +0,0 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("bloom")
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
@@ -13,6 +13,7 @@ gpt2_inputs = [
|
||||
|
||||
|
||||
class GPT2Module(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(GPT2Module, self).__init__()
|
||||
self.m = TFGPT2Model.from_pretrained("distilgpt2")
|
||||
@@ -29,12 +30,9 @@ if __name__ == "__main__":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
||||
text = "I love the distilled version of models."
|
||||
|
||||
inputs = tokenizer(text, return_tensors="tf")
|
||||
inputs = tokenizer(text, return_tensors='tf')
|
||||
shark_module = SharkInference(
|
||||
GPT2Module(), (inputs["input_ids"], inputs["attention_mask"])
|
||||
)
|
||||
GPT2Module(), (inputs["input_ids"], inputs["attention_mask"]))
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
print(
|
||||
shark_module.forward((inputs["input_ids"], inputs["attention_mask"]))
|
||||
)
|
||||
print(shark_module.forward((inputs["input_ids"], inputs["attention_mask"])))
|
||||
|
||||
@@ -12,26 +12,7 @@ mhlo_ir = r"""builtin.module {
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
|
||||
print("Running shark on cpu backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo"
|
||||
)
|
||||
|
||||
# Generate the random inputs and feed into the graph.
|
||||
x = shark_module.generate_random_inputs()
|
||||
shark_module = SharkInference(mhlo_ir, (arg0, arg1))
|
||||
shark_module.set_frontend("mhlo")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on cuda backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on vulkan backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
print(shark_module.forward((arg0, arg1)))
|
||||
|
||||
@@ -7,13 +7,17 @@ tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
num_labels=
|
||||
2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=
|
||||
False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=
|
||||
False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
@@ -23,12 +27,9 @@ class MiniLMSequenceClassification(torch.nn.Module):
|
||||
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
|
||||
shark_module = SharkInference(
|
||||
MiniLMSequenceClassification(),
|
||||
(test_input,),
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
shark_module = SharkInference(MiniLMSequenceClassification(), (test_input,),
|
||||
jit_trace=True,
|
||||
benchmark_mode=True)
|
||||
|
||||
shark_module.compile()
|
||||
shark_module.forward((test_input,))
|
||||
|
||||
@@ -2,6 +2,10 @@ import tensorflow as tf
|
||||
from transformers import BertModel, BertTokenizer, TFBertModel
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
@@ -9,22 +13,21 @@ BATCH_SIZE = 1
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32)
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
# Create a BERT trainer with the created network.
|
||||
self.m = TFBertModel.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True
|
||||
)
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True)
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m.predict = lambda x, y, z: self.m.call(
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
@@ -34,28 +37,22 @@ class BertModule(tf.Module):
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
"microsoft/MiniLM-L12-H384-uncased")
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
encoded_input = tokenizer(text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
tf.convert_to_tensor(encoded_input[key]), 0)
|
||||
|
||||
test_input = (
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
)
|
||||
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"])
|
||||
shark_module = SharkInference(
|
||||
BertModule(), test_input, benchmark_mode=True
|
||||
)
|
||||
BertModule(),
|
||||
test_input,
|
||||
benchmark_mode=True)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
shark_module.benchmark_all(test_input)
|
||||
|
||||
@@ -1,24 +1,35 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
|
||||
num_labels=
|
||||
2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=
|
||||
False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=
|
||||
False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
|
||||
shark_module = SharkInference(MiniLMSequenceClassification(), (test_input,),
|
||||
jit_trace=True)
|
||||
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
|
||||
|
||||
# Let's generate random inputs, currently supported
|
||||
# for static models.
|
||||
rand_inputs = shark_module.generate_random_inputs()
|
||||
rand_results = shark_module.forward(rand_inputs)
|
||||
|
||||
print("Running shark_module with random_inputs is: ", rand_results)
|
||||
result = shark_module.forward((test_input,))
|
||||
print("Obtained result", result)
|
||||
|
||||
41
shark/examples/shark_inference/minilm_load_benchmark_tf.py
Normal file
41
shark/examples/shark_inference/minilm_load_benchmark_tf.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import tensorflow as tf
|
||||
from transformers import BertModel, BertTokenizer, TFBertModel
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import shark_load
|
||||
from shark.parser import parser
|
||||
import os
|
||||
|
||||
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
|
||||
parser.add_argument(
|
||||
"--download_mlir_path",
|
||||
type=str,
|
||||
default="minilm_tf_inference.mlir",
|
||||
help="Specifies path to target mlir file that will be loaded.")
|
||||
load_args, unknown = parser.parse_known_args()
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased")
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0)
|
||||
model_name = "minilm_tf_inference"
|
||||
minilm_mlir = shark_load(model_name, load_args.download_mlir_path)
|
||||
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"])
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, test_input, benchmark_mode=True)
|
||||
shark_module.set_frontend("mhlo")
|
||||
shark_module.compile()
|
||||
shark_module.benchmark_all(test_input)
|
||||
@@ -9,22 +9,21 @@ BATCH_SIZE = 1
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32)
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
# Create a BERT trainer with the created network.
|
||||
self.m = TFBertModel.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True
|
||||
)
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True)
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m.predict = lambda x, y, z: self.m.call(
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
@@ -34,37 +33,24 @@ class BertModule(tf.Module):
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
"microsoft/MiniLM-L12-H384-uncased")
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
encoded_input = tokenizer(text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
tf.convert_to_tensor(encoded_input[key]), 0)
|
||||
|
||||
shark_module = SharkInference(
|
||||
BertModule(),
|
||||
(
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
),
|
||||
)
|
||||
(encoded_input["input_ids"], encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"]))
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
|
||||
print(
|
||||
shark_module.forward(
|
||||
(
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
)
|
||||
)
|
||||
)
|
||||
(encoded_input["input_ids"], encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"])))
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
torch.hub.list("zhanghang1989/ResNeSt", force_reload=True)
|
||||
|
||||
|
||||
class ResnestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = torch.hub.load(
|
||||
"zhanghang1989/ResNeSt", "resnest50", pretrained=True
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
ResnestModule(),
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(vision_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
print(golden_out)
|
||||
|
||||
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input,))
|
||||
print("Obtained result", result)
|
||||
@@ -1,76 +0,0 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import sys
|
||||
import torchvision.models as models
|
||||
import torch_mlir
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = models.resnet50(pretrained=True)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
model = VisionModule()
|
||||
test_input = torch.randn(1, 3, 224, 224)
|
||||
actual_out = model(test_input)
|
||||
|
||||
test_input_fp16 = test_input.to(device=torch.device("cuda"), dtype=torch.half)
|
||||
model_fp16 = model.half()
|
||||
model_fp16.eval()
|
||||
model_fp16.to("cuda")
|
||||
actual_out_fp16 = model_fp16(test_input_fp16)
|
||||
|
||||
ts_g = torch.jit.trace(model_fp16, [test_input_fp16])
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(test_input_fp16),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# from contextlib import redirect_stdout
|
||||
|
||||
# with open('resnet50_fp16_linalg_ir.mlir', 'w') as f:
|
||||
# with redirect_stdout(f):
|
||||
# print(module.operation.get_asm())
|
||||
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cuda", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
|
||||
def shark_result(x):
|
||||
x_ny = x.cpu().detach().numpy()
|
||||
inputs = (x_ny,)
|
||||
result = shark_module.forward(inputs)
|
||||
return torch.from_numpy(result)
|
||||
|
||||
|
||||
observed_out = shark_result(test_input_fp16)
|
||||
|
||||
print("Golden result:", actual_out_fp16)
|
||||
print("SHARK result:", observed_out)
|
||||
|
||||
actual_out_fp16 = actual_out_fp16.to(device=torch.device("cpu"))
|
||||
|
||||
print(
|
||||
torch.testing.assert_allclose(
|
||||
actual_out_fp16, observed_out, rtol=1e-2, atol=1e-2
|
||||
)
|
||||
)
|
||||
@@ -5,28 +5,24 @@ import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
|
||||
|
||||
################################## Preprocessing inputs and model ############
|
||||
def load_and_preprocess_image(url: str):
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
|
||||
"User-Agent":
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
|
||||
}
|
||||
img = Image.open(
|
||||
requests.get(url, headers=headers, stream=True).raw
|
||||
).convert("RGB")
|
||||
img = Image.open(requests.get(url, headers=headers,
|
||||
stream=True).raw).convert("RGB")
|
||||
# preprocessing pipeline
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
]
|
||||
)
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
img_preprocessed = preprocess(img)
|
||||
return torch.unsqueeze(img_preprocessed, 0)
|
||||
|
||||
@@ -48,6 +44,7 @@ def top3_possibilities(res):
|
||||
|
||||
|
||||
class Resnet50Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.resnet = models.resnet50(pretrained=True)
|
||||
@@ -64,18 +61,18 @@ labels = load_labels()
|
||||
|
||||
##############################################################################
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
print(input.shape)
|
||||
|
||||
## The img is passed to determine the input shape.
|
||||
shark_module = SharkInference(Resnet50Module(), (img,))
|
||||
shark_module.compile()
|
||||
|
||||
## Can pass any img or input to the forward module.
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50")
|
||||
|
||||
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
path = shark_module.save_module()
|
||||
shark_module.load_module(path)
|
||||
result = shark_module.forward((img.detach().numpy(),))
|
||||
results = shark_module.forward((img,))
|
||||
|
||||
print("The top 3 results obtained via shark_runner is:")
|
||||
print(top3_possibilities(torch.from_numpy(result)))
|
||||
print(top3_possibilities(torch.from_numpy(results)))
|
||||
|
||||
print()
|
||||
|
||||
|
||||
@@ -1,392 +0,0 @@
|
||||
# Description: an implementation of a deep learning recommendation model (DLRM)
|
||||
# The model input consists of dense and sparse features. The former is a vector
|
||||
# of floating point values. The latter is a list of sparse indices into
|
||||
# embedding tables, which consist of vectors of floating point values.
|
||||
# The selected vectors are passed to mlp networks denoted by triangles,
|
||||
# in some cases the vectors are interacted through operators (Ops).
|
||||
#
|
||||
# output:
|
||||
# vector of values
|
||||
# model: |
|
||||
# /\
|
||||
# /__\
|
||||
# |
|
||||
# _____________________> Op <___________________
|
||||
# / | \
|
||||
# /\ /\ /\
|
||||
# /__\ /__\ ... /__\
|
||||
# | | |
|
||||
# | Op Op
|
||||
# | ____/__\_____ ____/__\____
|
||||
# | |_Emb_|____|__| ... |_Emb_|__|___|
|
||||
# input:
|
||||
# [ dense features ] [sparse indices] , ..., [sparse indices]
|
||||
#
|
||||
# More precise definition of model layers:
|
||||
# 1) fully connected layers of an mlp
|
||||
# z = f(y)
|
||||
# y = Wx + b
|
||||
#
|
||||
# 2) embedding lookup (for a list of sparse indices p=[p1,...,pk])
|
||||
# z = Op(e1,...,ek)
|
||||
# obtain vectors e1=E[:,p1], ..., ek=E[:,pk]
|
||||
#
|
||||
# 3) Operator Op can be one of the following
|
||||
# Sum(e1,...,ek) = e1 + ... + ek
|
||||
# Dot(e1,...,ek) = [e1'e1, ..., e1'ek, ..., ek'e1, ..., ek'ek]
|
||||
# Cat(e1,...,ek) = [e1', ..., ek']'
|
||||
# where ' denotes transpose operation
|
||||
#
|
||||
# References:
|
||||
# [1] Maxim Naumov, Dheevatsa Mudigere, Hao-Jun Michael Shi, Jianyu Huang,
|
||||
# Narayanan Sundaram, Jongsoo Park, Xiaodong Wang, Udit Gupta, Carole-Jean Wu,
|
||||
# Alisson G. Azzolini, Dmytro Dzhulgakov, Andrey Mallevich, Ilia Cherniavskii,
|
||||
# Yinghai Lu, Raghuraman Krishnamoorthi, Ansha Yu, Volodymyr Kondratenko,
|
||||
# Stephanie Pereira, Xianjie Chen, Wenlin Chen, Vijay Rao, Bill Jia, Liang Xiong,
|
||||
# Misha Smelyanskiy, "Deep Learning Recommendation Model for Personalization and
|
||||
# Recommendation Systems", CoRR, arXiv:1906.00091, 2019
|
||||
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
### define dlrm in PyTorch ###
|
||||
class DLRM_Net(nn.Module):
|
||||
def create_mlp(self, ln, sigmoid_layer):
|
||||
# build MLP layer by layer
|
||||
layers = nn.ModuleList()
|
||||
for i in range(0, ln.size - 1):
|
||||
n = ln[i]
|
||||
m = ln[i + 1]
|
||||
|
||||
# construct fully connected operator
|
||||
LL = nn.Linear(int(n), int(m), bias=True)
|
||||
|
||||
# initialize the weights
|
||||
# with torch.no_grad():
|
||||
# custom Xavier input, output or two-sided fill
|
||||
|
||||
mean = 0.0 # std_dev = np.sqrt(variance)
|
||||
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
|
||||
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
|
||||
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
|
||||
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
|
||||
LL.weight.data = torch.tensor(W, requires_grad=True)
|
||||
LL.bias.data = torch.tensor(bt, requires_grad=True)
|
||||
|
||||
# approach 2
|
||||
# LL.weight.data.copy_(torch.tensor(W))
|
||||
# LL.bias.data.copy_(torch.tensor(bt))
|
||||
# approach 3
|
||||
# LL.weight = Parameter(torch.tensor(W),requires_grad=True)
|
||||
# LL.bias = Parameter(torch.tensor(bt),requires_grad=True)
|
||||
layers.append(LL)
|
||||
|
||||
# construct sigmoid or relu operator
|
||||
if i == sigmoid_layer:
|
||||
layers.append(nn.Sigmoid())
|
||||
else:
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
# approach 1: use ModuleList
|
||||
# return layers
|
||||
# approach 2: use Sequential container to wrap all layers
|
||||
return torch.nn.Sequential(*layers)
|
||||
|
||||
def create_emb(self, m, ln, weighted_pooling=None):
|
||||
emb_l = nn.ModuleList()
|
||||
v_W_l = []
|
||||
for i in range(0, ln.size):
|
||||
n = ln[i]
|
||||
|
||||
# construct embedding operator
|
||||
EE = nn.EmbeddingBag(n, m, mode="sum")
|
||||
# initialize embeddings
|
||||
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n))
|
||||
W = np.random.uniform(
|
||||
low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m)
|
||||
).astype(np.float32)
|
||||
# approach 1
|
||||
print(W)
|
||||
EE.weight.data = torch.tensor(W, requires_grad=True)
|
||||
# approach 2
|
||||
# EE.weight.data.copy_(torch.tensor(W))
|
||||
# approach 3
|
||||
# EE.weight = Parameter(torch.tensor(W),requires_grad=True)
|
||||
if weighted_pooling is None:
|
||||
v_W_l.append(None)
|
||||
else:
|
||||
v_W_l.append(torch.ones(n, dtype=torch.float32))
|
||||
emb_l.append(EE)
|
||||
return emb_l, v_W_l
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
m_spa=None,
|
||||
ln_emb=None,
|
||||
ln_bot=None,
|
||||
ln_top=None,
|
||||
arch_interaction_op=None,
|
||||
arch_interaction_itself=False,
|
||||
sigmoid_bot=-1,
|
||||
sigmoid_top=-1,
|
||||
weighted_pooling=None,
|
||||
):
|
||||
super(DLRM_Net, self).__init__()
|
||||
|
||||
if (
|
||||
(m_spa is not None)
|
||||
and (ln_emb is not None)
|
||||
and (ln_bot is not None)
|
||||
and (ln_top is not None)
|
||||
and (arch_interaction_op is not None)
|
||||
):
|
||||
|
||||
# save arguments
|
||||
self.output_d = 0
|
||||
self.arch_interaction_op = arch_interaction_op
|
||||
self.arch_interaction_itself = arch_interaction_itself
|
||||
if weighted_pooling is not None and weighted_pooling != "fixed":
|
||||
self.weighted_pooling = "learned"
|
||||
else:
|
||||
self.weighted_pooling = weighted_pooling
|
||||
|
||||
# create operators
|
||||
self.emb_l, w_list = self.create_emb(
|
||||
m_spa, ln_emb, weighted_pooling
|
||||
)
|
||||
if self.weighted_pooling == "learned":
|
||||
self.v_W_l = nn.ParameterList()
|
||||
for w in w_list:
|
||||
self.v_W_l.append(nn.Parameter(w))
|
||||
else:
|
||||
self.v_W_l = w_list
|
||||
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot)
|
||||
self.top_l = self.create_mlp(ln_top, sigmoid_top)
|
||||
|
||||
def apply_mlp(self, x, layers):
|
||||
return layers(x)
|
||||
|
||||
def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
|
||||
# WARNING: notice that we are processing the batch at once. We implicitly
|
||||
# assume that the data is laid out such that:
|
||||
# 1. each embedding is indexed with a group of sparse indices,
|
||||
# corresponding to a single lookup
|
||||
# 2. for each embedding the lookups are further organized into a batch
|
||||
# 3. for a list of embedding tables there is a list of batched lookups
|
||||
# TORCH-MLIR
|
||||
# We are passing all the embeddings as arguments for easy parsing.
|
||||
|
||||
ly = []
|
||||
for k, sparse_index_group_batch in enumerate(lS_i):
|
||||
sparse_offset_group_batch = lS_o[k]
|
||||
|
||||
# embedding lookup
|
||||
# We are using EmbeddingBag, which implicitly uses sum operator.
|
||||
# The embeddings are represented as tall matrices, with sum
|
||||
# happening vertically across 0 axis, resulting in a row vector
|
||||
# E = emb_l[k]
|
||||
|
||||
if v_W_l[k] is not None:
|
||||
per_sample_weights = v_W_l[k].gather(
|
||||
0, sparse_index_group_batch
|
||||
)
|
||||
else:
|
||||
per_sample_weights = None
|
||||
|
||||
E = emb_l[k]
|
||||
V = E(
|
||||
sparse_index_group_batch,
|
||||
sparse_offset_group_batch,
|
||||
per_sample_weights=per_sample_weights,
|
||||
)
|
||||
|
||||
ly.append(V)
|
||||
|
||||
return ly
|
||||
|
||||
def interact_features(self, x, ly):
|
||||
|
||||
if self.arch_interaction_op == "dot":
|
||||
# concatenate dense and sparse features
|
||||
(batch_size, d) = x.shape
|
||||
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
|
||||
# perform a dot product
|
||||
Z = torch.bmm(T, torch.transpose(T, 1, 2))
|
||||
# append dense feature with the interactions (into a row vector)
|
||||
# approach 1: all
|
||||
# Zflat = Z.view((batch_size, -1))
|
||||
# approach 2: unique
|
||||
_, ni, nj = Z.shape
|
||||
# approach 1: tril_indices
|
||||
# offset = 0 if self.arch_interaction_itself else -1
|
||||
# li, lj = torch.tril_indices(ni, nj, offset=offset)
|
||||
# approach 2: custom
|
||||
offset = 1 if self.arch_interaction_itself else 0
|
||||
li = torch.tensor(
|
||||
[i for i in range(ni) for j in range(i + offset)]
|
||||
)
|
||||
lj = torch.tensor(
|
||||
[j for i in range(nj) for j in range(i + offset)]
|
||||
)
|
||||
Zflat = Z[:, li, lj]
|
||||
# concatenate dense features and interactions
|
||||
R = torch.cat([x] + [Zflat], dim=1)
|
||||
elif self.arch_interaction_op == "cat":
|
||||
# concatenation features (into a row vector)
|
||||
R = torch.cat([x] + ly, dim=1)
|
||||
else:
|
||||
sys.exit(
|
||||
"ERROR: --arch-interaction-op="
|
||||
+ self.arch_interaction_op
|
||||
+ " is not supported"
|
||||
)
|
||||
|
||||
return R
|
||||
|
||||
def forward(self, dense_x, lS_o, *lS_i):
|
||||
return self.sequential_forward(dense_x, lS_o, lS_i)
|
||||
|
||||
def sequential_forward(self, dense_x, lS_o, lS_i):
|
||||
# process dense features (using bottom mlp), resulting in a row vector
|
||||
x = self.apply_mlp(dense_x, self.bot_l)
|
||||
# debug prints
|
||||
# print("intermediate")
|
||||
# print(x.detach().cpu().numpy())
|
||||
|
||||
# process sparse features(using embeddings), resulting in a list of row vectors
|
||||
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
|
||||
# for y in ly:
|
||||
# print(y.detach().cpu().numpy())
|
||||
|
||||
# interact features (dense and sparse)
|
||||
z = self.interact_features(x, ly)
|
||||
# print(z.detach().cpu().numpy())
|
||||
|
||||
# obtain probability of a click (using top mlp)
|
||||
p = self.apply_mlp(z, self.top_l)
|
||||
|
||||
# # clamp output if needed
|
||||
# if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
|
||||
# z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
|
||||
# else:
|
||||
# z = p
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def dash_separated_ints(value):
|
||||
vals = value.split("-")
|
||||
for val in vals:
|
||||
try:
|
||||
int(val)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is not a valid dash separated list of ints" % value
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# model related parameters
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train Deep Learning Recommendation Model (DLRM)"
|
||||
)
|
||||
parser.add_argument("--arch-sparse-feature-size", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--arch-embedding-size", type=dash_separated_ints, default="4-3-2"
|
||||
)
|
||||
# j will be replaced with the table number
|
||||
parser.add_argument(
|
||||
"--arch-mlp-bot", type=dash_separated_ints, default="4-3-2"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-mlp-top", type=dash_separated_ints, default="8-2-1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-interaction-op", type=str, choices=["dot", "cat"], default="dot"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-interaction-itself", action="store_true", default=False
|
||||
)
|
||||
parser.add_argument("--weighted-pooling", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
|
||||
ln_top = np.fromstring(args.arch_mlp_top, dtype=int, sep="-")
|
||||
m_den = ln_bot[0]
|
||||
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-")
|
||||
m_spa = args.arch_sparse_feature_size
|
||||
ln_emb = np.asarray(ln_emb)
|
||||
num_fea = ln_emb.size + 1 # num sparse + num dense features
|
||||
|
||||
|
||||
# Initialize the model.
|
||||
dlrm_model = DLRM_Net(
|
||||
m_spa=m_spa,
|
||||
ln_emb=ln_emb,
|
||||
ln_bot=ln_bot,
|
||||
ln_top=ln_top,
|
||||
arch_interaction_op=args.arch_interaction_op,
|
||||
)
|
||||
|
||||
|
||||
# Inputs to the model.
|
||||
dense_inp = torch.tensor([[0.6965, 0.2861, 0.2269, 0.5513]])
|
||||
vs0 = torch.tensor([[0], [0], [0]], dtype=torch.int64)
|
||||
vsi = torch.tensor([1, 2, 3]), torch.tensor([1]), torch.tensor([1])
|
||||
|
||||
input_dlrm = (dense_inp, vs0, *vsi)
|
||||
|
||||
golden_output = dlrm_model(dense_inp, vs0, *vsi)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
dlrm_model,
|
||||
input_dlrm,
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
dlrm_mlir, func_name, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(input_dlrm)
|
||||
np.testing.assert_allclose(
|
||||
golden_output.detach().numpy(), result, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
|
||||
|
||||
# Verified via torch-mlir.
|
||||
# import torch_mlir
|
||||
# from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
||||
# module = torch_mlir.compile(
|
||||
# dlrm_model, inputs, use_tracing=True, output_type="linalg-on-tensors"
|
||||
# )
|
||||
# backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
# compiled = backend.compile(module)
|
||||
# jit_module = backend.load(compiled)
|
||||
|
||||
# dense_numpy = dense_inp.numpy()
|
||||
# vs0_numpy = vs0.numpy()
|
||||
# vsi_numpy = [inp.numpy() for inp in vsi]
|
||||
|
||||
# numpy_inp = (dense_numpy, vs0_numpy, *vsi_numpy)
|
||||
|
||||
# print(jit_module.forward(*numpy_inp))
|
||||
@@ -1,314 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchrec.datasets.utils import Batch
|
||||
from torchrec.modules.crossnet import LowRankCrossNet
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from torchrec.models.dlrm import (
|
||||
choose,
|
||||
DenseArch,
|
||||
DLRM,
|
||||
InteractionArch,
|
||||
SparseArch,
|
||||
OverArch,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
def calculate_offsets(tensor_list, prev_values, prev_offsets):
|
||||
offset_init = 0
|
||||
offset_list = []
|
||||
values_list = []
|
||||
|
||||
if prev_offsets != None:
|
||||
offset_init = prev_values.shape[-1]
|
||||
for tensor in tensor_list:
|
||||
offset_list.append(offset_init)
|
||||
offset_init += tensor.shape[0]
|
||||
|
||||
concatendated_tensor_list = torch.cat(tensor_list)
|
||||
|
||||
if prev_values != None:
|
||||
concatendated_tensor_list = torch.cat(
|
||||
[prev_values, concatendated_tensor_list]
|
||||
)
|
||||
|
||||
concatenated_offsets = torch.tensor(offset_list)
|
||||
|
||||
if prev_offsets != None:
|
||||
concatenated_offsets = torch.cat([prev_offsets, concatenated_offsets])
|
||||
|
||||
return concatendated_tensor_list, concatenated_offsets
|
||||
|
||||
|
||||
# Have to make combined_keys as dict as to which embedding bags they
|
||||
# point to. {f1: 0, f3: 0, f2: 1}
|
||||
# The result will be a triple containing values, indices and pointer tensor.
|
||||
def to_list(key_jagged, combined_keys):
|
||||
key_jagged_dict = key_jagged.to_dict()
|
||||
combined_list = []
|
||||
|
||||
for key in combined_keys:
|
||||
prev_values, prev_offsets = calculate_offsets(
|
||||
key_jagged_dict[key].to_dense(), None, None
|
||||
)
|
||||
print(prev_values)
|
||||
print(prev_offsets)
|
||||
combined_list.append(prev_values)
|
||||
combined_list.append(prev_offsets)
|
||||
combined_list.append(torch.tensor(combined_keys[key]))
|
||||
|
||||
return combined_list
|
||||
|
||||
|
||||
class SparseArchShark(nn.Module):
|
||||
def create_emb(self, embedding_dim, num_embeddings_list):
|
||||
embedding_list = nn.ModuleList()
|
||||
for i in range(0, num_embeddings_list.size):
|
||||
num_embeddings = num_embeddings_list[i]
|
||||
EE = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum")
|
||||
W = np.random.uniform(
|
||||
low=-np.sqrt(1 / num_embeddings),
|
||||
high=np.sqrt(1 / num_embeddings),
|
||||
size=(num_embeddings, embedding_dim),
|
||||
).astype(np.float32)
|
||||
EE.weight.data = torch.tensor(W, requires_grad=True)
|
||||
embedding_list.append(EE)
|
||||
return embedding_list
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
total_features,
|
||||
num_embeddings_list,
|
||||
):
|
||||
super(SparseArchShark, self).__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_features = total_features
|
||||
self.embedding_list = self.create_emb(
|
||||
embedding_dim, num_embeddings_list
|
||||
)
|
||||
|
||||
def forward(self, *batched_inputs):
|
||||
|
||||
concatenated_list = []
|
||||
input_enum, embedding_enum = 0, 0
|
||||
|
||||
for k in range(len(batched_inputs) // 3):
|
||||
values = batched_inputs[input_enum]
|
||||
input_enum += 1
|
||||
offsets = batched_inputs[input_enum]
|
||||
input_enum += 1
|
||||
embedding_pointer = int(batched_inputs[input_enum])
|
||||
input_enum += 1
|
||||
|
||||
E = self.embedding_list[embedding_pointer]
|
||||
V = E(values, offsets)
|
||||
concatenated_list.append(V)
|
||||
|
||||
return torch.cat(concatenated_list, dim=1).reshape(
|
||||
-1, self.num_features, self.embedding_dim
|
||||
)
|
||||
|
||||
|
||||
def test_sparse_arch() -> None:
|
||||
|
||||
D = 3
|
||||
eb1_config = EmbeddingBagConfig(
|
||||
name="t1",
|
||||
embedding_dim=D,
|
||||
num_embeddings=10,
|
||||
feature_names=["f1", "f3"],
|
||||
)
|
||||
eb2_config = EmbeddingBagConfig(
|
||||
name="t2",
|
||||
embedding_dim=D,
|
||||
num_embeddings=10,
|
||||
feature_names=["f2"],
|
||||
)
|
||||
|
||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||
|
||||
w1 = ebc.embedding_bags["t1"].weight
|
||||
w2 = ebc.embedding_bags["t2"].weight
|
||||
|
||||
sparse_arch = SparseArch(ebc)
|
||||
|
||||
keys = ["f1", "f2", "f3", "f4", "f5"]
|
||||
offsets = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 19])
|
||||
features = KeyedJaggedTensor.from_offsets_sync(
|
||||
keys=keys,
|
||||
values=torch.tensor(
|
||||
[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]
|
||||
),
|
||||
offsets=offsets,
|
||||
)
|
||||
sparse_archi = SparseArchShark(D, 3, np.array([10, 10]))
|
||||
sparse_archi.embedding_list[0].weight = w1
|
||||
sparse_archi.embedding_list[1].weight = w2
|
||||
inputs = to_list(features, {"f1": 0, "f3": 0, "f2": 1})
|
||||
|
||||
test_results = sparse_archi(*inputs)
|
||||
sparse_features = sparse_arch(features)
|
||||
|
||||
torch.allclose(
|
||||
sparse_features,
|
||||
test_results,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
|
||||
test_sparse_arch()
|
||||
|
||||
|
||||
class DLRMShark(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
total_features,
|
||||
num_embeddings_list,
|
||||
dense_in_features: int,
|
||||
dense_arch_layer_sizes: List[int],
|
||||
over_arch_layer_sizes: List[int],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.sparse_arch: SparseArchShark = SparseArchShark(
|
||||
embedding_dim, total_features, num_embeddings_list
|
||||
)
|
||||
num_sparse_features: int = total_features
|
||||
|
||||
self.dense_arch = DenseArch(
|
||||
in_features=dense_in_features,
|
||||
layer_sizes=dense_arch_layer_sizes,
|
||||
)
|
||||
|
||||
self.inter_arch = InteractionArch(
|
||||
num_sparse_features=num_sparse_features,
|
||||
)
|
||||
|
||||
over_in_features: int = (
|
||||
embedding_dim
|
||||
+ choose(num_sparse_features, 2)
|
||||
+ num_sparse_features
|
||||
)
|
||||
|
||||
self.over_arch = OverArch(
|
||||
in_features=over_in_features,
|
||||
layer_sizes=over_arch_layer_sizes,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, dense_features: torch.Tensor, *sparse_features
|
||||
) -> torch.Tensor:
|
||||
|
||||
embedded_dense = self.dense_arch(dense_features)
|
||||
embedded_sparse = self.sparse_arch(*sparse_features)
|
||||
concatenated_dense = self.inter_arch(
|
||||
dense_features=embedded_dense, sparse_features=embedded_sparse
|
||||
)
|
||||
logits = self.over_arch(concatenated_dense)
|
||||
return logits
|
||||
|
||||
|
||||
def test_dlrm() -> None:
|
||||
B = 2
|
||||
D = 8
|
||||
dense_in_features = 100
|
||||
|
||||
eb1_config = EmbeddingBagConfig(
|
||||
name="t1",
|
||||
embedding_dim=D,
|
||||
num_embeddings=100,
|
||||
feature_names=["f1", "f3"],
|
||||
)
|
||||
eb2_config = EmbeddingBagConfig(
|
||||
name="t2",
|
||||
embedding_dim=D,
|
||||
num_embeddings=100,
|
||||
feature_names=["f2"],
|
||||
)
|
||||
|
||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||
|
||||
sparse_features = KeyedJaggedTensor.from_offsets_sync(
|
||||
keys=["f1", "f3", "f2"],
|
||||
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]),
|
||||
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]),
|
||||
)
|
||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||
sparse_nn = DLRM(
|
||||
embedding_bag_collection=ebc,
|
||||
dense_in_features=dense_in_features,
|
||||
dense_arch_layer_sizes=[20, D],
|
||||
over_arch_layer_sizes=[5, 1],
|
||||
)
|
||||
sparse_nn_nod = DLRMShark(
|
||||
embedding_dim=8,
|
||||
total_features=3,
|
||||
num_embeddings_list=np.array([100, 100]),
|
||||
dense_in_features=dense_in_features,
|
||||
dense_arch_layer_sizes=[20, D],
|
||||
over_arch_layer_sizes=[5, 1],
|
||||
)
|
||||
|
||||
dense_features = torch.rand((B, dense_in_features))
|
||||
|
||||
x = to_list(sparse_features, {"f1": 0, "f3": 0, "f2": 1})
|
||||
|
||||
w1 = ebc.embedding_bags["t1"].weight
|
||||
w2 = ebc.embedding_bags["t2"].weight
|
||||
|
||||
sparse_nn_nod.sparse_arch.embedding_list[0].weight = w1
|
||||
sparse_nn_nod.sparse_arch.embedding_list[1].weight = w2
|
||||
|
||||
sparse_nn_nod.dense_arch.load_state_dict(sparse_nn.dense_arch.state_dict())
|
||||
sparse_nn_nod.inter_arch.load_state_dict(sparse_nn.inter_arch.state_dict())
|
||||
sparse_nn_nod.over_arch.load_state_dict(sparse_nn.over_arch.state_dict())
|
||||
|
||||
logits = sparse_nn(
|
||||
dense_features=dense_features,
|
||||
sparse_features=sparse_features,
|
||||
)
|
||||
logits_nod = sparse_nn_nod(dense_features, *x)
|
||||
|
||||
# print(logits)
|
||||
# print(logits_nod)
|
||||
|
||||
# Import the module and print.
|
||||
mlir_importer = SharkImporter(
|
||||
sparse_nn_nod,
|
||||
(dense_features, *x),
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
dlrm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
|
||||
torch.allclose(
|
||||
logits,
|
||||
logits_nod,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
|
||||
test_dlrm()
|
||||
@@ -1,272 +0,0 @@
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
from tqdm.auto import tqdm
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import numpy as np
|
||||
|
||||
# pip install diffusers
|
||||
# pip install scipy
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a photograph of an astronaut riding a horse",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument("--steps", type=int, default=10, help="the device to use")
|
||||
p.add_argument("--mlir_loc", type=str, default=None, help="the device to use")
|
||||
p.add_argument("--vae_loc", type=str, default=None, help="the device to use")
|
||||
args = p.parse_args()
|
||||
|
||||
#####################################################
|
||||
|
||||
|
||||
def load_mlir(mlir_loc):
|
||||
import os
|
||||
|
||||
if mlir_loc == None:
|
||||
return None
|
||||
print(f"Trying to load the model from {mlir_loc}.")
|
||||
with open(os.path.join(mlir_loc)) as f:
|
||||
mlir_module = f.read()
|
||||
return mlir_module
|
||||
|
||||
|
||||
def compile_through_fx(model, inputs, mlir_loc=None, extra_args=[]):
|
||||
|
||||
module = load_mlir(mlir_loc)
|
||||
if mlir_loc == None:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(*inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
inputs,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile(extra_args)
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
# 1. Load the autoencoder model which will be used to decode the latents into image space.
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.vae.decode(input, return_dict=False)[0]
|
||||
|
||||
vae = VaeModel()
|
||||
vae_input = torch.rand(1, 4, 64, 64)
|
||||
shark_vae = compile_through_fx(vae, (vae_input,), args.vae_loc)
|
||||
|
||||
# Wrap the unet model to return tuples.
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
# 3. The UNet model for generating the latents.
|
||||
unet = UnetModel()
|
||||
latent_model_input = torch.rand([2, 4, 64, 64])
|
||||
text_embeddings = torch.rand([2, 77, 768])
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
(latent_model_input, torch.tensor([1.0]), text_embeddings),
|
||||
args.mlir_loc,
|
||||
["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
|
||||
)
|
||||
|
||||
# torch.jit.script(unet)
|
||||
|
||||
scheduler = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
prompt = [args.prompt]
|
||||
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
|
||||
num_inference_steps = args.steps # Number of denoising steps
|
||||
|
||||
guidance_scale = 7.5 # Scale for classifier-free guidance
|
||||
|
||||
generator = torch.manual_seed(
|
||||
42
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_input = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_embeddings = text_encoder(text_input.input_ids)[0]
|
||||
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer(
|
||||
[""] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = text_encoder(uncond_input.input_ids)[0]
|
||||
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
)
|
||||
# latents = latents.to(torch_device)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
latents = latents * scheduler.sigmas[0]
|
||||
# print(latents, latents.shape)
|
||||
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
|
||||
print(f"i = {i} t = {t}")
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
sigma = scheduler.sigmas[i]
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
# with torch.no_grad():
|
||||
# noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
|
||||
|
||||
latent_model_input_numpy = latent_model_input.detach().numpy()
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
|
||||
noise_pred = shark_unet.forward(
|
||||
(
|
||||
latent_model_input_numpy,
|
||||
np.array([t]).astype(np.float32),
|
||||
text_embeddings_numpy,
|
||||
)
|
||||
)
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
|
||||
|
||||
# print("Latents shape : ", latents.shape)
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents.detach().numpy()
|
||||
image = shark_vae.forward((latents_numpy,))
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
pil_images[0].save("astro.jpg")
|
||||
@@ -1,278 +0,0 @@
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
from tqdm.auto import tqdm
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import numpy as np
|
||||
|
||||
# pip install diffusers
|
||||
# pip install scipy
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a photograph of an astronaut riding a horse",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument("--steps", type=int, default=50, help="the device to use")
|
||||
p.add_argument("--mlir_loc", type=str, default=None, help="the device to use")
|
||||
p.add_argument("--vae_loc", type=str, default=None, help="the device to use")
|
||||
args = p.parse_args()
|
||||
|
||||
#####################################################
|
||||
|
||||
|
||||
def fp16_unet():
|
||||
from shark.shark_downloader import download_torch_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"stable_diff_f16_18_OCT", tank_url="gs://shark_tank/prashant_nod"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
return shark_module
|
||||
|
||||
|
||||
def load_mlir(mlir_loc):
|
||||
import os
|
||||
|
||||
if mlir_loc == None:
|
||||
return None
|
||||
print(f"Trying to load the model from {mlir_loc}.")
|
||||
with open(os.path.join(mlir_loc)) as f:
|
||||
mlir_module = f.read()
|
||||
return mlir_module
|
||||
|
||||
|
||||
def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
|
||||
module = load_mlir(mlir_loc)
|
||||
if mlir_loc == None:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(*inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
inputs,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
# 1. Load the autoencoder model which will be used to decode the latents into image space.
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.vae.decode(input, return_dict=False)[0]
|
||||
|
||||
vae = VaeModel()
|
||||
vae_input = torch.rand(1, 4, 64, 64)
|
||||
shark_vae = compile_through_fx(vae, (vae_input,), args.vae_loc)
|
||||
|
||||
# Wrap the unet model to return tuples.
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
# # 3. The UNet model for generating the latents.
|
||||
unet = UnetModel()
|
||||
|
||||
shark_unet = fp16_unet()
|
||||
|
||||
scheduler = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
prompt = [args.prompt]
|
||||
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
|
||||
num_inference_steps = args.steps # Number of denoising steps
|
||||
|
||||
guidance_scale = 7.5 # Scale for classifier-free guidance
|
||||
|
||||
generator = torch.manual_seed(
|
||||
42
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_input = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_embeddings = text_encoder(text_input.input_ids)[0]
|
||||
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer(
|
||||
[""] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = text_encoder(uncond_input.input_ids)[0]
|
||||
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
)
|
||||
# latents = latents.to(torch_device)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
latents = latents * scheduler.sigmas[0]
|
||||
# print(latents, latents.shape)
|
||||
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
|
||||
print(f"i = {i} t = {t}")
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
sigma = scheduler.sigmas[i]
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
# with torch.no_grad():
|
||||
# noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
|
||||
|
||||
latent_model_input_numpy = (
|
||||
latent_model_input.detach().numpy().astype(np.half)
|
||||
)
|
||||
text_embeddings_numpy = (
|
||||
text_embeddings.detach().numpy().astype(np.half)
|
||||
)
|
||||
|
||||
noise_pred = shark_unet.forward(
|
||||
(
|
||||
latent_model_input_numpy,
|
||||
np.array([t]).astype(np.half),
|
||||
text_embeddings_numpy,
|
||||
)
|
||||
)
|
||||
noise_pred = torch.from_numpy(noise_pred).to(torch.float32)
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
|
||||
|
||||
# print("Latents shape : ", latents.shape)
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents.detach().numpy()
|
||||
image = shark_vae.forward((latents_numpy,))
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
pil_images[0].save("astro.jpg")
|
||||
@@ -1,313 +0,0 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from keras_cv.models.generative.stable_diffusion.clip_tokenizer import (
|
||||
SimpleTokenizer,
|
||||
)
|
||||
from keras_cv.models.generative.stable_diffusion.constants import (
|
||||
_ALPHAS_CUMPROD,
|
||||
)
|
||||
from keras_cv.models.generative.stable_diffusion.constants import (
|
||||
_UNCONDITIONAL_TOKENS,
|
||||
)
|
||||
from keras_cv.models.generative.stable_diffusion.decoder import Decoder
|
||||
from keras_cv.models.generative.stable_diffusion.text_encoder import (
|
||||
TextEncoder,
|
||||
)
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
from PIL import Image
|
||||
|
||||
# pip install "git+https://github.com/keras-team/keras-cv.git"
|
||||
# pip install tensorflow_dataset
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a photograph of an astronaut riding a horse",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument(
|
||||
"--steps", type=int, default=10, help="the number of steps to use"
|
||||
)
|
||||
p.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="the file to save the resulting image to. (default to <input prompt>.jpg)",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
#####################################################
|
||||
|
||||
MAX_PROMPT_LENGTH = 77
|
||||
|
||||
|
||||
class SharkStableDiffusion:
|
||||
"""Shark implementation of Stable Diffusion based on model from keras_cv.
|
||||
Stable Diffusion is a powerful image generation model that can be used,
|
||||
among other things, to generate pictures according to a short text description
|
||||
(called a "prompt").
|
||||
Arguments:
|
||||
device: Device to use with SHARK. Default: cpu
|
||||
jit_compile: Whether to compile the underlying models to XLA.
|
||||
This can lead to a significant speedup on some systems. Default: False.
|
||||
References:
|
||||
- [About Stable Diffusion](https://stability.ai/blog/stable-diffusion-announcement)
|
||||
- [Original implementation](https://github.com/CompVis/stable-diffusion)
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", jit_compile=True):
|
||||
self.img_height = 512
|
||||
self.img_width = 512
|
||||
self.tokenizer = SimpleTokenizer()
|
||||
|
||||
# Create models
|
||||
self.text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_tf_model(
|
||||
"stable_diff", tank_url="gs://shark_tank/quinn"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
self.diffusion_model = shark_module
|
||||
self.decoder = Decoder(self.img_height, self.img_width)
|
||||
if jit_compile:
|
||||
self.text_encoder.compile(jit_compile=True)
|
||||
self.decoder.compile(jit_compile=True)
|
||||
|
||||
print(
|
||||
"By using this model checkpoint, you acknowledge that its usage is "
|
||||
"subject to the terms of the CreativeML Open RAIL-M license at "
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE"
|
||||
)
|
||||
# Load weights
|
||||
text_encoder_weights_fpath = keras.utils.get_file(
|
||||
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
|
||||
file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
|
||||
)
|
||||
decoder_weights_fpath = keras.utils.get_file(
|
||||
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5",
|
||||
file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962",
|
||||
)
|
||||
self.text_encoder.load_weights(text_encoder_weights_fpath)
|
||||
self.decoder.load_weights(decoder_weights_fpath)
|
||||
|
||||
def text_to_image(
|
||||
self,
|
||||
prompt,
|
||||
batch_size=1,
|
||||
num_steps=25,
|
||||
unconditional_guidance_scale=7.5,
|
||||
seed=None,
|
||||
):
|
||||
encoded_text = self.encode_text(prompt)
|
||||
|
||||
return self.generate_image(
|
||||
encoded_text,
|
||||
batch_size=batch_size,
|
||||
num_steps=num_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def encode_text(self, prompt):
|
||||
"""Encodes a prompt into a latent text encoding.
|
||||
The encoding produced by this method should be used as the
|
||||
`encoded_text` parameter of `StableDiffusion.generate_image`. Encoding
|
||||
text separately from generating an image can be used to arbitrarily
|
||||
modify the text encoding priot to image generation, e.g. for walking
|
||||
between two prompts.
|
||||
Args:
|
||||
prompt: a string to encode, must be 77 tokens or shorter.
|
||||
Example:
|
||||
```python
|
||||
from keras_cv.models import StableDiffusion
|
||||
model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
|
||||
encoded_text = model.encode_text("Tacos at dawn")
|
||||
img = model.generate_image(encoded_text)
|
||||
```
|
||||
"""
|
||||
# Tokenize prompt (i.e. starting context)
|
||||
inputs = self.tokenizer.encode(prompt)
|
||||
if len(inputs) > MAX_PROMPT_LENGTH:
|
||||
raise ValueError(
|
||||
f"Prompt is too long (should be <= {MAX_PROMPT_LENGTH} tokens)"
|
||||
)
|
||||
phrase = inputs + [49407] * (MAX_PROMPT_LENGTH - len(inputs))
|
||||
phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)
|
||||
|
||||
context = self.text_encoder.predict_on_batch(
|
||||
[phrase, self._get_pos_ids()]
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
encoded_text,
|
||||
batch_size=1,
|
||||
num_steps=25,
|
||||
unconditional_guidance_scale=7.5,
|
||||
diffusion_noise=None,
|
||||
seed=None,
|
||||
):
|
||||
"""Generates an image based on encoded text.
|
||||
The encoding passed to this method should be derived from
|
||||
`StableDiffusion.encode_text`.
|
||||
Args:
|
||||
encoded_text: Tensor of shape (`batch_size`, 77, 768), or a Tensor
|
||||
of shape (77, 768). When the batch axis is omitted, the same encoded
|
||||
text will be used to produce every generated image.
|
||||
batch_size: number of images to generate. Default: 1.
|
||||
num_steps: number of diffusion steps (controls image quality).
|
||||
Default: 25.
|
||||
unconditional_guidance_scale: float controling how closely the image
|
||||
should adhere to the prompt. Larger values result in more
|
||||
closely adhering to the prompt, but will make the image noisier.
|
||||
Default: 7.5.
|
||||
diffusion_noise: Tensor of shape (`batch_size`, img_height // 8,
|
||||
img_width // 8, 4), or a Tensor of shape (img_height // 8,
|
||||
img_width // 8, 4). Optional custom noise to seed the diffusion
|
||||
process. When the batch axis is omitted, the same noise will be
|
||||
used to seed diffusion for every generated image.
|
||||
seed: integer which is used to seed the random generation of
|
||||
diffusion noise, only to be specified if `diffusion_noise` is
|
||||
None.
|
||||
Example:
|
||||
```python
|
||||
from keras_cv.models import StableDiffusion
|
||||
batch_size = 8
|
||||
model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
|
||||
e_tacos = model.encode_text("Tacos at dawn")
|
||||
e_watermelons = model.encode_text("Watermelons at dusk")
|
||||
e_interpolated = tf.linspace(e_tacos, e_watermelons, batch_size)
|
||||
images = model.generate_image(e_interpolated, batch_size=batch_size)
|
||||
```
|
||||
"""
|
||||
if diffusion_noise is not None and seed is not None:
|
||||
raise ValueError(
|
||||
"`diffusion_noise` and `seed` should not both be passed to "
|
||||
"`generate_image`. `seed` is only used to generate diffusion "
|
||||
"noise when it's not already user-specified."
|
||||
)
|
||||
|
||||
encoded_text = tf.squeeze(encoded_text)
|
||||
if encoded_text.shape.rank == 2:
|
||||
encoded_text = tf.repeat(
|
||||
tf.expand_dims(encoded_text, axis=0), batch_size, axis=0
|
||||
)
|
||||
|
||||
context = encoded_text
|
||||
unconditional_context = tf.repeat(
|
||||
self._get_unconditional_context(), batch_size, axis=0
|
||||
)
|
||||
context = tf.concat([context, unconditional_context], 0)
|
||||
|
||||
if diffusion_noise is not None:
|
||||
diffusion_noise = tf.squeeze(diffusion_noise)
|
||||
if diffusion_noise.shape.rank == 3:
|
||||
diffusion_noise = tf.repeat(
|
||||
tf.expand_dims(diffusion_noise, axis=0), batch_size, axis=0
|
||||
)
|
||||
latent = diffusion_noise
|
||||
else:
|
||||
latent = self._get_initial_diffusion_noise(batch_size, seed)
|
||||
|
||||
# Iterative reverse diffusion stage
|
||||
timesteps = tf.range(1, 1000, 1000 // num_steps)
|
||||
alphas, alphas_prev = self._get_initial_alphas(timesteps)
|
||||
progbar = keras.utils.Progbar(len(timesteps))
|
||||
iteration = 0
|
||||
for index, timestep in list(enumerate(timesteps))[::-1]:
|
||||
latent_prev = latent # Set aside the previous latent vector
|
||||
t_emb = self._get_timestep_embedding(timestep, batch_size)
|
||||
|
||||
# Prepare the latent and unconditional latent to be run with a single forward call
|
||||
latent = tf.concat([latent, latent], 0)
|
||||
t_emb = tf.concat([t_emb, t_emb], 0)
|
||||
latent_numpy = self.diffusion_model.forward(
|
||||
[latent.numpy(), t_emb.numpy(), context.numpy()]
|
||||
)
|
||||
latent = tf.convert_to_tensor(latent_numpy, dtype=tf.float32)
|
||||
latent, unconditional_latent = tf.split(latent, 2)
|
||||
|
||||
latent = unconditional_latent + unconditional_guidance_scale * (
|
||||
latent - unconditional_latent
|
||||
)
|
||||
a_t, a_prev = alphas[index], alphas_prev[index]
|
||||
pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(
|
||||
a_t
|
||||
)
|
||||
latent = (
|
||||
latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0
|
||||
)
|
||||
iteration += 1
|
||||
progbar.update(iteration)
|
||||
|
||||
# Decoding stage
|
||||
decoded = self.decoder.predict_on_batch(latent)
|
||||
decoded = ((decoded + 1) / 2) * 255
|
||||
return np.clip(decoded, 0, 255).astype("uint8")
|
||||
|
||||
def _get_unconditional_context(self):
|
||||
unconditional_tokens = tf.convert_to_tensor(
|
||||
[_UNCONDITIONAL_TOKENS], dtype=tf.int32
|
||||
)
|
||||
unconditional_context = self.text_encoder.predict_on_batch(
|
||||
[unconditional_tokens, self._get_pos_ids()]
|
||||
)
|
||||
|
||||
return unconditional_context
|
||||
|
||||
def _get_timestep_embedding(
|
||||
self, timestep, batch_size, dim=320, max_period=10000
|
||||
):
|
||||
half = dim // 2
|
||||
freqs = tf.math.exp(
|
||||
-math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
|
||||
)
|
||||
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
|
||||
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
|
||||
embedding = tf.reshape(embedding, [1, -1])
|
||||
return tf.repeat(embedding, batch_size, axis=0)
|
||||
|
||||
def _get_initial_alphas(self, timesteps):
|
||||
alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
|
||||
alphas_prev = [1.0] + alphas[:-1]
|
||||
|
||||
return alphas, alphas_prev
|
||||
|
||||
def _get_initial_diffusion_noise(self, batch_size, seed):
|
||||
return tf.random.normal(
|
||||
(batch_size, self.img_height // 8, self.img_width // 8, 4),
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_pos_ids():
|
||||
return tf.convert_to_tensor(
|
||||
[list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SD = SharkStableDiffusion(device=args.device)
|
||||
images = SD.text_to_image(args.prompt, num_steps=args.steps)
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
save_fname = args.prompt + ".jpg"
|
||||
if args.save_path is not None:
|
||||
save_fname = args.save_path
|
||||
pil_images[0].save(save_fname)
|
||||
@@ -1,2 +0,0 @@
|
||||
*.vmfb
|
||||
*.jpg
|
||||
@@ -1,15 +0,0 @@
|
||||
# STABLE DIFFUSION
|
||||
|
||||
## Installation
|
||||
|
||||
```shell
|
||||
pip install diffusers
|
||||
pip install scipy
|
||||
```
|
||||
|
||||
## RUN
|
||||
|
||||
```shell
|
||||
python main.py --precision="fp32"|"fp16" --prompt="enter the text" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir
|
||||
|
||||
```
|
||||
@@ -1,25 +0,0 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(
|
||||
text=["a photo of a cat", "a photo of a dog"],
|
||||
images=image,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
|
||||
outputs = model(**inputs)
|
||||
logits_per_image = (
|
||||
outputs.logits_per_image
|
||||
) # this is the image-text similarity score
|
||||
probs = logits_per_image.softmax(
|
||||
dim=1
|
||||
) # we can take the softmax to get the label probabilities
|
||||
@@ -1,241 +0,0 @@
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from stable_args import args
|
||||
from model_wrappers import (
|
||||
get_vae32,
|
||||
get_vae16,
|
||||
get_unet16_wrapped,
|
||||
get_unet32_wrapped,
|
||||
get_clipped_text,
|
||||
)
|
||||
from utils import get_shark_model
|
||||
import time
|
||||
|
||||
GCLOUD_BUCKET = "gs://shark_tank/prashant_nod"
|
||||
VAE_FP16 = "vae_fp16"
|
||||
VAE_FP32 = "vae_fp32"
|
||||
UNET_FP16 = "unet_fp16"
|
||||
UNET_FP32 = "unet_fp32"
|
||||
IREE_EXTRA_ARGS = []
|
||||
|
||||
TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn"
|
||||
UNET_FP16_TUNED = "unet_fp16_tunedv2"
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
|
||||
if BATCH_SIZE not in [1, 2]:
|
||||
import sys
|
||||
|
||||
sys.exit("Only batch size 1 and 2 are supported.")
|
||||
|
||||
if BATCH_SIZE > 1 and args.precision != "fp16":
|
||||
sys.exit("batch size > 1 is supported for fp16 model.")
|
||||
|
||||
|
||||
if BATCH_SIZE != 1:
|
||||
TUNED_GCLOUD_BUCKET = "gs://shark_tank/prashant_nod"
|
||||
UNET_FP16_TUNED = f"unet_fp16_{BATCH_SIZE}"
|
||||
VAE_FP16 = f"vae_fp16_{BATCH_SIZE}"
|
||||
|
||||
# Helper function to profile the vulkan device.
|
||||
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
|
||||
if args.vulkan_debug_utils and "vulkan" in args.device:
|
||||
import iree
|
||||
|
||||
print(f"Profiling and saving to {file_path}.")
|
||||
vulkan_device = iree.runtime.get_device(args.device)
|
||||
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
|
||||
return vulkan_device
|
||||
return None
|
||||
|
||||
|
||||
def end_profiling(device):
|
||||
if device:
|
||||
return device.end_profiling()
|
||||
|
||||
|
||||
def get_models():
|
||||
global IREE_EXTRA_ARGS
|
||||
if args.precision == "fp16":
|
||||
IREE_EXTRA_ARGS += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
]
|
||||
if args.use_tuned:
|
||||
unet_gcloud_bucket = TUNED_GCLOUD_BUCKET
|
||||
vae_gcloud_bucket = GCLOUD_BUCKET
|
||||
unet_args = IREE_EXTRA_ARGS
|
||||
vae_args = IREE_EXTRA_ARGS + [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
|
||||
]
|
||||
unet_name = UNET_FP16_TUNED
|
||||
vae_name = VAE_FP16
|
||||
else:
|
||||
unet_gcloud_bucket = GCLOUD_BUCKET
|
||||
vae_gcloud_bucket = GCLOUD_BUCKET
|
||||
IREE_EXTRA_ARGS += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
|
||||
]
|
||||
unet_args = IREE_EXTRA_ARGS
|
||||
vae_args = IREE_EXTRA_ARGS
|
||||
unet_name = UNET_FP16
|
||||
vae_name = VAE_FP16
|
||||
|
||||
if batch_size > 1:
|
||||
vae_args = []
|
||||
|
||||
if args.import_mlir == True:
|
||||
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
|
||||
model_name=UNET_FP16
|
||||
)
|
||||
else:
|
||||
return get_shark_model(
|
||||
vae_gcloud_bucket,
|
||||
vae_name,
|
||||
vae_args,
|
||||
), get_shark_model(
|
||||
unet_gcloud_bucket,
|
||||
unet_name,
|
||||
unet_args,
|
||||
)
|
||||
|
||||
elif args.precision == "fp32":
|
||||
IREE_EXTRA_ARGS += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
]
|
||||
if args.import_mlir == True:
|
||||
return get_vae32(model_name=VAE_FP32), get_unet32_wrapped(
|
||||
model_name=UNET_FP32
|
||||
)
|
||||
else:
|
||||
return get_shark_model(
|
||||
GCLOUD_BUCKET,
|
||||
VAE_FP32,
|
||||
IREE_EXTRA_ARGS,
|
||||
), get_shark_model(
|
||||
GCLOUD_BUCKET,
|
||||
UNET_FP32,
|
||||
IREE_EXTRA_ARGS,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
IREE_EXTRA_ARGS.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
clip_model = "clip_text"
|
||||
clip_extra_args = [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
]
|
||||
clip = get_shark_model(GCLOUD_BUCKET, clip_model, clip_extra_args)
|
||||
|
||||
prompt = args.prompts
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
|
||||
num_inference_steps = args.steps # Number of denoising steps
|
||||
|
||||
guidance_scale = args.guidance_scale # Scale for classifier-free guidance
|
||||
|
||||
generator = torch.manual_seed(
|
||||
args.seed
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
vae, unet = get_models()
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
scheduler = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
|
||||
text_input = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=args.max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_embeddings = clip.forward((text_input.input_ids,))
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer(
|
||||
[""] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = clip.forward((uncond_input.input_ids,))
|
||||
uncond_embeddings = torch.from_numpy(uncond_embeddings).to(dtype)
|
||||
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler.is_scale_input_called = True
|
||||
|
||||
latents = latents * scheduler.sigmas[0]
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
avg_ms = 0
|
||||
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
step_start = time.time()
|
||||
print(f"i = {i} t = {t}", end="")
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latents_numpy = latents.detach().numpy()
|
||||
sigma_numpy = np.array(scheduler.sigmas[i]).astype(np.float32)
|
||||
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
noise_pred = unet.forward(
|
||||
(latents_numpy, timestep, text_embeddings_numpy, sigma_numpy)
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
step_time = time.time() - step_start
|
||||
avg_ms += step_time
|
||||
step_ms = int((step_time) * 1000)
|
||||
print(f" ({step_ms}ms)")
|
||||
|
||||
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
print(f"Average step time: {avg_ms}ms/it")
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents.detach().numpy()
|
||||
profile_device = start_profiling(file_path="vae.rdc")
|
||||
image = vae.forward((latents_numpy,))
|
||||
end_profiling(profile_device)
|
||||
image = torch.from_numpy(image)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
|
||||
print("Total image generation runtime (s): {}".format(time.time() - start))
|
||||
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
for i in range(batch_size):
|
||||
pil_images[i].save(f"{args.prompts[i]}_{i}.jpg")
|
||||
@@ -1,223 +0,0 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
||||
from transformers import CLIPTextModel
|
||||
from utils import compile_through_fx
|
||||
from stable_args import args
|
||||
import torch
|
||||
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
|
||||
|
||||
def get_clipped_text(model_name="clip_text"):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
clip_input = torch.randint(1, 2, (BATCH_SIZE, 77))
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
(clip_input,),
|
||||
model_name=model_name,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
|
||||
def get_vae32(model_name="vae_fp32"):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
return (x / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
vae = VaeModel()
|
||||
vae_input = torch.rand(BATCH_SIZE, 4, 64, 64)
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
(vae_input,),
|
||||
model_name=model_name,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_vae16(model_name="vae_fp16"):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
revision="fp16",
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
return (x / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
vae = VaeModel()
|
||||
vae = vae.half().cuda()
|
||||
vae_input = torch.rand(BATCH_SIZE, 4, 64, 64, dtype=torch.half).cuda()
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
(vae_input,),
|
||||
model_name=model_name,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_unet32(model_name="unet_fp32"):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
unet = UnetModel()
|
||||
latent_model_input = torch.rand([2, 4, 64, 64])
|
||||
text_embeddings = torch.rand([2, args.max_length, 768])
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
(latent_model_input, torch.tensor([1.0]), text_embeddings),
|
||||
model_name=model_name,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
|
||||
def get_unet16(model_name="unet_fp16"):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
revision="fp16",
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
unet = UnetModel()
|
||||
unet = unet.half().cuda()
|
||||
latent_model_input = torch.rand([2, 4, 64, 64]).half().cuda()
|
||||
text_embeddings = torch.rand([2, args.max_length, 768]).half().cuda()
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
(
|
||||
latent_model_input,
|
||||
torch.tensor([1.0]).half().cuda(),
|
||||
text_embeddings,
|
||||
),
|
||||
model_name=model_name,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
|
||||
def get_unet16_wrapped(guidance_scale=7.5, model_name="unet_fp16_wrapped"):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, guidance_scale=guidance_scale):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
revision="fp16",
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.guidance_scale = guidance_scale
|
||||
self.train(False)
|
||||
|
||||
def forward(self, latent, timestep, text_embedding, sigma):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latents = torch.cat([latent] * 2)
|
||||
latents = latents / (torch.pow((torch.pow(sigma, 2) + 1), 0.5))
|
||||
unet_out = self.unet.forward(
|
||||
latents, timestep, text_embedding, return_dict=False
|
||||
)[0]
|
||||
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel()
|
||||
unet = unet.half().cuda()
|
||||
latent_model_input = torch.rand([BATCH_SIZE, 4, 64, 64]).half().cuda()
|
||||
text_embeddings = (
|
||||
torch.rand([2 * BATCH_SIZE, args.max_length, 768]).half().cuda()
|
||||
)
|
||||
sigma = torch.tensor(1).to(torch.float32)
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
(
|
||||
latent_model_input,
|
||||
torch.tensor([1.0]).half().cuda(),
|
||||
text_embeddings,
|
||||
sigma,
|
||||
),
|
||||
model_name=model_name,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
|
||||
def get_unet32_wrapped(guidance_scale=7.5, model_name="unet_fp32_wrapped"):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, guidance_scale=guidance_scale):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.guidance_scale = guidance_scale
|
||||
self.train(False)
|
||||
|
||||
def forward(self, latent, timestep, text_embedding, sigma):
|
||||
latents = torch.cat([latent] * 2)
|
||||
latents = latents / (torch.pow((torch.pow(sigma, 2) + 1), 0.5))
|
||||
unet_out = self.unet.forward(
|
||||
latents, timestep, text_embedding, return_dict=False
|
||||
)[0]
|
||||
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel()
|
||||
latent_model_input = torch.rand([BATCH_SIZE, 4, 64, 64])
|
||||
text_embeddings = torch.rand([2 * BATCH_SIZE, args.max_length, 768])
|
||||
sigma = torch.tensor(1).to(torch.float32)
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
(latent_model_input, torch.tensor([1.0]), text_embeddings, sigma),
|
||||
model_name=model_name,
|
||||
)
|
||||
return shark_unet
|
||||
@@ -1,88 +0,0 @@
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=["a photograph of an astronaut riding a horse"],
|
||||
help="text of which images to be generated.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--device", type=str, default="cpu", help="device to run the model."
|
||||
)
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="the no. of steps to do the sampling.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed to use.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="the value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp32", help="precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=77,
|
||||
help="max length of the tokenizer output.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="saves the compiled flatbuffer to the local directory",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree-vulkan-target-triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Profiles vulkan device and collects the .rdc info",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available",
|
||||
)
|
||||
|
||||
args = p.parse_args()
|
||||
@@ -1,103 +0,0 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from stable_args import args
|
||||
from torch._decomp import get_decompositions
|
||||
import torch_mlir
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
extended_name = "{}_{}".format(model_name, args.device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print("Loading flatbuffer from {}".format(vmfb_path))
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
print("Saving to {}".format(vmfb_path))
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(
|
||||
vmfb_path
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path)
|
||||
else:
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
from shark.shark_downloader import download_torch_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
model_name, tank_url=tank_url
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
# Converts the torch-module into shark_module.
|
||||
def compile_through_fx(model, inputs, model_name, extra_args=[]):
|
||||
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(*inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.trace(fx_g, inputs)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
ts_g,
|
||||
inputs,
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(mlir_module, func_name), _, _ = mlir_importer.import_debug()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
@@ -11,12 +11,12 @@ t5_inputs = [
|
||||
tf.TensorSpec(shape=[1, 10], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class T5Module(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(T5Module, self).__init__()
|
||||
self.m = TFT5Model.from_pretrained("t5-small")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, decoder_input_ids=y)
|
||||
self.m.predict = lambda x,y: self.m(input_ids=x, decoder_input_ids=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs)
|
||||
def forward(self, input_ids, decoder_input_ids):
|
||||
@@ -27,9 +27,12 @@ if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
text = "I love the distilled version of models."
|
||||
inputs = tokenizer(text, return_tensors="tf").input_ids
|
||||
inputs = tokenizer(
|
||||
text, return_tensors="tf"
|
||||
).input_ids
|
||||
|
||||
shark_module = SharkInference(T5Module(), (inputs, inputs))
|
||||
shark_module = SharkInference(
|
||||
T5Module(), (inputs, inputs))
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward((inputs, inputs)))
|
||||
print(shark_module.forward((inputs,inputs)))
|
||||
|
||||
@@ -4,6 +4,7 @@ from shark.shark_inference import SharkInference
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark_runner import SharkInference
|
||||
|
||||
|
||||
# Currently not supported aten.transpose_conv2d missing.
|
||||
class UnetModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = torch.hub.load(
|
||||
@@ -15,7 +15,7 @@ class UnetModule(torch.nn.Module):
|
||||
init_features=32,
|
||||
pretrained=True,
|
||||
)
|
||||
self.model.eval()
|
||||
self.train(False)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
@@ -23,17 +23,10 @@ class UnetModule(torch.nn.Module):
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
print(input)
|
||||
shark_module = SharkInference(
|
||||
UnetModule(),
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(vision_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=False
|
||||
)
|
||||
|
||||
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input,))
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
shark_module.benchmark_forward((input,))
|
||||
print(input)
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
@@ -5,13 +5,17 @@ from shark.shark_runner import SharkTrainer
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
num_labels=
|
||||
2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=
|
||||
False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=
|
||||
False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
@@ -33,9 +37,8 @@ inp = (torch.randint(2, (1, 128)),)
|
||||
|
||||
def forward(params, buffers, args):
|
||||
params_and_buffers = {**params, **buffers}
|
||||
_stateless.functional_call(
|
||||
mod, params_and_buffers, args, {}
|
||||
).sum().backward()
|
||||
_stateless.functional_call(mod, params_and_buffers, args,
|
||||
{}).sum().backward()
|
||||
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
|
||||
# optim.load_state_dict(optim_state)
|
||||
optim.step()
|
||||
|
||||
@@ -5,14 +5,13 @@ import tensorflow as tf
|
||||
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
from shark.parser import parser
|
||||
from urllib import request
|
||||
from shark.shark_importer import shark_load
|
||||
|
||||
parser.add_argument(
|
||||
"--download_mlir_path",
|
||||
type=str,
|
||||
default="bert_tf_training.mlir",
|
||||
help="Specifies path to target mlir file that will be loaded.",
|
||||
)
|
||||
help="Specifies path to target mlir file that will be loaded.")
|
||||
load_args, unknown = parser.parse_known_args()
|
||||
|
||||
tf.random.set_seed(0)
|
||||
@@ -26,30 +25,16 @@ if __name__ == "__main__":
|
||||
predict_sample_input = [
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
]
|
||||
file_link = "https://storage.googleapis.com/shark_tank/users/stanley/bert_tf_training.mlir"
|
||||
response = request.urlretrieve(file_link, load_args.download_mlir_path)
|
||||
sample_input_tensors = [
|
||||
tf.convert_to_tensor(val, dtype=tf.int32)
|
||||
for val in predict_sample_input
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH))
|
||||
]
|
||||
model_name = "bert_tf_training"
|
||||
bert_mlir = shark_load(model_name, load_args.download_mlir_path)
|
||||
sample_input_tensors = [tf.convert_to_tensor(val, dtype=tf.int32) for val in predict_sample_input]
|
||||
num_iter = 10
|
||||
if not os.path.isfile(load_args.download_mlir_path):
|
||||
raise ValueError(
|
||||
f"Tried looking for target mlir in {load_args.download_mlir_path}, but cannot be found."
|
||||
)
|
||||
with open(load_args.download_mlir_path, "rb") as input_file:
|
||||
bert_mlir = input_file.read()
|
||||
shark_module = SharkTrainer(
|
||||
bert_mlir,
|
||||
(
|
||||
sample_input_tensors,
|
||||
tf.convert_to_tensor(
|
||||
np.random.randint(5, size=(BATCH_SIZE)), dtype=tf.int32
|
||||
),
|
||||
),
|
||||
)
|
||||
(sample_input_tensors,
|
||||
tf.convert_to_tensor(np.random.randint(5, size=(BATCH_SIZE)), dtype=tf.int32)))
|
||||
shark_module.set_frontend("mhlo")
|
||||
shark_module.compile()
|
||||
start = time.time()
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import sys
|
||||
from absl import app
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import tempfile
|
||||
import tensorflow as tf
|
||||
|
||||
from official.nlp.modeling import layers
|
||||
@@ -25,35 +28,31 @@ bert_input = [
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
dict_outputs = False
|
||||
test_network = networks.BertEncoder(
|
||||
vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs
|
||||
)
|
||||
test_network = networks.BertEncoder(vocab_size=vocab_size,
|
||||
num_layers=2,
|
||||
dict_outputs=dict_outputs)
|
||||
|
||||
# Create a BERT trainer with the created network.
|
||||
bert_trainer_model = bert_classifier.BertClassifier(
|
||||
test_network, num_classes=NUM_CLASSES
|
||||
)
|
||||
test_network, num_classes=NUM_CLASSES)
|
||||
bert_trainer_model.summary()
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m = bert_trainer_model
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
self.predict = tf.function(input_signature=[bert_input])(
|
||||
self.m.predict
|
||||
)
|
||||
self.predict = tf.function(input_signature=[bert_input])(self.m.predict)
|
||||
self.m.learn = lambda x, y: self.m.call(x, training=False)
|
||||
self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
)
|
||||
@tf.function(input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32) # labels
|
||||
])
|
||||
def forward(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
# Capture the gradients from forward prop...
|
||||
@@ -71,22 +70,14 @@ if __name__ == "__main__":
|
||||
predict_sample_input = [
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
]
|
||||
sample_input_tensors = [
|
||||
tf.convert_to_tensor(val, dtype=tf.int32)
|
||||
for val in predict_sample_input
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH))
|
||||
]
|
||||
sample_input_tensors = [tf.convert_to_tensor(val, dtype=tf.int32) for val in predict_sample_input]
|
||||
num_iter = 10
|
||||
shark_module = SharkTrainer(
|
||||
BertModule(),
|
||||
(
|
||||
sample_input_tensors,
|
||||
tf.convert_to_tensor(
|
||||
np.random.randint(5, size=(BATCH_SIZE)), dtype=tf.int32
|
||||
),
|
||||
),
|
||||
)
|
||||
(sample_input_tensors,
|
||||
tf.convert_to_tensor(np.random.randint(5, size=(BATCH_SIZE)), dtype=tf.int32)))
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
start = time.time()
|
||||
|
||||
@@ -4,6 +4,7 @@ from shark.shark_trainer import SharkTrainer
|
||||
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Foo, self).__init__()
|
||||
self.l1 = torch.nn.Linear(10, 16)
|
||||
@@ -27,9 +28,8 @@ def get_sorted_params(named_params):
|
||||
|
||||
def forward(params, buffers, args):
|
||||
params_and_buffers = {**params, **buffers}
|
||||
_stateless.functional_call(
|
||||
mod, params_and_buffers, args, {}
|
||||
).sum().backward()
|
||||
_stateless.functional_call(mod, params_and_buffers, args,
|
||||
{}).sum().backward()
|
||||
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
|
||||
optim.step()
|
||||
return params, buffers
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
# Stable Diffusion Img2Img model
|
||||
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
<summary>Installation (Linux)</summary>
|
||||
|
||||
### Activate shark.venv Virtual Environment
|
||||
|
||||
```shell
|
||||
source shark.venv/bin/activate
|
||||
|
||||
# Some older pip installs may not be able to handle the recent PyTorch deps
|
||||
python -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
### Install dependencies
|
||||
|
||||
# Run the setup.sh script
|
||||
|
||||
```shell
|
||||
./setup.sh
|
||||
```
|
||||
|
||||
### Run the Stable diffusion Img2Img model
|
||||
|
||||
To run the model with the default set of images and params, run:
|
||||
```shell
|
||||
python stable_diffusion_img2img.py
|
||||
```
|
||||
To run the model with your set of images, and parameters you need to specify the following params:
|
||||
1.) Input images directory with the arg `--input_dir` containing 3-5 images.
|
||||
2.) What to teach the model? Using the arg `--what_to_teach`, allowed values are `object` or `style`.
|
||||
3.) Placeholder token using the arg `--placeholder_token`, that represents your new concept. It should be passed with the opening and closing angle brackets. For ex: token is `cat-toy`, it should be passed as `<cat-toy>`.
|
||||
4.) Initializer token using the arg `--initializer_token`, which summarise what is your new concept.
|
||||
|
||||
For the result, you need to pass the text prompt with the arg: `--prompt`. The prompt string should contain a "*s" in it, which will be replaced by the placeholder token during the inference.
|
||||
|
||||
By default the result images will go into the `sd_result` dir. To specify your output dir use the arg: `--output_dir`.
|
||||
|
||||
The default value of max_training_steps is `3000`, which takes some hours to complete. You can pass the smaller value with the arg `--training_steps`. Specify the number of images to be sampled for the result with the `--num_inference_samples` arg.
|
||||
@@ -1,25 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
TD="$(cd $(dirname $0) && pwd)"
|
||||
if [ -z "$PYTHON" ]; then
|
||||
PYTHON="$(which python3)"
|
||||
fi
|
||||
|
||||
function die() {
|
||||
echo "Error executing command: $*"
|
||||
exit 1
|
||||
}
|
||||
|
||||
PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; print("{0}.{1}".format(*version))'`
|
||||
|
||||
echo "Python: $PYTHON"
|
||||
echo "Python version: $PYTHON_VERSION_X_Y"
|
||||
|
||||
mkdir input_images
|
||||
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg -P input_images/
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg -P input_images/
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg -P input_images/
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg -P input_images/
|
||||
|
||||
pip install diffusers["training"]==0.4.1 transformers ftfy opencv-python
|
||||
@@ -1,597 +0,0 @@
|
||||
# Textual-inversion fine-tuning for Stable Diffusion using diffusers
|
||||
# This script shows how to "teach" Stable Diffusion a new concept via
|
||||
# textual-inversion using 🤗 Hugging Face [🧨 Diffusers library](https://github.com/huggingface/diffusers).
|
||||
# By using just 3-5 images you can teach new concepts to Stable Diffusion
|
||||
# and personalize the model on your own images.
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import PIL
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
p.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
default="input_images/",
|
||||
help="the directory contains the images used for fine tuning",
|
||||
)
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd_result",
|
||||
help="the directory contains the images used for fine tuning",
|
||||
)
|
||||
p.add_argument(
|
||||
"--training_steps",
|
||||
type=int,
|
||||
default=3000,
|
||||
help="the maximum number of training steps",
|
||||
)
|
||||
p.add_argument("--seed", type=int, default=42, help="the random seed")
|
||||
p.add_argument(
|
||||
"--what_to_teach",
|
||||
type=str,
|
||||
choices=["object", "style"],
|
||||
default="object",
|
||||
help="what is it that you are teaching?",
|
||||
)
|
||||
p.add_argument(
|
||||
"--placeholder_token",
|
||||
type=str,
|
||||
default="<cat-toy>",
|
||||
help="It is the token you are going to use to represent your new concept",
|
||||
)
|
||||
p.add_argument(
|
||||
"--initializer_token",
|
||||
type=str,
|
||||
default="toy",
|
||||
help="It is a word that can summarise what is your new concept",
|
||||
)
|
||||
p.add_argument(
|
||||
"--inference_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the number of steps for inference",
|
||||
)
|
||||
p.add_argument(
|
||||
"--num_inference_samples",
|
||||
type=int,
|
||||
default=4,
|
||||
help="the number of samples for inference",
|
||||
)
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a grafitti in a wall with a *s on it",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
if "*s" not in args.prompt:
|
||||
raise ValueError(
|
||||
f'The prompt should have a "*s" which will be replaced by a placeholder token.'
|
||||
)
|
||||
|
||||
prompt1, prompt2 = args.prompt.split("*s")
|
||||
args.prompt = prompt1 + args.placeholder_token + prompt2
|
||||
|
||||
pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
# Load input images.
|
||||
images = []
|
||||
for filename in os.listdir(args.input_dir):
|
||||
img = cv2.imread(os.path.join(args.input_dir, filename))
|
||||
if img is not None:
|
||||
images.append(img)
|
||||
|
||||
# Setup the prompt templates for training
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
# Setup the dataset
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
):
|
||||
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
|
||||
self.templates = (
|
||||
imagenet_style_templates_small
|
||||
if learnable_property == "style"
|
||||
else imagenet_templates_small
|
||||
)
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
text = random.choice(self.templates).format(placeholder_string)
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
# Setting up the model
|
||||
# Load the tokenizer and add the placeholder token as a additional special token.
|
||||
# Please read and if you agree accept the LICENSE
|
||||
# [here](https://huggingface.co/CompVis/stable-diffusion-v1-4) if you see an error
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Get token ids for our placeholder and initializer token.
|
||||
# This code block will complain if initializer string is not a single token
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load the Stable Diffusion model
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# We have added the `placeholder_token` in the `tokenizer` so we resize the token embeddings here,
|
||||
# this will a new embedding vector in the token embeddings for our `placeholder_token`
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
|
||||
# In Textual-Inversion we only train the newly added embedding vector,
|
||||
# so lets freeze rest of the model parameters here.
|
||||
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
# Freeze vae and unet
|
||||
freeze_params(vae.parameters())
|
||||
freeze_params(unet.parameters())
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
params_to_freeze = itertools.chain(
|
||||
text_encoder.text_model.encoder.parameters(),
|
||||
text_encoder.text_model.final_layer_norm.parameters(),
|
||||
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
||||
)
|
||||
freeze_params(params_to_freeze)
|
||||
|
||||
# Creating our training data
|
||||
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=args.input_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=512,
|
||||
placeholder_token=args.placeholder_token,
|
||||
repeats=100,
|
||||
learnable_property=args.what_to_teach, # Option selected above between object and style
|
||||
center_crop=False,
|
||||
set="train",
|
||||
)
|
||||
|
||||
|
||||
def create_dataloader(train_batch_size=1):
|
||||
return torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True
|
||||
)
|
||||
|
||||
|
||||
# Create noise_scheduler for training.
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
tensor_format="pt",
|
||||
)
|
||||
|
||||
# Define hyperparameters for our training
|
||||
hyperparameters = {
|
||||
"learning_rate": 5e-04,
|
||||
"scale_lr": True,
|
||||
"max_train_steps": args.training_steps,
|
||||
"train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"seed": args.seed,
|
||||
"output_dir": "sd-concept-output",
|
||||
}
|
||||
|
||||
|
||||
def training_function(text_encoder, vae, unet):
|
||||
logger = get_logger(__name__)
|
||||
|
||||
train_batch_size = hyperparameters["train_batch_size"]
|
||||
gradient_accumulation_steps = hyperparameters[
|
||||
"gradient_accumulation_steps"
|
||||
]
|
||||
learning_rate = hyperparameters["learning_rate"]
|
||||
max_train_steps = hyperparameters["max_train_steps"]
|
||||
output_dir = hyperparameters["output_dir"]
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
train_dataloader = create_dataloader(train_batch_size)
|
||||
|
||||
if hyperparameters["scale_lr"]:
|
||||
learning_rate = (
|
||||
learning_rate
|
||||
* gradient_accumulation_steps
|
||||
* train_batch_size
|
||||
* accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=learning_rate,
|
||||
)
|
||||
|
||||
text_encoder, optimizer, train_dataloader = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
# Move vae and unet to device
|
||||
vae.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
|
||||
# Keep vae and unet in eval model as we don't train these
|
||||
vae.eval()
|
||||
unet.eval()
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
total_batch_size = (
|
||||
train_batch_size
|
||||
* accelerator.num_processes
|
||||
* gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
range(max_train_steps), disable=not accelerator.is_local_main_process
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = (
|
||||
vae.encode(batch["pixel_values"])
|
||||
.latent_dist.sample()
|
||||
.detach()
|
||||
)
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(latents.shape).to(latents.device)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(
|
||||
latents, noise, timesteps
|
||||
)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(
|
||||
noisy_latents, timesteps, encoder_hidden_states
|
||||
).sample
|
||||
|
||||
loss = (
|
||||
F.mse_loss(noise_pred, noise, reduction="none")
|
||||
.mean([1, 2, 3])
|
||||
.mean()
|
||||
)
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Zero out the gradients for all token embeddings except the newly added
|
||||
# embeddings for the concept, as we only want to optimize the concept embeddings
|
||||
if accelerator.num_processes > 1:
|
||||
grads = (
|
||||
text_encoder.module.get_input_embeddings().weight.grad
|
||||
)
|
||||
else:
|
||||
grads = text_encoder.get_input_embeddings().weight.grad
|
||||
# Get the index for tokens that we want to zero the grads for
|
||||
index_grads_to_zero = (
|
||||
torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
)
|
||||
grads.data[index_grads_to_zero, :] = grads.data[
|
||||
index_grads_to_zero, :
|
||||
].fill_(0)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {"loss": loss.detach().item()}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
skip_prk_steps=True,
|
||||
),
|
||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker"
|
||||
),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained(
|
||||
"openai/clip-vit-base-patch32"
|
||||
),
|
||||
)
|
||||
pipeline.save_pretrained(output_dir)
|
||||
# Also save the newly trained embeddings
|
||||
learned_embeds = (
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight[placeholder_token_id]
|
||||
)
|
||||
learned_embeds_dict = {
|
||||
args.placeholder_token: learned_embeds.detach().cpu()
|
||||
}
|
||||
torch.save(
|
||||
learned_embeds_dict, os.path.join(output_dir, "learned_embeds.bin")
|
||||
)
|
||||
|
||||
|
||||
import accelerate
|
||||
|
||||
accelerate.notebook_launcher(
|
||||
training_function, args=(text_encoder, vae, unet), num_processes=1
|
||||
)
|
||||
|
||||
# Set up the pipeline
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
hyperparameters["output_dir"],
|
||||
# torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
all_images = []
|
||||
for _ in range(args.num_inference_samples):
|
||||
images = pipe(
|
||||
[args.prompt],
|
||||
num_inference_steps=args.inference_steps,
|
||||
guidance_scale=7.5,
|
||||
).images
|
||||
all_images.extend(images)
|
||||
|
||||
# output_path = os.path.abspath(os.path.join(os.getcwd(), args.output_dir))
|
||||
if not os.path.isdir(args.output_dir):
|
||||
os.mkdir(args.output_dir)
|
||||
|
||||
[
|
||||
image.save(f"{args.output_dir}/{i}.jpeg")
|
||||
for i, image in enumerate(all_images)
|
||||
]
|
||||
@@ -28,14 +28,9 @@ from torch_mlir.eager_mode.torch_mlir_eager_backend import (
|
||||
TorchMLIREagerBackend,
|
||||
TensorMetaData,
|
||||
)
|
||||
from torch_mlir_e2e_test.eager_backends.refbackend import (
|
||||
NUMPY_TO_TORCH_DTYPE_DICT,
|
||||
)
|
||||
from torch_mlir_e2e_test.eager_backends.refbackend import NUMPY_TO_TORCH_DTYPE_DICT
|
||||
|
||||
from shark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
IREE_DEVICE_MAP,
|
||||
)
|
||||
from shark.iree_utils import get_iree_compiled_module, IREE_DEVICE_MAP
|
||||
|
||||
|
||||
class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend):
|
||||
@@ -48,19 +43,18 @@ class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend):
|
||||
|
||||
def __init__(self, device: str):
|
||||
self.torch_device_str = device
|
||||
self.config = ireert.Config(IREE_DEVICE_MAP[device])
|
||||
self.raw_device_str = device
|
||||
self.iree_device_str = IREE_DEVICE_MAP[device]
|
||||
self.config = ireert.Config(self.iree_device_str)
|
||||
|
||||
def get_torch_metadata(
|
||||
self, tensor: DeviceArray, kwargs: Dict[str, Any]
|
||||
) -> TensorMetaData:
|
||||
def get_torch_metadata(self, tensor: DeviceArray,
|
||||
kwargs: Dict[str, Any]) -> TensorMetaData:
|
||||
return TensorMetaData(
|
||||
size=tensor.shape,
|
||||
dtype=NUMPY_TO_TORCH_DTYPE_DICT[tensor.dtype.type],
|
||||
device=torch.device(self.torch_device_str),
|
||||
requires_grad=tensor.dtype.type
|
||||
in {np.float, np.float32, np.float64}
|
||||
and kwargs.get("requires_grad", False),
|
||||
in {np.float, np.float32, np.float64} and
|
||||
kwargs.get("requires_grad", False),
|
||||
)
|
||||
|
||||
def compile(self, imported_module: Module):
|
||||
@@ -70,9 +64,9 @@ class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend):
|
||||
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||
"EagerMode",
|
||||
)
|
||||
callable, _ = get_iree_compiled_module(
|
||||
imported_module, self.raw_device_str, func_name=fn_name
|
||||
)
|
||||
callable, _ = get_iree_compiled_module(imported_module,
|
||||
self.iree_device_str,
|
||||
func_name=fn_name)
|
||||
return callable
|
||||
|
||||
def copy_into(self, dst, src):
|
||||
@@ -82,7 +76,6 @@ class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend):
|
||||
def transfer_from_device_to_torch(self, e):
|
||||
return torch.from_numpy(e.to_host())
|
||||
|
||||
def transfer_from_torch_to_device(
|
||||
self, tensor: torch.Tensor
|
||||
) -> DeviceArray:
|
||||
def transfer_from_torch_to_device(self,
|
||||
tensor: torch.Tensor) -> DeviceArray:
|
||||
return iree.runtime.asdevicearray(self.config.device, tensor.numpy())
|
||||
|
||||
359
shark/iree_utils.py
Normal file
359
shark/iree_utils.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import iree.runtime as ireert
|
||||
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
|
||||
import iree.compiler as ireec
|
||||
from shark.torch_mlir_utils import get_module_name_for_asm_dump
|
||||
from shark.cuda_utils import get_cuda_sm_cc
|
||||
from shark.model_annotation import *
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"gpu": "cuda",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "vulkan",
|
||||
"rocm": "rocm"
|
||||
}
|
||||
|
||||
IREE_TARGET_MAP = {
|
||||
"cpu": "dylib",
|
||||
"gpu": "cuda",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "vulkan",
|
||||
"rocm": "rocm"
|
||||
}
|
||||
|
||||
UNIT_TO_SECOND_MAP = {"ms": 0.001, "s": 1}
|
||||
|
||||
|
||||
def check_device_drivers(device):
|
||||
"""Checks necessary drivers present for gpu and vulkan devices"""
|
||||
if (device in ["gpu", "cuda"]):
|
||||
try:
|
||||
subprocess.check_output('nvidia-smi')
|
||||
except Exception:
|
||||
return True
|
||||
elif (device in ["metal", "vulkan"]):
|
||||
try:
|
||||
subprocess.check_output('vulkaninfo')
|
||||
except Exception:
|
||||
return True
|
||||
elif (device == "cpu"):
|
||||
return False
|
||||
# Unknown device.
|
||||
else:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_iree_cpu_args():
|
||||
find_triple_cmd = "uname -s -m"
|
||||
os_name, proc_name = subprocess.run(
|
||||
find_triple_cmd, shell=True, stdout=subprocess.PIPE,
|
||||
check=True).stdout.decode('utf-8').split()
|
||||
if os_name == "Darwin":
|
||||
find_kernel_version_cmd = "uname -r"
|
||||
kernel_version = subprocess.run(find_kernel_version_cmd,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
check=True).stdout.decode('utf-8')
|
||||
target_triple = f"{proc_name}-apple-darwin{kernel_version}"
|
||||
elif os_name == "Linux":
|
||||
target_triple = f"{proc_name}-linux-gnu"
|
||||
else:
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
print(f"Target triple found:{target_triple}")
|
||||
return [f"-iree-llvm-target-triple={target_triple}"]
|
||||
|
||||
|
||||
def get_iree_gpu_args():
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
ireert.flags.parse_flags("--cuda_allow_inline_execution")
|
||||
sm_arch = get_cuda_sm_cc()
|
||||
if sm_arch in ['sm_70', 'sm_72', 'sm_75', 'sm_80', 'sm_84', 'sm_86']:
|
||||
return [
|
||||
"--iree-hal-cuda-disable-loop-nounroll-wa",
|
||||
f"--iree-hal-cuda-llvm-target-arch={sm_arch}"
|
||||
]
|
||||
else:
|
||||
return ["--iree-hal-cuda-disable-loop-nounroll-wa"]
|
||||
|
||||
|
||||
def get_vulkan_triple_flag():
|
||||
vulkan_device_cmd = "vulkaninfo | grep deviceName | awk \'END{{print $NF}}\'"
|
||||
vulkan_device = run_cmd(vulkan_device_cmd).strip()
|
||||
if vulkan_device == "M1":
|
||||
print("Found Apple Device. Using m1-moltenvk-macos")
|
||||
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
elif vulkan_device == "A100-SXM4-40GB":
|
||||
print("Found Nvidia Device. Using ampere-rtx3080-linux")
|
||||
return "-iree-vulkan-target-triple=ampere-rtx3080-linux"
|
||||
else:
|
||||
print(
|
||||
"Optimized kernel for your target device is not added yet. Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u] or pull up an issue."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_vulkan_args():
|
||||
#vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
vulkan_flag = []
|
||||
vulkan_triple_flag = get_vulkan_triple_flag()
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_flag.append(vulkan_triple_flag)
|
||||
return vulkan_flag
|
||||
|
||||
|
||||
def get_iree_device_args(device):
|
||||
if device == "cpu":
|
||||
return get_iree_cpu_args()
|
||||
if device in ["gpu", "cuda"]:
|
||||
return get_iree_gpu_args()
|
||||
if device in ["metal", "vulkan"]:
|
||||
return get_iree_vulkan_args()
|
||||
return []
|
||||
|
||||
|
||||
def get_iree_frontend_args(frontend):
|
||||
if frontend in ["torch", "pytorch", "linalg"]:
|
||||
return ["--iree-llvm-target-cpu-features=host"]
|
||||
elif frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
return [
|
||||
"--iree-llvm-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
"--iree-flow-demote-i64-to-i32"
|
||||
]
|
||||
else:
|
||||
# Frontend not found.
|
||||
return []
|
||||
|
||||
|
||||
def compile_module_to_flatbuffer(module, device, frontend, func_name,
|
||||
model_config_path):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device)
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "mhlo"
|
||||
elif frontend in ["mhlo", "tosa"]:
|
||||
input_type = frontend
|
||||
elif frontend in ["tflite"]:
|
||||
input_type = "tosa"
|
||||
|
||||
# Annotate the input module with the configs
|
||||
if model_config_path != None:
|
||||
# Currently tuned model only works on tf frontend
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_module = module.decode('utf-8')
|
||||
elif frontend in ["pytorch", "torch"]:
|
||||
input_module = module.operation.get_asm()
|
||||
with create_context() as ctx:
|
||||
module = model_annotation(ctx,
|
||||
input_contents=input_module,
|
||||
config_path=model_config_path)
|
||||
module = str(module)
|
||||
|
||||
# Compile according to the input type, else just try compiling.
|
||||
if input_type not in ["mhlo", "tosa"]:
|
||||
module = str(module)
|
||||
if input_type != "":
|
||||
# Currently for MHLO/TOSA.
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[IREE_TARGET_MAP[device]],
|
||||
extra_args=args,
|
||||
input_type=input_type)
|
||||
else:
|
||||
# Currently for Torch.
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
str(module),
|
||||
target_backends=[IREE_TARGET_MAP[device]],
|
||||
extra_args=args)
|
||||
return flatbuffer_blob
|
||||
|
||||
|
||||
def get_iree_module(flatbuffer_blob, device, func_name):
|
||||
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)
|
||||
config = ireert.Config(IREE_DEVICE_MAP[device])
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
ModuleCompiled = ctx.modules.module[func_name]
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
def get_iree_compiled_module(module,
|
||||
device: str,
|
||||
frontend: str = "torch",
|
||||
func_name: str = "forward",
|
||||
model_config_path: str = None):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(module, device, frontend,
|
||||
func_name, model_config_path)
|
||||
return get_iree_module(flatbuffer_blob, device, func_name)
|
||||
|
||||
|
||||
def export_iree_module_to_vmfb(module,
|
||||
device: str,
|
||||
directory: str,
|
||||
frontend: str = "torch",
|
||||
func_name: str = "forward",
|
||||
model_config_path: str = None):
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(module, device, frontend,
|
||||
func_name, model_config_path)
|
||||
module_name = f"{frontend}_{func_name}_{device}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(flatbuffer_blob)
|
||||
return filename
|
||||
|
||||
|
||||
def export_module_to_mlir_file(module, frontend, directory: str):
|
||||
mlir_str = module
|
||||
if frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
mlir_str = module.decode('utf-8')
|
||||
elif frontend in ["pytorch", "torch"]:
|
||||
mlir_str = module.operation.get_asm()
|
||||
filename = os.path.join(directory, "model.mlir")
|
||||
with open(filename, 'w') as f:
|
||||
f.write(mlir_str)
|
||||
print(f"Saved mlir in {filename}.")
|
||||
return filename
|
||||
|
||||
|
||||
def get_results(compiled_vm, input, config, frontend="torch"):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
device_inputs = input
|
||||
if frontend in ["torch", "pytorch"]:
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
if frontend in ["tensorflow", "tf", "tflite"]:
|
||||
device_inputs = []
|
||||
for a in input:
|
||||
if (isinstance(a, list)):
|
||||
device_inputs.append([
|
||||
ireert.asdevicearray(config.device, val, dtype=np.int32)
|
||||
for val in a
|
||||
])
|
||||
else:
|
||||
device_inputs.append(ireert.asdevicearray(config.device, a))
|
||||
result = compiled_vm(*device_inputs)
|
||||
result_tensors = []
|
||||
if (isinstance(result, tuple)):
|
||||
for val in result:
|
||||
result_tensors.append(np.copy(np.asarray(val, val.dtype)))
|
||||
return result_tensors
|
||||
elif (isinstance(result, dict)):
|
||||
data = list(result.items())
|
||||
res = np.array(data, dtype=object)
|
||||
return np.copy(res)
|
||||
else:
|
||||
return np.copy(np.asarray(result, dtype=result.dtype))
|
||||
|
||||
|
||||
######### Benchmark Related Tools ###########
|
||||
|
||||
|
||||
def tensor_to_type_str(input_tensors: tuple, frontend: str):
|
||||
"""
|
||||
Input: A tuple of input tensors i.e tuple(torch.tensor)
|
||||
Output: list of string that represent mlir types (i.e 1x24xf64)
|
||||
# TODO: Support more than floats, and ints
|
||||
"""
|
||||
list_of_type = []
|
||||
for input_tensor in input_tensors:
|
||||
type_string = "x".join([str(dim) for dim in input_tensor.shape])
|
||||
if frontend in ["torch", "pytorch"]:
|
||||
dtype_string = str(input_tensor.dtype).replace("torch.", "")
|
||||
elif frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
dtype = input_tensor.dtype
|
||||
dtype_string = re.findall('\'[^"]*\'',
|
||||
str(dtype))[0].replace("\'", "")
|
||||
regex_split = re.compile("([a-zA-Z]+)([0-9]+)")
|
||||
match = regex_split.match(dtype_string)
|
||||
mlir_type_string = str(match.group(1)[0]) + str(match.group(2))
|
||||
type_string += f"x{mlir_type_string}"
|
||||
list_of_type.append(type_string)
|
||||
return list_of_type
|
||||
|
||||
|
||||
def build_benchmark_args(input_file: str,
|
||||
device: str,
|
||||
input_tensors: tuple,
|
||||
frontend: str,
|
||||
training=False):
|
||||
"""
|
||||
Inputs: input_file leading to vmfb, input_tensor to function, target device, and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = benchmark_module.__path__[0]
|
||||
benchmarker_path = os.path.join(path, "..", "..", "iree-benchmark-module")
|
||||
benchmark_cl = [benchmarker_path, f"--module_file={input_file}"]
|
||||
fn_name = "forward"
|
||||
if training == True:
|
||||
# TODO: Replace name of train with actual train fn name.
|
||||
fn_name = "train"
|
||||
benchmark_cl.append(f"--entry_function={fn_name}")
|
||||
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
|
||||
mlir_input_types = tensor_to_type_str(input_tensors, frontend)
|
||||
for mlir_input in mlir_input_types:
|
||||
benchmark_cl.append(f"--function_input={mlir_input}")
|
||||
time_extractor = "| awk \'END{{print $2 $3}}\'"
|
||||
benchmark_cl.append(time_extractor)
|
||||
return benchmark_cl
|
||||
|
||||
|
||||
def run_cmd(cmd):
|
||||
"""
|
||||
Inputs: cli command string.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(cmd,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
check=True)
|
||||
result_str = result.stdout.decode()
|
||||
return result_str
|
||||
except Exception:
|
||||
sys.exit("Exiting program due to error running:", cmd)
|
||||
|
||||
|
||||
def run_benchmark_module(benchmark_cl):
|
||||
"""
|
||||
Run benchmark command, extract result and return iteration/seconds.
|
||||
|
||||
Input: benchmark command.
|
||||
"""
|
||||
benchmark_path = benchmark_cl[0]
|
||||
assert os.path.exists(
|
||||
benchmark_path
|
||||
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
|
||||
bench_result = run_cmd(' '.join(benchmark_cl))
|
||||
regex_split = re.compile("([0-9]+[.]*[0-9]*)([a-zA-Z]+)")
|
||||
match = regex_split.match(bench_result)
|
||||
time = float(match.group(1))
|
||||
unit = match.group(2)
|
||||
return 1.0 / (time * UNIT_TO_SECOND_MAP[unit])
|
||||
@@ -1,100 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
## Common utilities to be shared by iree utilities.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
|
||||
def run_cmd(cmd):
|
||||
"""
|
||||
Inputs: cli command string.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
check=True,
|
||||
)
|
||||
result_str = result.stdout.decode()
|
||||
return result_str
|
||||
except Exception:
|
||||
sys.exit("Exiting program due to error running:", cmd)
|
||||
|
||||
|
||||
IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "vulkan",
|
||||
"rocm": "rocm",
|
||||
"intel-gpu": "level_zero",
|
||||
}
|
||||
|
||||
IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "vulkan",
|
||||
"rocm": "rocm",
|
||||
"intel-gpu": "opencl-spirv",
|
||||
}
|
||||
|
||||
# Finds whether the required drivers are installed for the given device.
|
||||
def check_device_drivers(device):
|
||||
"""Checks necessary drivers present for gpu and vulkan devices"""
|
||||
if device == "cuda":
|
||||
try:
|
||||
subprocess.check_output("nvidia-smi")
|
||||
except Exception:
|
||||
return True
|
||||
elif device in ["metal", "vulkan"]:
|
||||
try:
|
||||
subprocess.check_output("vulkaninfo")
|
||||
except Exception:
|
||||
return True
|
||||
elif device in ["intel-gpu"]:
|
||||
try:
|
||||
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])
|
||||
return False
|
||||
except Exception:
|
||||
return True
|
||||
elif device == "cpu":
|
||||
return False
|
||||
elif device == "rocm":
|
||||
try:
|
||||
subprocess.check_output("rocminfo")
|
||||
except Exception:
|
||||
return True
|
||||
# Unknown device.
|
||||
else:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Installation info for the missing device drivers.
|
||||
def device_driver_info(device):
|
||||
if device == "cuda":
|
||||
return "nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
|
||||
elif device in ["metal", "vulkan"]:
|
||||
return "vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
|
||||
elif device == "rocm":
|
||||
return "rocm info not found. Please install rocm"
|
||||
else:
|
||||
return f"{device} is not supported."
|
||||
@@ -1,122 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
|
||||
from shark.iree_utils._common import run_cmd, IREE_DEVICE_MAP
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
|
||||
UNIT_TO_SECOND_MAP = {"ms": 0.001, "s": 1}
|
||||
|
||||
|
||||
def tensor_to_type_str(input_tensors: tuple, mlir_dialect: str):
|
||||
"""
|
||||
Input: A tuple of input tensors i.e tuple(torch.tensor)
|
||||
Output: list of string that represent mlir types (i.e 1x24xf64)
|
||||
# TODO: Support more than floats, and ints
|
||||
"""
|
||||
list_of_type = []
|
||||
for input_tensor in input_tensors:
|
||||
type_string = "x".join([str(dim) for dim in input_tensor.shape])
|
||||
if mlir_dialect in ["linalg", "tosa"]:
|
||||
dtype_string = str(input_tensor.dtype).replace("torch.", "")
|
||||
elif mlir_dialect in ["mhlo", "tflite"]:
|
||||
dtype = input_tensor.dtype
|
||||
try:
|
||||
dtype_string = re.findall("'[^\"]*'", str(dtype))[0].replace(
|
||||
"'", ""
|
||||
)
|
||||
except IndexError:
|
||||
dtype_string = str(dtype)
|
||||
regex_split = re.compile("([a-zA-Z]+)([0-9]+)")
|
||||
match = regex_split.match(dtype_string)
|
||||
mlir_type_string = str(match.group(1)[0]) + str(match.group(2))
|
||||
type_string += f"x{mlir_type_string}"
|
||||
list_of_type.append(type_string)
|
||||
return list_of_type
|
||||
|
||||
|
||||
def build_benchmark_args(
|
||||
input_file: str,
|
||||
device: str,
|
||||
input_tensors: tuple,
|
||||
mlir_dialect: str,
|
||||
training=False,
|
||||
):
|
||||
"""
|
||||
Inputs: input_file leading to vmfb, input_tensor to function, target device,
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = benchmark_module.__path__[0]
|
||||
benchmarker_path = os.path.join(path, "..", "..", "iree-benchmark-module")
|
||||
benchmark_cl = [benchmarker_path, f"--module_file={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
fn_name = "forward"
|
||||
if training == True:
|
||||
# TODO: Replace name of train with actual train fn name.
|
||||
fn_name = "train"
|
||||
benchmark_cl.append(f"--entry_function={fn_name}")
|
||||
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
|
||||
mlir_input_types = tensor_to_type_str(input_tensors, mlir_dialect)
|
||||
for mlir_input in mlir_input_types:
|
||||
benchmark_cl.append(f"--function_input={mlir_input}")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl.append(time_extractor)
|
||||
return benchmark_cl
|
||||
|
||||
|
||||
def build_benchmark_args_non_tensor_input(
|
||||
input_file: str,
|
||||
device: str,
|
||||
inputs: tuple,
|
||||
mlir_dialect: str,
|
||||
function_name: str,
|
||||
):
|
||||
"""
|
||||
Inputs: input_file leading to vmfb, input_tensor to function, target device,
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = benchmark_module.__path__[0]
|
||||
benchmarker_path = os.path.join(path, "..", "..", "iree-benchmark-module")
|
||||
benchmark_cl = [benchmarker_path, f"--module_file={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
benchmark_cl.append(f"--entry_function={function_name}")
|
||||
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
|
||||
for input in inputs:
|
||||
benchmark_cl.append(f"--function_input={input}")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl.append(time_extractor)
|
||||
return benchmark_cl
|
||||
|
||||
|
||||
def run_benchmark_module(benchmark_cl):
|
||||
"""
|
||||
Run benchmark command, extract result and return iteration/seconds.
|
||||
|
||||
# TODO: Add an example of the benchmark command.
|
||||
Input: benchmark command.
|
||||
"""
|
||||
benchmark_path = benchmark_cl[0]
|
||||
assert os.path.exists(
|
||||
benchmark_path
|
||||
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
|
||||
bench_result = run_cmd(" ".join(benchmark_cl))
|
||||
regex_split = re.compile("([0-9]+[.]*[0-9]*)([a-zA-Z]+)")
|
||||
match = regex_split.match(bench_result)
|
||||
time = float(match.group(1))
|
||||
unit = match.group(2)
|
||||
return 1.0 / (time * UNIT_TO_SECOND_MAP[unit])
|
||||
@@ -1,314 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.iree_utils._common import IREE_DEVICE_MAP, IREE_TARGET_MAP
|
||||
from shark.iree_utils.benchmark_utils import *
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
if device == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
return get_iree_cpu_args()
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
return get_iree_gpu_args()
|
||||
if device in ["metal", "vulkan"]:
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(extra_args=extra_args)
|
||||
if device == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args()
|
||||
return []
|
||||
|
||||
|
||||
# Get the iree-compiler arguments given frontend.
|
||||
def get_iree_frontend_args(frontend):
|
||||
if frontend in ["torch", "pytorch", "linalg"]:
|
||||
return ["--iree-llvm-target-cpu-features=host"]
|
||||
elif frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
return [
|
||||
"--iree-llvm-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
"--iree-flow-demote-i64-to-i32",
|
||||
]
|
||||
else:
|
||||
# Frontend not found.
|
||||
return []
|
||||
|
||||
|
||||
# Common args to be used given any frontend or device.
|
||||
def get_iree_common_args():
|
||||
return [
|
||||
"--iree-stream-resource-index-bits=64",
|
||||
"--iree-vm-target-index-bits=64",
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
]
|
||||
|
||||
|
||||
def create_dispatch_dirs(bench_dir, device):
|
||||
bench_dir_path = bench_dir.split("/")
|
||||
bench_dir_path[-1] = "temp_" + bench_dir_path[-1]
|
||||
tmp_bench_dir = "/".join(bench_dir_path)
|
||||
for f_ in os.listdir(bench_dir):
|
||||
if os.path.isfile(f"{bench_dir}/{f_}"):
|
||||
dir_name = re.sub("\.\S*$", "", f_)
|
||||
if os.path.exists(f"{bench_dir}/{dir_name}"):
|
||||
os.system(f"rm -rf {bench_dir}/{dir_name}")
|
||||
os.system(f"mkdir {bench_dir}/{dir_name}")
|
||||
os.system(f"mv {bench_dir}/{f_} {bench_dir}/{dir_name}/{f_}")
|
||||
for f_ in os.listdir(tmp_bench_dir):
|
||||
if os.path.isfile(f"{tmp_bench_dir}/{f_}"):
|
||||
dir_name = ""
|
||||
for d_ in os.listdir(bench_dir):
|
||||
if re.search(f"{d_}(?=\D)", f_):
|
||||
dir_name = d_
|
||||
if dir_name != "":
|
||||
os.system(
|
||||
f"mv {tmp_bench_dir}/{f_} {bench_dir}/{dir_name}/{dir_name}_benchmark.mlir"
|
||||
)
|
||||
|
||||
|
||||
def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
|
||||
dispatch_list = []
|
||||
all_dispatches = False
|
||||
|
||||
if dispatch_benchmarks.lower().strip() == "all":
|
||||
all_dispatches = True
|
||||
else:
|
||||
try:
|
||||
dispatch_list = [
|
||||
int(dispatch_index)
|
||||
for dispatch_index in dispatch_benchmarks.split(" ")
|
||||
]
|
||||
except:
|
||||
print("ERROR: Invalid dispatch benchmarks")
|
||||
return None
|
||||
for d_ in os.listdir(bench_dir):
|
||||
in_dispatches = False
|
||||
for dispatch in dispatch_list:
|
||||
if str(dispatch) in d_:
|
||||
in_dispatches = True
|
||||
if all_dispatches or in_dispatches:
|
||||
for f_ in os.listdir(f"{bench_dir}/{d_}"):
|
||||
|
||||
if "benchmark.mlir" in f_:
|
||||
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
|
||||
module = dispatch_file.read()
|
||||
dispatch_file.close()
|
||||
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module, target_backends=[IREE_TARGET_MAP[device]]
|
||||
)
|
||||
|
||||
vmfb_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_benchmark.vmfb", "wb"
|
||||
)
|
||||
vmfb_file.write(flatbuffer_blob)
|
||||
vmfb_file.close()
|
||||
|
||||
config = ireert.Config(IREE_DEVICE_MAP[device])
|
||||
vm_module = ireert.VmModule.from_flatbuffer(
|
||||
config.vm_instance, flatbuffer_blob
|
||||
)
|
||||
|
||||
benchmark_cl = build_benchmark_args_non_tensor_input(
|
||||
input_file=f"{bench_dir}/{d_}/{d_}_benchmark.vmfb",
|
||||
device=device,
|
||||
inputs=(0,),
|
||||
mlir_dialect="linalg",
|
||||
function_name=vm_module.function_names[0],
|
||||
)
|
||||
|
||||
benchmark_bash = open(
|
||||
f"{bench_dir}/{d_}/{d_}_benchmark.sh", "w+"
|
||||
)
|
||||
benchmark_bash.write("#!/bin/bash\n")
|
||||
benchmark_bash.write(" ".join(benchmark_cl))
|
||||
benchmark_bash.close()
|
||||
|
||||
benchmark_data = run_benchmark_module(benchmark_cl)
|
||||
|
||||
benchmark_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
|
||||
)
|
||||
benchmark_file.write(f"DISPATCH: {d_}\n")
|
||||
benchmark_file.write(str(benchmark_data) + "\n")
|
||||
benchmark_file.write(
|
||||
"SHARK BENCHMARK RESULT: "
|
||||
+ str(1 / (benchmark_data * 0.001))
|
||||
+ "\n"
|
||||
)
|
||||
benchmark_file.close()
|
||||
|
||||
elif ".mlir" in f_ and "benchmark" not in f_:
|
||||
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
|
||||
module = dispatch_file.read()
|
||||
dispatch_file.close()
|
||||
|
||||
module = re.sub(
|
||||
"hal.executable private",
|
||||
"hal.executable public",
|
||||
module,
|
||||
)
|
||||
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[IREE_TARGET_MAP[device]],
|
||||
extra_args=["--compile-mode=hal-executable"],
|
||||
)
|
||||
|
||||
spirv_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_spirv.vmfb", "wb"
|
||||
)
|
||||
spirv_file.write(flatbuffer_blob)
|
||||
spirv_file.close()
|
||||
|
||||
|
||||
def compile_module_to_flatbuffer(
|
||||
module, device, frontend, func_name, model_config_path, extra_args
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args()
|
||||
args += extra_args
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "mhlo"
|
||||
elif frontend in ["mhlo", "tosa"]:
|
||||
input_type = frontend
|
||||
elif frontend in ["tflite", "tflite-tosa"]:
|
||||
input_type = "tosa"
|
||||
elif frontend in ["tm_tensor"]:
|
||||
input_type = ireec.InputType.TM_TENSOR
|
||||
|
||||
# TODO: make it simpler.
|
||||
# Compile according to the input type, else just try compiling.
|
||||
if input_type != "":
|
||||
# Currently for MHLO/TOSA.
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[IREE_TARGET_MAP[device]],
|
||||
extra_args=args,
|
||||
input_type=input_type,
|
||||
)
|
||||
else:
|
||||
# Currently for Torch.
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[IREE_TARGET_MAP[device]],
|
||||
extra_args=args,
|
||||
)
|
||||
|
||||
return flatbuffer_blob
|
||||
|
||||
|
||||
def get_iree_module(flatbuffer_blob, device, func_name):
|
||||
# Returns the compiled module and the configs.
|
||||
config = ireert.Config(IREE_DEVICE_MAP[device])
|
||||
vm_module = ireert.VmModule.from_flatbuffer(
|
||||
config.vm_instance, flatbuffer_blob
|
||||
)
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
ModuleCompiled = ctx.modules.module[func_name]
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
def get_iree_compiled_module(
|
||||
module,
|
||||
device: str,
|
||||
frontend: str = "torch",
|
||||
func_name: str = "forward",
|
||||
model_config_path: str = None,
|
||||
extra_args: list = [],
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, func_name, model_config_path, extra_args
|
||||
)
|
||||
return get_iree_module(flatbuffer_blob, device, func_name)
|
||||
|
||||
|
||||
def load_flatbuffer(
|
||||
flatbuffer_path: str, device: str, func_name: str = "forward"
|
||||
):
|
||||
|
||||
with open(os.path.join(flatbuffer_path), "rb") as f:
|
||||
flatbuffer_blob = f.read()
|
||||
|
||||
return get_iree_module(flatbuffer_blob, device, func_name)
|
||||
|
||||
|
||||
def export_iree_module_to_vmfb(
|
||||
module,
|
||||
device: str,
|
||||
directory: str,
|
||||
mlir_dialect: str = "linalg",
|
||||
func_name: str = "forward",
|
||||
model_config_path: str = None,
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, mlir_dialect, func_name, model_config_path, extra_args
|
||||
)
|
||||
if module_name is None:
|
||||
module_name = f"{mlir_dialect}_{func_name}_{device}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
with open(filename, "wb") as f:
|
||||
f.write(flatbuffer_blob)
|
||||
return filename
|
||||
|
||||
|
||||
def export_module_to_mlir_file(module, frontend, directory: str):
|
||||
# TODO: write proper documentation.
|
||||
mlir_str = module
|
||||
if frontend in ["tensorflow", "tf", "mhlo", "tflite"]:
|
||||
mlir_str = module.decode("utf-8")
|
||||
elif frontend in ["pytorch", "torch"]:
|
||||
mlir_str = module.operation.get_asm()
|
||||
filename = os.path.join(directory, "model.mlir")
|
||||
with open(filename, "w") as f:
|
||||
f.write(mlir_str)
|
||||
print(f"Saved mlir in {filename}.")
|
||||
return filename
|
||||
|
||||
|
||||
def get_results(compiled_vm, input, config, frontend="torch"):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
result = compiled_vm(*device_inputs)
|
||||
result_tensors = []
|
||||
if isinstance(result, tuple):
|
||||
for val in result:
|
||||
result_tensors.append(np.copy(np.asarray(val, val.dtype)))
|
||||
return result_tensors
|
||||
elif isinstance(result, dict):
|
||||
data = list(result.items())
|
||||
res = np.array(data, dtype=object)
|
||||
return np.copy(res)
|
||||
else:
|
||||
return np.copy(np.asarray(result, dtype=result.dtype))
|
||||
@@ -1,44 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# All the iree_cpu related functionalities go here.
|
||||
|
||||
import subprocess
|
||||
|
||||
# Get the default cpu args.
|
||||
def get_iree_cpu_args():
|
||||
find_triple_cmd = "uname -s -m"
|
||||
os_name, proc_name = (
|
||||
subprocess.run(
|
||||
find_triple_cmd, shell=True, stdout=subprocess.PIPE, check=True
|
||||
)
|
||||
.stdout.decode("utf-8")
|
||||
.split()
|
||||
)
|
||||
if os_name == "Darwin":
|
||||
find_kernel_version_cmd = "uname -r"
|
||||
kernel_version = subprocess.run(
|
||||
find_kernel_version_cmd,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
).stdout.decode("utf-8")
|
||||
target_triple = f"{proc_name}-apple-darwin{kernel_version}"
|
||||
elif os_name == "Linux":
|
||||
target_triple = f"{proc_name}-linux-gnu"
|
||||
else:
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
print(f"Target triple found:{target_triple}")
|
||||
return [f"-iree-llvm-target-triple={target_triple}"]
|
||||
@@ -1,123 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# All the iree_gpu related functionalities go here.
|
||||
|
||||
import iree.runtime as ireert
|
||||
import ctypes
|
||||
from shark.parser import shark_args
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
def get_iree_gpu_args():
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
ireert.flags.parse_flags("--cuda_allow_inline_execution")
|
||||
# TODO: Give the user_interface to pass the sm_arch.
|
||||
sm_arch = get_cuda_sm_cc()
|
||||
if (
|
||||
sm_arch in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86"]
|
||||
) and (shark_args.enable_tf32 == True):
|
||||
return [
|
||||
"--iree-hal-cuda-disable-loop-nounroll-wa",
|
||||
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
|
||||
]
|
||||
else:
|
||||
return ["--iree-hal-cuda-disable-loop-nounroll-wa"]
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
def get_iree_rocm_args():
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
# TODO: find a way to get arch from code.
|
||||
rocm_arch = "gfx908"
|
||||
return [
|
||||
f"--iree-rocm-target-chip={rocm_arch}",
|
||||
"--iree-rocm-link-bc=true",
|
||||
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
|
||||
]
|
||||
|
||||
|
||||
# Some constants taken from cuda.h
|
||||
CUDA_SUCCESS = 0
|
||||
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16
|
||||
CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR = 39
|
||||
CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13
|
||||
CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36
|
||||
|
||||
|
||||
def get_cuda_sm_cc():
|
||||
libnames = ("libcuda.so", "libcuda.dylib", "cuda.dll")
|
||||
for libname in libnames:
|
||||
try:
|
||||
cuda = ctypes.CDLL(libname)
|
||||
except OSError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise OSError("could not load any of: " + " ".join(libnames))
|
||||
|
||||
nGpus = ctypes.c_int()
|
||||
name = b" " * 100
|
||||
cc_major = ctypes.c_int()
|
||||
cc_minor = ctypes.c_int()
|
||||
|
||||
result = ctypes.c_int()
|
||||
device = ctypes.c_int()
|
||||
context = ctypes.c_void_p()
|
||||
error_str = ctypes.c_char_p()
|
||||
|
||||
result = cuda.cuInit(0)
|
||||
if result != CUDA_SUCCESS:
|
||||
cuda.cuGetErrorString(result, ctypes.byref(error_str))
|
||||
print(
|
||||
"cuInit failed with error code %d: %s"
|
||||
% (result, error_str.value.decode())
|
||||
)
|
||||
return 1
|
||||
result = cuda.cuDeviceGetCount(ctypes.byref(nGpus))
|
||||
if result != CUDA_SUCCESS:
|
||||
cuda.cuGetErrorString(result, ctypes.byref(error_str))
|
||||
print(
|
||||
"cuDeviceGetCount failed with error code %d: %s"
|
||||
% (result, error_str.value.decode())
|
||||
)
|
||||
return 1
|
||||
print("Found %d device(s)." % nGpus.value)
|
||||
for i in range(nGpus.value):
|
||||
result = cuda.cuDeviceGet(ctypes.byref(device), i)
|
||||
if result != CUDA_SUCCESS:
|
||||
cuda.cuGetErrorString(result, ctypes.byref(error_str))
|
||||
print(
|
||||
"cuDeviceGet failed with error code %d: %s"
|
||||
% (result, error_str.value.decode())
|
||||
)
|
||||
return 1
|
||||
print("Device: %d" % i)
|
||||
if (
|
||||
cuda.cuDeviceGetName(ctypes.c_char_p(name), len(name), device)
|
||||
== CUDA_SUCCESS
|
||||
):
|
||||
print(" Name: %s" % (name.split(b"\0", 1)[0].decode()))
|
||||
if (
|
||||
cuda.cuDeviceComputeCapability(
|
||||
ctypes.byref(cc_major), ctypes.byref(cc_minor), device
|
||||
)
|
||||
== CUDA_SUCCESS
|
||||
):
|
||||
print(
|
||||
" Compute Capability: %d.%d"
|
||||
% (cc_major.value, cc_minor.value)
|
||||
)
|
||||
sm = f"sm_{cc_major.value}{cc_minor.value}"
|
||||
return sm
|
||||
@@ -1,70 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# All the iree_vulkan related functionalities go here.
|
||||
|
||||
from os import linesep
|
||||
from shark.iree_utils._common import run_cmd
|
||||
|
||||
|
||||
def get_vulkan_device_name():
|
||||
vulkaninfo_dump = run_cmd("vulkaninfo").split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print(
|
||||
f"Found {len(vulkaninfo_list)} device names. choosing first one: {vulkaninfo_list[0]}"
|
||||
)
|
||||
return vulkaninfo_list[0]
|
||||
|
||||
|
||||
def get_vulkan_triple_flag(extra_args=[]):
|
||||
if "-iree-vulkan-target-triple=" in " ".join(extra_args):
|
||||
print(f"Using target triple from command line args")
|
||||
return None
|
||||
|
||||
vulkan_device = get_vulkan_device_name()
|
||||
if all(x in vulkan_device for x in ("Apple", "M1")):
|
||||
print(f"Found {vulkan_device} Device. Using m1-moltenvk-macos")
|
||||
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
elif all(x in vulkan_device for x in ("Apple", "M2")):
|
||||
print("Found Apple M2 Device. Using m1-moltenvk-macos")
|
||||
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
elif all(x in vulkan_device for x in ("A100", "SXM4")):
|
||||
print(f"Found {vulkan_device} Device. Using ampere-rtx3080-linux")
|
||||
return "-iree-vulkan-target-triple=ampere-rtx3080-linux"
|
||||
elif all(x in vulkan_device for x in ("RTX", "3090")):
|
||||
print(f"Found {vulkan_device} Device. Using ampere-rtx3090-linux")
|
||||
return "-iree-vulkan-target-triple=ampere-rtx3090-linux"
|
||||
elif "AMD" in vulkan_device:
|
||||
print("Found AMD device. Using rdna2-unknown-linux")
|
||||
return "-iree-vulkan-target-triple=rdna2-unknown-linux"
|
||||
else:
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
or pull up an issue."""
|
||||
)
|
||||
print(f"Target : {vulkan_device}")
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_vulkan_args(extra_args=[]):
|
||||
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
vulkan_flag = []
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args)
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_flag.append(vulkan_triple_flag)
|
||||
return vulkan_flag
|
||||
@@ -12,22 +12,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
from typing import List, Dict
|
||||
|
||||
from iree.compiler import ir
|
||||
from iree.compiler.transforms import ireec as ireec_trans
|
||||
|
||||
MATMUL_OP_NAMES = set(
|
||||
["linalg.matmul", "linalg.batch_matmul", "mhlo.dot", "mhlo.dot_general"])
|
||||
idx = 0
|
||||
|
||||
def model_annotation(
|
||||
ctx: ir.Context,
|
||||
*,
|
||||
input_contents: str,
|
||||
config_path: str,
|
||||
search_op: str = "matmul",
|
||||
):
|
||||
|
||||
def model_annotation(ctx: ir.Context, *, input_contents: str, config_path: str):
|
||||
if os.path.isfile(input_contents):
|
||||
with open(input_contents, "rb") as f:
|
||||
input_contents = f.read()
|
||||
@@ -40,35 +38,20 @@ def model_annotation(
|
||||
|
||||
# The Python API does not expose a general walk() function, so we just
|
||||
# do it ourselves.
|
||||
walk_children(module.operation, configs, 0, search_op)
|
||||
walk_children(module.operation, configs)
|
||||
|
||||
if not module.operation.verify():
|
||||
raise RuntimeError("Modified program does not verify!")
|
||||
|
||||
# More efficient than: print(module)
|
||||
# - Disables verification (already done above)
|
||||
# - Writes as binary, avoiding costly unicode conversions
|
||||
sys.stdout.buffer.write(
|
||||
module.operation.get_asm(assume_verified=True, binary=True))
|
||||
return module
|
||||
|
||||
|
||||
def walk_children(
|
||||
op: ir.Operation, configs: List[Dict], idx: int, search_op: str
|
||||
):
|
||||
if search_op == "matmul":
|
||||
op_names = ["linalg.matmul", "mhlo.dot"]
|
||||
elif search_op == "bmm":
|
||||
op_names = ["linalg.batch_matmul", "mhlo.dot_general"]
|
||||
elif search_op == "conv":
|
||||
op_names = ["mhlo.convolution", "linalg.conv_2d_nhwc_hwcf"]
|
||||
elif search_op == "all":
|
||||
op_names = [
|
||||
"mhlo.dot",
|
||||
"mhlo.dot_general",
|
||||
"mhlo.convolution",
|
||||
"linalg.matmul",
|
||||
"linalg.batch_matmul",
|
||||
"linalg.conv_2d_nhwc_hwcf",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"{search_op} op is not tunable.")
|
||||
|
||||
def walk_children(op: ir.Operation, configs: List[Dict]):
|
||||
for region in op.regions:
|
||||
for block in region.blocks:
|
||||
for child_op in block.operations:
|
||||
@@ -76,41 +59,29 @@ def walk_children(
|
||||
# 'operation' and 'name' attributes.
|
||||
if isinstance(child_op, ir.OpView):
|
||||
child_op = child_op.operation
|
||||
if child_op.name in op_names and idx < len(configs):
|
||||
add_attributes(child_op, configs[idx])
|
||||
if child_op.name in MATMUL_OP_NAMES:
|
||||
global idx
|
||||
tile_sizes, pipeline, workgroup_size, \
|
||||
split_k, pipeline_depth = parse_config(configs[idx])
|
||||
|
||||
add_compilation_info(child_op,
|
||||
tile_sizes=tile_sizes,
|
||||
pipeline=pipeline,
|
||||
workgroup_size=workgroup_size,
|
||||
pipeline_depth=pipeline_depth)
|
||||
|
||||
if split_k:
|
||||
add_split_k(child_op, split_k)
|
||||
|
||||
idx = idx + 1
|
||||
print(f"Updated op {child_op}", file=sys.stderr)
|
||||
walk_children(child_op, configs, idx, search_op)
|
||||
|
||||
|
||||
def add_attributes(op: ir.Operation, config: Dict):
|
||||
(
|
||||
tile_sizes,
|
||||
pipeline,
|
||||
workgroup_size,
|
||||
split_k,
|
||||
pipeline_depth,
|
||||
) = parse_config(config)
|
||||
|
||||
add_compilation_info(
|
||||
op,
|
||||
tile_sizes=tile_sizes,
|
||||
pipeline=pipeline,
|
||||
workgroup_size=workgroup_size,
|
||||
pipeline_depth=pipeline_depth,
|
||||
)
|
||||
|
||||
if split_k:
|
||||
add_attribute_by_name(op, "iree_flow_split_k", split_k)
|
||||
walk_children(child_op, configs)
|
||||
|
||||
|
||||
def parse_config(config: Dict):
|
||||
if config["pipeline"] == "GPU" or config["pipeline"] == "GPU_TENSORCORE":
|
||||
pipeline = (
|
||||
"LLVMGPUMatmulSimt"
|
||||
if config["pipeline"] == "GPU"
|
||||
else "LLVMGPUMatmulTensorCore"
|
||||
)
|
||||
pipeline = "LLVMGPUMatmulSimt" if config[
|
||||
"pipeline"] == "GPU" else "LLVMGPUMatmulTensorCore"
|
||||
tile_sizes = [config["work_group_tile_sizes"]]
|
||||
workgroup_size = config["work_group_sizes"]
|
||||
try:
|
||||
@@ -124,9 +95,8 @@ def parse_config(config: Dict):
|
||||
else:
|
||||
pipeline = config["pipeline"]
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"],
|
||||
config["l1_tile_sizes"],
|
||||
config["vector_tile_sizes"],
|
||||
config["work_group_tile_sizes"], config["l1_tile_sizes"],
|
||||
config["vector_tile_sizes"]
|
||||
]
|
||||
workgroup_size = []
|
||||
split_k = None
|
||||
@@ -134,13 +104,9 @@ def parse_config(config: Dict):
|
||||
return tile_sizes, pipeline, workgroup_size, split_k, pipeline_depth
|
||||
|
||||
|
||||
def add_compilation_info(
|
||||
op: ir.Operation,
|
||||
tile_sizes: List[List[int]],
|
||||
pipeline: str,
|
||||
workgroup_size: List[int],
|
||||
pipeline_depth: int,
|
||||
):
|
||||
def add_compilation_info(op: ir.Operation, tile_sizes: List[List[int]],
|
||||
pipeline: str, workgroup_size: List[int],
|
||||
pipeline_depth: int):
|
||||
# We don't have a Python binding for CompilationInfo, so we just parse
|
||||
# its string form.
|
||||
if pipeline_depth:
|
||||
@@ -148,21 +114,19 @@ def add_compilation_info(
|
||||
f"#iree_codegen.compilation_info<"
|
||||
f"lowering_config = <tile_sizes = {repr(tile_sizes)}>, "
|
||||
f"translation_info = <{pipeline} pipeline_depth = {pipeline_depth}>, "
|
||||
f"workgroup_size = {repr(workgroup_size)}>"
|
||||
)
|
||||
f"workgroup_size = {repr(workgroup_size)}>")
|
||||
else:
|
||||
attr = ir.Attribute.parse(
|
||||
f"#iree_codegen.compilation_info<"
|
||||
f"lowering_config = <tile_sizes = {repr(tile_sizes)}>, "
|
||||
f"translation_info = <{pipeline}>, "
|
||||
f"workgroup_size = {repr(workgroup_size)}>"
|
||||
)
|
||||
f"workgroup_size = {repr(workgroup_size)}>")
|
||||
op.attributes["compilation_info"] = attr
|
||||
|
||||
|
||||
def add_attribute_by_name(op: ir.Operation, name: str, val: int):
|
||||
attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
|
||||
op.attributes[name] = attr
|
||||
def add_split_k(op: ir.Operation, k: int):
|
||||
attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), k)
|
||||
op.attributes["iree_flow_split_k"] = attr
|
||||
|
||||
|
||||
def create_context() -> ir.Context:
|
||||
@@ -174,14 +138,6 @@ def create_context() -> ir.Context:
|
||||
|
||||
if __name__ == "__main__":
|
||||
with create_context() as ctx:
|
||||
module = model_annotation(
|
||||
ctx,
|
||||
input_contents=sys.argv[1],
|
||||
config_path=sys.argv[2],
|
||||
search_op="all",
|
||||
)
|
||||
mlir_str = str(module)
|
||||
filename = "tuned_model.mlir"
|
||||
with open(filename, "w") as f:
|
||||
f.write(mlir_str)
|
||||
print(f"Saved mlir in {filename}.")
|
||||
model_annotation(ctx,
|
||||
input_contents=sys.argv[1],
|
||||
config_path=sys.argv[2])
|
||||
|
||||
@@ -20,8 +20,8 @@ def dir_path(path):
|
||||
if os.path.isdir(path):
|
||||
return path
|
||||
else:
|
||||
os.mkdir(path)
|
||||
return path
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"readable_dir:{path} is not a valid path")
|
||||
|
||||
|
||||
def dir_file(path):
|
||||
@@ -29,80 +29,43 @@ def dir_file(path):
|
||||
return path
|
||||
else:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"readable_file:{path} is not a valid file"
|
||||
)
|
||||
f"readable_file:{path} is not a valid file")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="SHARK runner.")
|
||||
parser = argparse.ArgumentParser(description='SHARK runner.')
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Device on which shark_runner runs. options are cpu, cuda, and vulkan",
|
||||
)
|
||||
help="Device on which shark_runner runs. options are cpu, gpu, and vulkan")
|
||||
parser.add_argument(
|
||||
"--repro_dir",
|
||||
help="Directory to which module files will be saved for reproduction or debugging.",
|
||||
help=
|
||||
"Directory to which module files will be saved for reproduction or debugging.",
|
||||
type=dir_path,
|
||||
default="./shark_tmp",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_tf32",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Enables TF32 precision calculations on supported GPUs.",
|
||||
)
|
||||
default="/tmp/")
|
||||
parser.add_argument("--save_mlir",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Saves input MLIR module to /tmp/ directory.")
|
||||
parser.add_argument("--save_vmfb",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Saves iree .vmfb module to /tmp/ directory.")
|
||||
parser.add_argument(
|
||||
"--model_config_path",
|
||||
help="Directory to where the tuned model config file is located.",
|
||||
default=None,
|
||||
)
|
||||
default=None)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_warmup_iterations",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Run the model for the specified number of warmup iterations.",
|
||||
)
|
||||
default=2,
|
||||
help="Run the model for the specified number of warmup iterations.")
|
||||
parser.add_argument(
|
||||
"--num_iterations",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Run the model for the specified number of iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--onnx_bench",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled, pytest bench results will include ONNX benchmark results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shark_prefix",
|
||||
default="latest",
|
||||
help="gs://shark_tank/<this_flag>/model_directories",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update_tank",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled, SHARK downloader will update local shark_tank if local hash is different from latest upstream hash.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dispatch_benchmarks",
|
||||
default=None,
|
||||
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="temp_dispatch_benchmarks",
|
||||
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
|
||||
)
|
||||
default=1,
|
||||
help="Run the model for the specified number of iterations.")
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
|
||||
@@ -1,371 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.iree_utils.compile_utils import export_iree_module_to_vmfb
|
||||
from shark.iree_utils.benchmark_utils import (
|
||||
build_benchmark_args,
|
||||
run_benchmark_module,
|
||||
)
|
||||
from shark.parser import shark_args
|
||||
from datetime import datetime
|
||||
import time
|
||||
import csv
|
||||
import os
|
||||
|
||||
|
||||
class OnnxFusionOptions(object):
|
||||
def __init__(self):
|
||||
self.disable_gelu = False
|
||||
self.disable_layer_norm = False
|
||||
self.disable_attention = False
|
||||
self.disable_skip_layer_norm = False
|
||||
self.disable_embed_layer_norm = False
|
||||
self.disable_bias_skip_layer_norm = False
|
||||
self.disable_bias_gelu = False
|
||||
self.enable_gelu_approximation = False
|
||||
self.use_mask_index = False
|
||||
self.no_attention_mask = False
|
||||
|
||||
|
||||
class SharkBenchmarkRunner(SharkRunner):
|
||||
# SharkRunner derived class with Benchmarking capabilities.
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: bytes,
|
||||
function_name: str = "forward",
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
extra_args: list = [],
|
||||
):
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.frontend_model = None
|
||||
self.vmfb_file = None
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
SharkRunner.__init__(
|
||||
self,
|
||||
mlir_module,
|
||||
function_name,
|
||||
device,
|
||||
self.mlir_dialect,
|
||||
self.extra_args,
|
||||
compile_vmfb=True,
|
||||
)
|
||||
if self.vmfb_file == None:
|
||||
self.vmfb_file = export_iree_module_to_vmfb(
|
||||
mlir_module,
|
||||
device,
|
||||
shark_args.repro_dir,
|
||||
self.mlir_dialect,
|
||||
function_name,
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
|
||||
def setup_cl(self, input_tensors):
|
||||
self.benchmark_cl = build_benchmark_args(
|
||||
self.vmfb_file,
|
||||
self.device,
|
||||
input_tensors,
|
||||
mlir_dialect=self.mlir_dialect,
|
||||
)
|
||||
print(self.benchmark_cl)
|
||||
|
||||
def benchmark_frontend(self, modelname):
|
||||
if self.mlir_dialect in ["linalg", "torch"]:
|
||||
return self.benchmark_torch(modelname)
|
||||
elif self.mlir_dialect in ["mhlo", "tf"]:
|
||||
return self.benchmark_tf(modelname)
|
||||
|
||||
def benchmark_torch(self, modelname):
|
||||
import torch
|
||||
from tank.model_utils import get_torch_model
|
||||
|
||||
if self.device == "cuda":
|
||||
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
torch_device = torch.device(
|
||||
"cuda:0" if self.device == "cuda" else "cpu"
|
||||
)
|
||||
HFmodel, input = get_torch_model(modelname)[:2]
|
||||
frontend_model = HFmodel.model
|
||||
frontend_model.to(torch_device)
|
||||
input.to(torch_device)
|
||||
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(input)
|
||||
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(input)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
print(
|
||||
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
]
|
||||
|
||||
def benchmark_tf(self, modelname):
|
||||
import tensorflow as tf
|
||||
from tank.model_utils_tf import get_tf_model
|
||||
|
||||
model, input, = get_tf_model(
|
||||
modelname
|
||||
)[:2]
|
||||
frontend_model = model
|
||||
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(*input)
|
||||
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(*input)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
print(
|
||||
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
]
|
||||
|
||||
def benchmark_c(self):
|
||||
print(self.benchmark_cl)
|
||||
result = run_benchmark_module(self.benchmark_cl)
|
||||
print(f"Shark-IREE-C benchmark:{result} iter/second")
|
||||
return [f"{result}", f"{1000/result}"]
|
||||
|
||||
def benchmark_python(self, inputs):
|
||||
input_list = [x for x in inputs]
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
self.run(input_list)
|
||||
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = self.run(input_list)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
print(
|
||||
f"Shark-IREE Python benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
]
|
||||
|
||||
def benchmark_onnx(self, modelname, inputs):
|
||||
if self.device == "cuda":
|
||||
print(
|
||||
"Currently GPU benchmarking on ONNX is not supported in SHARK."
|
||||
)
|
||||
return ["N/A", "N/A"]
|
||||
else:
|
||||
from onnxruntime.transformers.benchmark import run_onnxruntime
|
||||
from onnxruntime.transformers.huggingface_models import MODELS
|
||||
from onnxruntime.transformers.benchmark_helper import (
|
||||
ConfigModifier,
|
||||
Precision,
|
||||
)
|
||||
import psutil
|
||||
|
||||
if modelname == "microsoft/MiniLM-L12-H384-uncased":
|
||||
modelname = "bert-base-uncased"
|
||||
if modelname not in MODELS:
|
||||
print(
|
||||
f"{modelname} is currently not supported in ORT's HF. Check \
|
||||
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/huggingface_models.py \
|
||||
for currently supported models. Exiting benchmark ONNX."
|
||||
)
|
||||
return ["N/A", "N/A"]
|
||||
use_gpu = self.device == "cuda"
|
||||
num_threads = psutil.cpu_count(logical=False)
|
||||
batch_sizes = [1]
|
||||
sequence_lengths = [128]
|
||||
cache_dir = os.path.join(".", "cache_models")
|
||||
onnx_dir = os.path.join(".", "onnx_models")
|
||||
verbose = False
|
||||
input_counts = [1]
|
||||
optimize_onnx = True
|
||||
validate_onnx = False
|
||||
disable_ort_io_binding = False
|
||||
use_raw_attention_mask = True
|
||||
model_fusion_statistics = {}
|
||||
overwrite = False
|
||||
model_source = "pt" # Either "pt" or "tf"
|
||||
provider = None
|
||||
config_modifier = ConfigModifier(None)
|
||||
onnx_args = OnnxFusionOptions()
|
||||
result = run_onnxruntime(
|
||||
use_gpu,
|
||||
provider,
|
||||
(modelname,),
|
||||
None,
|
||||
config_modifier,
|
||||
Precision.FLOAT32,
|
||||
num_threads,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
shark_args.num_iterations,
|
||||
input_counts,
|
||||
optimize_onnx,
|
||||
validate_onnx,
|
||||
cache_dir,
|
||||
onnx_dir,
|
||||
verbose,
|
||||
overwrite,
|
||||
disable_ort_io_binding,
|
||||
use_raw_attention_mask,
|
||||
model_fusion_statistics,
|
||||
model_source,
|
||||
onnx_args,
|
||||
)
|
||||
print(
|
||||
f"ONNX ORT-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
result[0]["QPS"],
|
||||
result[0]["average_latency_ms"],
|
||||
]
|
||||
|
||||
def get_metadata(self, modelname):
|
||||
with open("./tank/model_metadata.csv", mode="r") as csvfile:
|
||||
torch_reader = csv.reader(csvfile, delimiter=",")
|
||||
fields = next(torch_reader)
|
||||
for row in torch_reader:
|
||||
torch_model_name = row[0]
|
||||
if torch_model_name == modelname:
|
||||
param_count = row[3]
|
||||
model_tags = row[4]
|
||||
model_notes = row[5]
|
||||
return [param_count, model_tags, model_notes]
|
||||
|
||||
def compare_bench_results(self, baseline: str, result: str):
|
||||
# Takes two numbers represented as strings and returns "<n>x slower/faster", as in "result is <n>x slower than baseline".
|
||||
a = float(baseline)
|
||||
b = float(result)
|
||||
if a < b:
|
||||
# result slower than baseline
|
||||
comparison = (b - a) / a
|
||||
comp_str = f"{round(comparison, 2)}x slower"
|
||||
elif a > b:
|
||||
# result faster than baseline
|
||||
comparison = a / b
|
||||
comp_str = f"{round(comparison, 2)}x faster"
|
||||
else:
|
||||
comp_str = "equal"
|
||||
return comp_str
|
||||
|
||||
def benchmark_all_csv(
|
||||
self, inputs: tuple, modelname, dynamic, device_str, frontend
|
||||
):
|
||||
self.setup_cl(inputs)
|
||||
field_names = [
|
||||
"model",
|
||||
"engine",
|
||||
"dialect",
|
||||
"device",
|
||||
"shape_type",
|
||||
"data_type",
|
||||
"iter/sec",
|
||||
"ms/iter",
|
||||
"vs. PyTorch/TF",
|
||||
"iterations",
|
||||
"param_count",
|
||||
"tags",
|
||||
"notes",
|
||||
"datetime",
|
||||
]
|
||||
engines = ["frontend", "shark_python", "shark_iree_c"]
|
||||
if shark_args.onnx_bench == True:
|
||||
engines.append("onnxruntime")
|
||||
|
||||
if not os.path.exists("bench_results.csv"):
|
||||
with open("bench_results.csv", mode="w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(field_names)
|
||||
|
||||
with open("bench_results.csv", mode="a", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=field_names)
|
||||
bench_result = {}
|
||||
bench_result["model"] = modelname
|
||||
if dynamic == True:
|
||||
bench_result["shape_type"] = "dynamic"
|
||||
else:
|
||||
bench_result["shape_type"] = "static"
|
||||
bench_result["device"] = device_str
|
||||
bench_result["data_type"] = inputs[0].dtype
|
||||
for e in engines:
|
||||
(
|
||||
bench_result["param_count"],
|
||||
bench_result["tags"],
|
||||
bench_result["notes"],
|
||||
) = ["", "", ""]
|
||||
if e == "frontend":
|
||||
bench_result["engine"] = frontend
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
) = self.benchmark_frontend(modelname)
|
||||
self.frontend_result = bench_result["ms/iter"]
|
||||
bench_result["vs. PyTorch/TF"] = "="
|
||||
(
|
||||
bench_result["param_count"],
|
||||
bench_result["tags"],
|
||||
bench_result["notes"],
|
||||
) = self.get_metadata(modelname)
|
||||
|
||||
elif e == "shark_python":
|
||||
bench_result["engine"] = "shark_python"
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
) = self.benchmark_python(inputs)
|
||||
|
||||
bench_result[
|
||||
"vs. PyTorch/TF"
|
||||
] = self.compare_bench_results(
|
||||
self.frontend_result, bench_result["ms/iter"]
|
||||
)
|
||||
|
||||
elif e == "shark_iree_c":
|
||||
bench_result["engine"] = "shark_iree_c"
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
) = self.benchmark_c()
|
||||
|
||||
bench_result[
|
||||
"vs. PyTorch/TF"
|
||||
] = self.compare_bench_results(
|
||||
self.frontend_result, bench_result["ms/iter"]
|
||||
)
|
||||
|
||||
elif e == "onnxruntime":
|
||||
bench_result["engine"] = "onnxruntime"
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
) = self.benchmark_onnx(modelname, inputs)
|
||||
|
||||
bench_result["dialect"] = self.mlir_dialect
|
||||
bench_result["iterations"] = shark_args.num_iterations
|
||||
bench_result["datetime"] = str(datetime.now())
|
||||
writer.writerow(bench_result)
|
||||
@@ -1,280 +0,0 @@
|
||||
# Lint as: python3
|
||||
"""SHARK Downloader"""
|
||||
# Requirements : Put shark_tank in SHARK directory
|
||||
# /SHARK
|
||||
# /gen_shark_tank
|
||||
# /tflite
|
||||
# /albert_lite_base
|
||||
# /...model_name...
|
||||
# /tf
|
||||
# /pytorch
|
||||
#
|
||||
#
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import urllib.request
|
||||
import json
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from shark.parser import shark_args
|
||||
|
||||
input_type_to_np_dtype = {
|
||||
"float32": np.float32,
|
||||
"float64": np.float64,
|
||||
"bool": np.bool_,
|
||||
"int32": np.int32,
|
||||
"int64": np.int64,
|
||||
"uint8": np.uint8,
|
||||
"int8": np.int8,
|
||||
}
|
||||
|
||||
|
||||
# Save the model in the home local so it needn't be fetched everytime in the CI.
|
||||
home = str(Path.home())
|
||||
alt_path = os.path.join(os.path.dirname(__file__), "../gen_shark_tank/")
|
||||
custom_path = shark_args.local_tank_cache
|
||||
if os.path.exists(alt_path):
|
||||
WORKDIR = alt_path
|
||||
print(
|
||||
f"Using {WORKDIR} as shark_tank directory. Delete this directory if you aren't working from locally generated shark_tank."
|
||||
)
|
||||
if custom_path:
|
||||
if not os.path.exists(custom_path):
|
||||
os.mkdir(custom_path)
|
||||
|
||||
WORKDIR = custom_path
|
||||
|
||||
print(f"Using {WORKDIR} as local shark_tank cache directory.")
|
||||
else:
|
||||
WORKDIR = os.path.join(home, ".local/shark_tank/")
|
||||
print(
|
||||
f"shark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache="
|
||||
" pytest flag"
|
||||
)
|
||||
|
||||
# Checks whether the directory and files exists.
|
||||
def check_dir_exists(model_name, frontend="torch", dynamic=""):
|
||||
model_dir = os.path.join(WORKDIR, model_name)
|
||||
|
||||
# Remove the _tf keyword from end.
|
||||
if frontend in ["tf", "tensorflow"]:
|
||||
model_name = model_name[:-3]
|
||||
elif frontend in ["tflite"]:
|
||||
model_name = model_name[:-7]
|
||||
elif frontend in ["torch", "pytorch"]:
|
||||
model_name = model_name[:-6]
|
||||
|
||||
if os.path.isdir(model_dir):
|
||||
if (
|
||||
os.path.isfile(
|
||||
os.path.join(
|
||||
model_dir,
|
||||
model_name + dynamic + "_" + str(frontend) + ".mlir",
|
||||
)
|
||||
)
|
||||
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
|
||||
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
|
||||
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
|
||||
and os.path.isfile(os.path.join(model_dir, "hash.npy"))
|
||||
):
|
||||
print(
|
||||
f"""The models are present in the {WORKDIR}. If you want a fresh
|
||||
download, consider deleting the directory."""
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Downloads the torch model from gs://shark_tank dir.
|
||||
def download_torch_model(
|
||||
model_name, dynamic=False, tank_url="gs://shark_tank/latest"
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_torch"
|
||||
|
||||
def gs_download_model():
|
||||
gs_command = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp -r '
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ " "
|
||||
+ WORKDIR
|
||||
)
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
|
||||
if not check_dir_exists(model_dir_name, frontend="torch", dynamic=dyn_str):
|
||||
gs_download_model()
|
||||
else:
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp '
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ "/hash.npy"
|
||||
+ " "
|
||||
+ os.path.join(model_dir, "upstream_hash.npy")
|
||||
)
|
||||
if os.system(gs_hash) != 0:
|
||||
raise Exception("hash of the model not present in the tank.")
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
if shark_args.update_tank == True:
|
||||
gs_download_model()
|
||||
else:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
with open(
|
||||
os.path.join(model_dir, model_name + dyn_str + "_torch.mlir"),
|
||||
mode="rb",
|
||||
) as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
|
||||
|
||||
# Downloads the tflite model from gs://shark_tank dir.
|
||||
def download_tflite_model(
|
||||
model_name, dynamic=False, tank_url="gs://shark_tank/latest"
|
||||
):
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_tflite"
|
||||
|
||||
def gs_download_model():
|
||||
gs_command = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp -r '
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ " "
|
||||
+ WORKDIR
|
||||
)
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
|
||||
if not check_dir_exists(
|
||||
model_dir_name, frontend="tflite", dynamic=dyn_str
|
||||
):
|
||||
gs_download_model()
|
||||
else:
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp '
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ "/hash.npy"
|
||||
+ " "
|
||||
+ os.path.join(model_dir, "upstream_hash.npy")
|
||||
)
|
||||
if os.system(gs_hash) != 0:
|
||||
raise Exception("hash of the model not present in the tank.")
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
if shark_args.update_tank == True:
|
||||
gs_download_model()
|
||||
else:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
with open(
|
||||
os.path.join(model_dir, model_name + dyn_str + "_tflite.mlir"),
|
||||
mode="rb",
|
||||
) as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
|
||||
|
||||
def download_tf_model(
|
||||
model_name, tuned=None, tank_url="gs://shark_tank/latest"
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_tf"
|
||||
|
||||
def gs_download_model():
|
||||
gs_command = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp -r '
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ " "
|
||||
+ WORKDIR
|
||||
)
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
|
||||
if not check_dir_exists(model_dir_name, frontend="tf"):
|
||||
gs_download_model()
|
||||
else:
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp '
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ "/hash.npy"
|
||||
+ " "
|
||||
+ os.path.join(model_dir, "upstream_hash.npy")
|
||||
)
|
||||
if os.system(gs_hash) != 0:
|
||||
raise Exception("hash of the model not present in the tank.")
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
if shark_args.update_tank == True:
|
||||
gs_download_model()
|
||||
else:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
suffix = "_tf.mlir" if tuned is None else "_tf_" + tuned + ".mlir"
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
if not os.path.isfile(filename):
|
||||
filename = os.path.join(model_dir, model_name + "_tf.mlir")
|
||||
|
||||
with open(filename, mode="rb") as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
@@ -1,246 +1,136 @@
|
||||
# Lint as: python3
|
||||
"""SHARK Importer"""
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import iree.compiler.tflite as iree_tflite_compile
|
||||
import iree.runtime as iree_rt
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# List of the supported frontends.
|
||||
supported_frontends = {
|
||||
"tensorflow",
|
||||
"tf",
|
||||
"pytorch",
|
||||
"torch",
|
||||
"tf-lite",
|
||||
"tflite",
|
||||
}
|
||||
import sys
|
||||
import tensorflow.compat.v2 as tf
|
||||
import urllib.request
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
|
||||
class SharkImporter:
|
||||
"""
|
||||
SharkImporter converts frontend modules into a
|
||||
mlir_module. The supported frameworks are tensorflow,
|
||||
pytorch, and tf-lite.
|
||||
|
||||
...
|
||||
def __init__(self,
|
||||
model_path,
|
||||
model_type: str = "tflite",
|
||||
model_source_hub: str = "tfhub",
|
||||
device: str = None,
|
||||
dynamic: bool = False,
|
||||
jit_trace: bool = False,
|
||||
benchmark_mode: bool = False):
|
||||
self.model_path = model_path
|
||||
self.model_type = model_type
|
||||
self.model_source_hub = model_source_hub
|
||||
self.device = device
|
||||
self.dynamic = dynamic
|
||||
self.jit_trace = jit_trace
|
||||
self.benchmark_mode = benchmark_mode
|
||||
self.inputs = None
|
||||
self.input_details = None
|
||||
self.output_details = None
|
||||
|
||||
Attributes
|
||||
----------
|
||||
module :
|
||||
torch, tensorflow or tf-lite module.
|
||||
inputs :
|
||||
inputs to the module, may be required for the shape
|
||||
information.
|
||||
frontend: str
|
||||
frontend to which the module belongs.
|
||||
raw_model_file: str
|
||||
temp tflite model path
|
||||
# create tmp model file directory
|
||||
if self.model_path is None:
|
||||
print("Error. No model_path, Please input model path.")
|
||||
return
|
||||
|
||||
Methods
|
||||
-------
|
||||
import_mlir(is_dynamic, tracing_required, func_name):
|
||||
is_dynamic: input shapes to be totally dynamic (pytorch specific).
|
||||
tracing_required: whether tracing is required (pytorch specific.
|
||||
func_name: The function to be traced out or imported to mlir.
|
||||
|
||||
import_debug(is_dynamic, tracing_required, func_name):
|
||||
returns the converted (mlir_module,func_name) with inputs and golden
|
||||
outputs.
|
||||
The inputs and outputs are converted into np array.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
inputs: tuple = (),
|
||||
frontend: str = "torch",
|
||||
raw_model_file: str = "",
|
||||
):
|
||||
self.module = module
|
||||
self.inputs = None if len(inputs) == 0 else inputs
|
||||
self.frontend = frontend
|
||||
if not self.frontend in supported_frontends:
|
||||
print(
|
||||
f"The frontend is not in the supported_frontends: {supported_frontends}"
|
||||
)
|
||||
sys.exit(1)
|
||||
self.raw_model_file = raw_model_file
|
||||
|
||||
# NOTE: The default function for torch is "forward" and tf-lite is "main".
|
||||
|
||||
def _torch_mlir(self, is_dynamic, tracing_required):
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module
|
||||
|
||||
return get_torch_mlir_module(
|
||||
self.module, self.inputs, is_dynamic, tracing_required
|
||||
)
|
||||
|
||||
def _tf_mlir(self, func_name, save_dir="./shark_tmp/"):
|
||||
from iree.compiler import tf as tfc
|
||||
|
||||
return tfc.compile_module(
|
||||
self.module,
|
||||
exported_names=[func_name],
|
||||
import_only=True,
|
||||
output_file=save_dir,
|
||||
)
|
||||
|
||||
def _tflite_mlir(self, func_name, save_dir="./shark_tmp/"):
|
||||
from iree.compiler import tflite as tflitec
|
||||
from shark.iree_utils._common import IREE_TARGET_MAP
|
||||
|
||||
self.mlir_model = tflitec.compile_file(
|
||||
self.raw_model_file, # in tflite, it is a path to .tflite file, not a tflite interpreter
|
||||
input_type="tosa",
|
||||
import_only=True,
|
||||
output_file=save_dir,
|
||||
)
|
||||
return self.mlir_model
|
||||
|
||||
# Adds the conversion of the frontend with the private function.
|
||||
def import_mlir(
|
||||
self,
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
func_name="forward",
|
||||
save_dir="./shark_tmp/",
|
||||
):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
if self.inputs == None:
|
||||
print(
|
||||
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module"
|
||||
if self.model_source_hub == "tfhub":
|
||||
# compile and run tfhub tflite
|
||||
if self.model_type == "tflite":
|
||||
print("Setting up for TMP_DIR")
|
||||
exe_basename = os.path.basename(sys.argv[0])
|
||||
self.workdir = os.path.join(os.path.dirname(__file__), "tmp",
|
||||
exe_basename)
|
||||
print(f"TMP_DIR = {self.workdir}")
|
||||
os.makedirs(self.workdir, exist_ok=True)
|
||||
self.tflite_file = '/'.join([self.workdir, 'model.tflite'])
|
||||
print("Setting up local address for tflite model file: ",
|
||||
self.tflite_file)
|
||||
if os.path.exists(self.model_path):
|
||||
self.tflite_file = self.model_path
|
||||
else:
|
||||
print("Download tflite model")
|
||||
urllib.request.urlretrieve(self.model_path,
|
||||
self.tflite_file)
|
||||
print("Setting up tflite interpreter")
|
||||
self.tflite_interpreter = tf.lite.Interpreter(
|
||||
model_path=self.tflite_file)
|
||||
self.tflite_interpreter.allocate_tensors()
|
||||
# default input initialization
|
||||
self.input_details, self.output_details = self.get_model_details(
|
||||
)
|
||||
sys.exit(1)
|
||||
return self._torch_mlir(is_dynamic, tracing_required), func_name
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
return self._tf_mlir(func_name, save_dir), func_name
|
||||
if self.frontend in ["tflite", "tf-lite"]:
|
||||
func_name = "main"
|
||||
return self._tflite_mlir(func_name, save_dir), func_name
|
||||
inputs = self.generate_inputs(
|
||||
self.input_details) # device_inputs
|
||||
self.setup_inputs(inputs)
|
||||
|
||||
# Converts the frontend specific tensors into np array.
|
||||
def convert_to_numpy(self, array_tuple: tuple):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
return [x.detach().cpu().numpy() for x in array_tuple]
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
return [x.numpy() for x in array_tuple]
|
||||
def generate_inputs(self, input_details):
|
||||
args = []
|
||||
for input in input_details:
|
||||
print(str(input["shape"]), input["dtype"].__name__)
|
||||
args.append(np.zeros(shape=input["shape"], dtype=input["dtype"]))
|
||||
return args
|
||||
|
||||
# Saves `function_name.npy`, `inputs.npz`, `golden_out.npz` and `model_name.mlir` in the directory `dir`.
|
||||
def save_data(
|
||||
self, dir, model_name, mlir_data, func_name, inputs, outputs
|
||||
):
|
||||
import numpy as np
|
||||
def get_model_details(self):
|
||||
if self.model_type == "tflite":
|
||||
print("Get tflite input output details")
|
||||
self.input_details = self.tflite_interpreter.get_input_details()
|
||||
self.output_details = self.tflite_interpreter.get_output_details()
|
||||
return self.input_details, self.output_details
|
||||
|
||||
inputs_name = "inputs.npz"
|
||||
outputs_name = "golden_out.npz"
|
||||
func_file_name = "function_name"
|
||||
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
|
||||
try:
|
||||
inputs = [x.cpu().detach() for x in inputs]
|
||||
except AttributeError:
|
||||
try:
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
except AttributeError:
|
||||
inputs = [x for x in inputs]
|
||||
np.savez(os.path.join(dir, inputs_name), *inputs)
|
||||
np.savez(os.path.join(dir, outputs_name), *outputs)
|
||||
np.save(os.path.join(dir, func_file_name), np.array(func_name))
|
||||
def setup_inputs(self, inputs):
|
||||
print("Setting up inputs")
|
||||
self.inputs = inputs
|
||||
|
||||
if self.frontend == "torch":
|
||||
with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file:
|
||||
mlir_file.write(mlir_data)
|
||||
def compile(self, inputs=None):
|
||||
if inputs is not None:
|
||||
self.setup_inputs(inputs)
|
||||
# preprocess model_path to get model_type and Model Source Hub
|
||||
print("Shark Importer Intialize SharkInference and Do Compile")
|
||||
if self.model_source_hub == "tfhub":
|
||||
# compile and run tfhub tflite
|
||||
print("Inference tfhub model")
|
||||
self.shark_module = SharkInference(self.tflite_file,
|
||||
self.inputs,
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=self.jit_trace)
|
||||
self.shark_module.set_frontend("tflite")
|
||||
self.shark_module.compile()
|
||||
elif self.model_source_hub == "huggingface":
|
||||
print("Inference", self.model_source_hub, " not implemented yet")
|
||||
elif self.model_source_hub == "jaxhub":
|
||||
print("Inference", self.model_source_hub, " not implemented yet")
|
||||
|
||||
return
|
||||
def forward(self, inputs=None):
|
||||
if inputs is not None:
|
||||
self.setup_inputs(inputs)
|
||||
# preprocess model_path to get model_type and Model Source Hub
|
||||
print("Shark Importer forward Model")
|
||||
if self.model_source_hub == "tfhub":
|
||||
shark_results = self.shark_module.forward(self.inputs)
|
||||
# Fix type information for unsigned cases.
|
||||
# for test compare result
|
||||
shark_results = list(shark_results)
|
||||
for i in range(len(self.output_details)):
|
||||
dtype = self.output_details[i]["dtype"]
|
||||
shark_results[i] = shark_results[i].astype(dtype)
|
||||
return shark_results
|
||||
elif self.model_source_hub == "huggingface":
|
||||
print("Inference", self.model_source_hub, " not implemented yet")
|
||||
elif self.model_source_hub == "jaxhub":
|
||||
print("Inference", self.model_source_hub, " not implemented yet")
|
||||
|
||||
def import_debug(
|
||||
self,
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
func_name="forward",
|
||||
dir=tempfile.gettempdir(),
|
||||
model_name="model",
|
||||
):
|
||||
if self.inputs == None:
|
||||
print(
|
||||
f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir."
|
||||
)
|
||||
sys.exit(1)
|
||||
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
|
||||
artifact_path = os.path.join(dir, model_name_mlir)
|
||||
imported_mlir = self.import_mlir(
|
||||
is_dynamic,
|
||||
tracing_required,
|
||||
func_name,
|
||||
save_dir=artifact_path,
|
||||
|
||||
def shark_load(model_name, file_path):
|
||||
file_link = f"https://storage.googleapis.com/shark_tank/users/stanley/{model_name}.mlir"
|
||||
response = urllib.request.urlretrieve(file_link, file_path)
|
||||
if not os.path.isfile(file_path):
|
||||
raise ValueError(
|
||||
f"Tried looking for target mlir in {file_path}, but cannot be found."
|
||||
)
|
||||
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
|
||||
# TODO: Check for multiple outputs.
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
import torch
|
||||
|
||||
golden_out = self.module(*self.inputs)
|
||||
if torch.is_tensor(golden_out):
|
||||
golden_out = tuple(
|
||||
golden_out.detach().cpu().numpy(),
|
||||
)
|
||||
else:
|
||||
golden_out = self.convert_to_numpy(golden_out)
|
||||
# Save the artifacts in the directory dir.
|
||||
self.save_data(
|
||||
dir,
|
||||
model_name,
|
||||
imported_mlir[0],
|
||||
imported_mlir[1],
|
||||
self.inputs,
|
||||
golden_out,
|
||||
)
|
||||
return (
|
||||
imported_mlir,
|
||||
self.convert_to_numpy(self.inputs),
|
||||
golden_out,
|
||||
)
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
import tensorflow as tf
|
||||
|
||||
golden_out = self.module.forward(*self.inputs)
|
||||
if tf.is_tensor(golden_out):
|
||||
golden_out = tuple(
|
||||
golden_out.numpy(),
|
||||
)
|
||||
elif golden_out is tuple:
|
||||
golden_out = self.convert_to_numpy(golden_out)
|
||||
elif hasattr(golden_out, "logits"):
|
||||
# from transformers import TFSequenceClassifierOutput
|
||||
golden_out = golden_out.logits
|
||||
else:
|
||||
golden_out = golden_out.last_hidden_state
|
||||
# Save the artifacts in the directory dir.
|
||||
self.save_data(
|
||||
dir,
|
||||
model_name,
|
||||
imported_mlir[0],
|
||||
imported_mlir[1],
|
||||
self.inputs,
|
||||
golden_out,
|
||||
)
|
||||
return (
|
||||
imported_mlir,
|
||||
self.convert_to_numpy(self.inputs),
|
||||
golden_out,
|
||||
)
|
||||
if self.frontend in ["tflite", "tf-lite"]:
|
||||
# TODO(Chi): Validate it for tflite models.
|
||||
golden_out = self.module.invoke_tflite(self.inputs)
|
||||
self.save_data(
|
||||
dir,
|
||||
model_name,
|
||||
imported_mlir[0],
|
||||
imported_mlir[1],
|
||||
self.inputs,
|
||||
golden_out,
|
||||
)
|
||||
return (
|
||||
imported_mlir,
|
||||
self.inputs,
|
||||
golden_out,
|
||||
)
|
||||
with open(file_path, "rb") as input_file:
|
||||
model_mlir = input_file.read()
|
||||
return model_mlir
|
||||
|
||||
@@ -9,202 +9,107 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from shark.iree_utils.compile_utils import (
|
||||
export_iree_module_to_vmfb,
|
||||
load_flatbuffer,
|
||||
create_dispatch_dirs,
|
||||
compile_benchmark_dirs,
|
||||
)
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module, run_on_refbackend
|
||||
import os
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.parser import shark_args
|
||||
import numpy as np
|
||||
from shark.shark_runner import SharkRunner, SharkBenchmarkRunner
|
||||
import time
|
||||
import sys
|
||||
|
||||
|
||||
dtype_to_np_dtype = {
|
||||
"f32": np.float32,
|
||||
"f64": np.float64,
|
||||
"i32": np.int32,
|
||||
"i64": np.int64,
|
||||
"i1": np.bool_,
|
||||
}
|
||||
# Prints to stderr.
|
||||
def print_err(*a):
|
||||
print(*a, file=sys.stderr)
|
||||
|
||||
|
||||
class SharkInference:
|
||||
"""
|
||||
Runs prediction or inference on mlir_module.
|
||||
"""Inference API targeting pytorch, tensorflow, linalg, mhlo and tosa frontend."""
|
||||
|
||||
...
|
||||
def __init__(self,
|
||||
model,
|
||||
input: tuple,
|
||||
device: str = None,
|
||||
dynamic: bool = False,
|
||||
jit_trace: bool = False,
|
||||
benchmark_mode: bool = False):
|
||||
self.model = model
|
||||
self.input = input
|
||||
self.dynamic = dynamic
|
||||
self.jit_trace = jit_trace
|
||||
self.benchmark_mode = benchmark_mode
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
|
||||
function_name : str
|
||||
function to execute in the given mlir_module.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
mlir_dialect: str
|
||||
The dialect in which the given mlir_module is in.
|
||||
Refer to {https://mlir.llvm.org/docs/Dialects/}
|
||||
is_benchmark: bool
|
||||
Whether this SharkInference module should be benchmark-enabled.
|
||||
# By default it's torch frontend.
|
||||
self.frontend = "pytorch"
|
||||
|
||||
Methods
|
||||
-------
|
||||
run(inputs=None):
|
||||
Runs the mlir_module with the given inputs, if the inputs are not
|
||||
given it autogenerates the inputs. Also, the inputs should be a
|
||||
numpy array.
|
||||
input_info():
|
||||
Gives the information about the inputs required by the `function_name`.
|
||||
This can be expensive as it does string matching to do so.
|
||||
# Sets the device.
|
||||
self.device = device if device is not None else shark_args.device
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: bytes,
|
||||
function_name: str = "forward",
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
is_benchmark: bool = False,
|
||||
dispatch_benchmark: str = None,
|
||||
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
self.function_name = function_name
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.is_benchmark = is_benchmark
|
||||
self.dispatch_benchmarks = (
|
||||
shark_args.dispatch_benchmarks
|
||||
if dispatch_benchmark is None
|
||||
else dispatch_benchmark
|
||||
)
|
||||
self.dispatch_benchmarks_dir = (
|
||||
shark_args.dispatch_benchmarks_dir
|
||||
if dispatch_benchmark_dir == "temp_dispatch_benchmarks"
|
||||
else dispatch_benchmark_dir
|
||||
)
|
||||
self.model_config_path = shark_args.model_config_path
|
||||
|
||||
self.shark_runner = None
|
||||
|
||||
def compile(self, extra_args=[]):
|
||||
|
||||
if self.dispatch_benchmarks is not None:
|
||||
extra_args.append(
|
||||
f"--iree-hal-dump-executable-sources-to={self.dispatch_benchmarks_dir}"
|
||||
)
|
||||
temp_dir = self.dispatch_benchmarks_dir.split("/")
|
||||
temp_dir[-1] = "temp_" + temp_dir[-1]
|
||||
temp_dir = "/".join(temp_dir)
|
||||
self.temp_dispatch_benchmarks_dir = temp_dir
|
||||
extra_args.append(
|
||||
f"--iree-hal-dump-executable-benchmarks-to={self.temp_dispatch_benchmarks_dir}"
|
||||
)
|
||||
|
||||
if self.is_benchmark == True:
|
||||
from shark.shark_benchmark_runner import SharkBenchmarkRunner
|
||||
|
||||
self.shark_runner = SharkBenchmarkRunner(
|
||||
self.mlir_module,
|
||||
self.function_name,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
# Sets the frontend i.e `pytorch` or `tensorflow`.
|
||||
def set_frontend(self, frontend: str):
|
||||
if frontend not in [
|
||||
"pytorch", "torch", "tensorflow", "tf", "mhlo", "linalg",
|
||||
"tosa", "tflite"
|
||||
]:
|
||||
print_err("frontend not supported.")
|
||||
else:
|
||||
self.shark_runner = SharkRunner(
|
||||
self.mlir_module,
|
||||
self.function_name,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
self.frontend = frontend
|
||||
|
||||
if self.dispatch_benchmarks is not None:
|
||||
create_dispatch_dirs(self.dispatch_benchmarks_dir, self.device)
|
||||
compile_benchmark_dirs(
|
||||
self.dispatch_benchmarks_dir,
|
||||
self.device,
|
||||
self.dispatch_benchmarks,
|
||||
)
|
||||
os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}")
|
||||
def compile(self):
|
||||
# Inference do not use AOT.
|
||||
from_aot = False
|
||||
if (self.benchmark_mode == True):
|
||||
self.shark_runner = SharkBenchmarkRunner(self.model, self.input,
|
||||
self.dynamic, self.device,
|
||||
self.jit_trace, from_aot,
|
||||
self.frontend)
|
||||
else:
|
||||
self.shark_runner = SharkRunner(self.model, self.input,
|
||||
self.dynamic, self.device,
|
||||
self.jit_trace, from_aot,
|
||||
self.frontend,
|
||||
self.model_config_path)
|
||||
|
||||
# inputs are considered to be tuple of np.array.
|
||||
def forward(self, inputs: tuple):
|
||||
return self.shark_runner.run(inputs)
|
||||
# inputs are considered to be np.array.
|
||||
def forward(self, inputs):
|
||||
input_list = inputs
|
||||
# converts the inputs to numpy.
|
||||
if self.frontend in ["pytorch", "torch"]:
|
||||
input_list = [x.detach().numpy() for x in inputs]
|
||||
elif self.frontend in ["tensorflow", "tf"]:
|
||||
input_list = [x.numpy() for x in inputs]
|
||||
return self.shark_runner.forward(input_list, self.frontend)
|
||||
|
||||
# Captures the static input information from the mlir_module.
|
||||
# TODO(pashu123): Generate the input information for dynamic shapes.
|
||||
def _input_info(self):
|
||||
# func_key to get the line which contains the function.
|
||||
func_key = "func.func @" + self.function_name
|
||||
func_header = None
|
||||
for line in str(self.mlir_module).splitlines():
|
||||
if func_key in line:
|
||||
func_header = line
|
||||
break
|
||||
if func_header is None:
|
||||
print(f"Function: {self.function_name} not found")
|
||||
# Saves the .vmfb module.
|
||||
def save_module(self, dir=None):
|
||||
if dir is None:
|
||||
return self.shark_runner.save_module()
|
||||
return self.shark_runner.save_module(dir)
|
||||
|
||||
import re
|
||||
######### Benchmark Related Functions #########
|
||||
def benchmark_mode(func):
|
||||
|
||||
inputs = re.findall("\(.*?\)", func_header)[0].split(",")
|
||||
shapes = []
|
||||
dtype = []
|
||||
for inp in inputs:
|
||||
shape_dtype = re.findall(r"<[^>]*>", inp)[0].split("x")
|
||||
shape_dtype[0], shape_dtype[-1] = (
|
||||
shape_dtype[0][1:],
|
||||
shape_dtype[-1][:-1],
|
||||
)
|
||||
shapes.append(tuple([int(x) for x in shape_dtype[:-1]]))
|
||||
dtype.append(shape_dtype[-1])
|
||||
def inner(self, *args, **kwargs):
|
||||
assert self.benchmark_mode, "SharkRunner needs to be in benchmark mode to run benchmark methods."
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return shapes, dtype
|
||||
return inner
|
||||
|
||||
# Generates random input to be feed into the graph.
|
||||
def generate_random_inputs(self, low=0, high=1):
|
||||
shapes, dtype = self._input_info()
|
||||
inputs = []
|
||||
for i, j in zip(shapes, dtype):
|
||||
inputs.append(
|
||||
np.random.uniform(low, high, size=i).astype(
|
||||
dtype_to_np_dtype[j]
|
||||
)
|
||||
)
|
||||
return tuple(inputs)
|
||||
@benchmark_mode
|
||||
def benchmark_all(self, inputs):
|
||||
self.shark_runner.benchmark_all(inputs)
|
||||
|
||||
# TODO: Instead of passing directory and having names decided by the module
|
||||
# , user may want to save the module with manual names.
|
||||
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
|
||||
return export_iree_module_to_vmfb(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
dir,
|
||||
self.mlir_dialect,
|
||||
self.function_name,
|
||||
module_name=module_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
@benchmark_mode
|
||||
def benchmark_frontend(self, inputs):
|
||||
self.shark_runner.benchmark_frontend(inputs)
|
||||
|
||||
# load and return the module.
|
||||
def load_module(self, path):
|
||||
self.shark_runner = SharkRunner(
|
||||
function_name=self.function_name,
|
||||
device=self.device,
|
||||
compile_vmfb=False,
|
||||
)
|
||||
(
|
||||
self.shark_runner.iree_compilation_module,
|
||||
self.shark_runner.iree_config,
|
||||
) = load_flatbuffer(
|
||||
path,
|
||||
self.device,
|
||||
self.function_name,
|
||||
)
|
||||
return
|
||||
@benchmark_mode
|
||||
def benchmark_python(self, inputs):
|
||||
self.shark_runner.benchmark_python(inputs)
|
||||
|
||||
@benchmark_mode
|
||||
def benchmark_c(self):
|
||||
self.shark_runner.benchmark_c()
|
||||
|
||||
@@ -11,90 +11,195 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from iree.compiler import tf as tfc
|
||||
import iree.compiler.tflite as ireec_tflite
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode
|
||||
from torch_mlir.eager_mode import torch_mlir_tensor
|
||||
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
||||
from torch_mlir_e2e_test.eager_backends.refbackend import EagerModeRefBackend
|
||||
|
||||
from shark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
get_results,
|
||||
export_iree_module_to_vmfb,
|
||||
load_flatbuffer,
|
||||
)
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.parser import shark_args
|
||||
from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module, run_on_refbackend
|
||||
from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb, export_module_to_mlir_file, build_benchmark_args, run_benchmark_module
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
# supported dialects by the shark-runtime.
|
||||
supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite", "tm_tensor"}
|
||||
from shark.parser import shark_args
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
|
||||
|
||||
class SharkRunner:
|
||||
"""
|
||||
Base class for SharkInference and SharkTrainer
|
||||
used to execute an mlir_module.
|
||||
|
||||
...
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string.
|
||||
function_name : str
|
||||
function to execute in the given mlir_module.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
mlir_dialect: str
|
||||
The dialect in which the given mlir_module is in.
|
||||
Refer to {https://mlir.llvm.org/docs/Dialects/}
|
||||
|
||||
Methods
|
||||
-------
|
||||
run(inputs=None):
|
||||
Runs the mlir_module with the given inputs, if the inputs are not
|
||||
given it autogenerates the inputs. Also, the inputs should be a
|
||||
numpy array.
|
||||
input_info():
|
||||
Gives the information about the inputs required by the `function_name`.
|
||||
This can be expensive as it does string matching to do so.
|
||||
"""
|
||||
"""Base class for Shark Inference and Shark Runner."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: bytes = None,
|
||||
function_name: str = "forward",
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
extra_args: list = [],
|
||||
compile_vmfb: bool = True,
|
||||
model,
|
||||
input: tuple,
|
||||
dynamic: bool = False,
|
||||
device: str = None,
|
||||
jit_trace: bool = False,
|
||||
from_aot: bool = False,
|
||||
frontend: str = "torch",
|
||||
model_config_path: str = None,
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
self.function_name = function_name
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
self.model = model
|
||||
self.frontend_model = model
|
||||
self.from_aot = from_aot
|
||||
self.input = input
|
||||
self.frontend = frontend
|
||||
self.vmfb_file = None
|
||||
func_name = "forward"
|
||||
self.device = device if device is not None else shark_args.device
|
||||
if self.frontend in ["pytorch", "torch"]:
|
||||
# get torch-mlir dialect
|
||||
# self.model = torch.Module
|
||||
# TODO assert
|
||||
self.model = get_torch_mlir_module(self.model, input, dynamic,
|
||||
jit_trace, from_aot)
|
||||
elif self.frontend in ["tensorflow", "tf"]:
|
||||
# get mhlo dialect
|
||||
# self.model = tf.Module
|
||||
# TODO assert
|
||||
self.model = tfc.compile_module(self.model,
|
||||
exported_names=[func_name],
|
||||
import_only=True)
|
||||
elif self.frontend in ["tflite"]:
|
||||
print("Setting up for IREE compiler tflite")
|
||||
# get tosa dialect
|
||||
# self.model = model.tflite
|
||||
# TODO assert
|
||||
self.model = ireec_tflite.compile_file(self.model,
|
||||
input_type="tosa",
|
||||
import_only=True)
|
||||
func_name = "main"
|
||||
|
||||
if check_device_drivers(self.device):
|
||||
device_driver_info(self.device)
|
||||
sys.exit(1)
|
||||
|
||||
if compile_vmfb == True:
|
||||
# Compile the module to get the .vmfb.
|
||||
(
|
||||
self.iree_compilation_module,
|
||||
self.iree_config,
|
||||
) = get_iree_compiled_module(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
func_name=self.function_name,
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
|
||||
def run(self, inputs: tuple):
|
||||
return get_results(
|
||||
# TODO: We can capture the .vmfb module here and later use it for saving
|
||||
# rather than recompiling it again, if used for saving.
|
||||
(
|
||||
self.iree_compilation_module,
|
||||
inputs,
|
||||
self.iree_config,
|
||||
self.mlir_dialect,
|
||||
) = get_iree_compiled_module(self.model,
|
||||
self.device,
|
||||
self.frontend,
|
||||
func_name=func_name,
|
||||
model_config_path=model_config_path)
|
||||
|
||||
# Debugging Options:
|
||||
if shark_args.save_mlir:
|
||||
export_module_to_mlir_file(self.model, self.frontend,
|
||||
shark_args.repro_dir)
|
||||
if shark_args.save_vmfb:
|
||||
self.vmfb_file = self.save_module(shark_args.repro_dir)
|
||||
|
||||
# All the timings and benchmarking can be done here.
|
||||
def forward(self, input, frontend):
|
||||
return get_results(self.iree_compilation_module, input,
|
||||
self.iree_config, frontend)
|
||||
|
||||
# TODO: Instead of passing directory and having names decided by the module
|
||||
# , user may want to save the module with manual names.
|
||||
def save_module(self, dir=os.getcwd()):
|
||||
return export_iree_module_to_vmfb(self.model, self.device, dir,
|
||||
self.frontend)
|
||||
|
||||
# TODO: Load a module and directly use it, we will need to set the frontend
|
||||
# in this case.
|
||||
def load_module(self, name):
|
||||
pass
|
||||
|
||||
|
||||
class SharkEagerMode:
|
||||
|
||||
def __init__(self, device="cpu"):
|
||||
if device == "refbackend":
|
||||
torch_mlir_tensor.backend = EagerModeRefBackend()
|
||||
else:
|
||||
torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend(
|
||||
device)
|
||||
self.guard = enable_torch_dispatch_mode(TorchMLIRTensor)
|
||||
self.guard.__enter__()
|
||||
|
||||
def __del__(self):
|
||||
self.guard.__exit__(None, None, None)
|
||||
|
||||
|
||||
class SharkBenchmarkRunner(SharkRunner):
|
||||
# SharkRunner derived class with Benchmarking capabilities.
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
input: tuple,
|
||||
dynamic: bool = False,
|
||||
device: str = None,
|
||||
jit_trace: bool = False,
|
||||
from_aot: bool = False,
|
||||
frontend: str = "torch",
|
||||
):
|
||||
SharkRunner.__init__(self, model, input, dynamic, device, jit_trace,
|
||||
from_aot, frontend)
|
||||
if (self.vmfb_file == None):
|
||||
self.vmfb_file = export_iree_module_to_vmfb(self.model, device,
|
||||
shark_args.repro_dir,
|
||||
frontend)
|
||||
self.benchmark_cl = build_benchmark_args(self.vmfb_file, device, input,
|
||||
frontend, from_aot)
|
||||
|
||||
def benchmark_frontend(self, inputs):
|
||||
if self.frontend in ["pytorch", "torch"]:
|
||||
self.benchmark_torch(inputs)
|
||||
elif self.frontend in ["tensorflow", "tf"]:
|
||||
self.benchmark_tf(inputs)
|
||||
|
||||
def benchmark_torch(self, inputs):
|
||||
inputs = self.input if self.from_aot else inputs
|
||||
inputs = inputs[0]
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
self.frontend_model.forward(inputs)
|
||||
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = self.frontend_model.forward(inputs)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
print(
|
||||
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
|
||||
def benchmark_tf(self, inputs):
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
self.frontend_model.forward(*inputs)
|
||||
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = self.frontend_model.forward(*inputs)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
print(
|
||||
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return
|
||||
|
||||
def benchmark_c(self):
|
||||
result = run_benchmark_module(self.benchmark_cl)
|
||||
print(f"Shark-{self.frontend} C-benchmark:{result} iter/second")
|
||||
|
||||
def benchmark_python(self, inputs):
|
||||
inputs = self.input if self.from_aot else inputs
|
||||
input_list = [x for x in inputs]
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
self.forward(input_list, self.frontend)
|
||||
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = self.forward(input_list, self.frontend)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
print(
|
||||
f"Shark-{self.frontend} Python-benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
|
||||
def benchmark_all(self, inputs):
|
||||
self.benchmark_frontend(inputs)
|
||||
self.benchmark_python(inputs)
|
||||
self.benchmark_c()
|
||||
|
||||
@@ -12,11 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module, run_on_refbackend
|
||||
from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb
|
||||
import os
|
||||
from shark.parser import shark_args
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.backward_makefx import MakeFxModule
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import sys
|
||||
|
||||
|
||||
@@ -54,13 +58,7 @@ class SharkTrainer:
|
||||
# Sets the frontend i.e `pytorch` or `tensorflow`.
|
||||
def set_frontend(self, frontend: str):
|
||||
if frontend not in [
|
||||
"pytorch",
|
||||
"torch",
|
||||
"tensorflow",
|
||||
"tf",
|
||||
"mhlo",
|
||||
"linalg",
|
||||
"tosa",
|
||||
"pytorch", "torch", "tensorflow", "tf", "mhlo", "linalg", "tosa"
|
||||
]:
|
||||
print_err("frontend not supported.")
|
||||
else:
|
||||
@@ -69,32 +67,22 @@ class SharkTrainer:
|
||||
# Training function is needed in the case of torch_fn.
|
||||
def compile(self, training_fn=None):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
aot_module = MakeFxModule(
|
||||
self.model, tuple(self.input), custom_inference_fn=training_fn
|
||||
)
|
||||
aot_module = MakeFxModule(self.model,
|
||||
tuple(self.input),
|
||||
custom_inference_fn=training_fn)
|
||||
aot_module.generate_graph()
|
||||
# Returns the backward graph.
|
||||
training_graph = aot_module.training_graph
|
||||
weights = self.get_torch_params()
|
||||
self.shark_runner = SharkRunner(
|
||||
training_graph,
|
||||
weights + self.input,
|
||||
self.dynamic,
|
||||
self.device,
|
||||
self.jit_trace,
|
||||
self.from_aot,
|
||||
self.frontend,
|
||||
)
|
||||
self.shark_runner = SharkRunner(training_graph,
|
||||
weights + self.input, self.dynamic,
|
||||
self.device, self.jit_trace,
|
||||
self.from_aot, self.frontend)
|
||||
elif self.frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
self.shark_runner = SharkRunner(
|
||||
self.model,
|
||||
self.input,
|
||||
self.dynamic,
|
||||
self.device,
|
||||
self.jit_trace,
|
||||
self.from_aot,
|
||||
self.frontend,
|
||||
)
|
||||
self.shark_runner = SharkRunner(self.model, self.input,
|
||||
self.dynamic, self.device,
|
||||
self.jit_trace, self.from_aot,
|
||||
self.frontend)
|
||||
else:
|
||||
print_err("Unknown frontend")
|
||||
return
|
||||
@@ -112,9 +100,8 @@ class SharkTrainer:
|
||||
params = [x.numpy() for x in params]
|
||||
print(f"Training started for {num_iters} iterations:")
|
||||
for i in tqdm(range(num_iters)):
|
||||
params = self.shark_runner.forward(
|
||||
params + self.input, self.frontend
|
||||
)
|
||||
params = self.shark_runner.forward(params + self.input,
|
||||
self.frontend)
|
||||
|
||||
return params
|
||||
|
||||
@@ -124,15 +111,15 @@ class SharkTrainer:
|
||||
def _train_tf(self, num_iters):
|
||||
input_list = []
|
||||
for x in self.input:
|
||||
if isinstance(x, list):
|
||||
if (isinstance(x, list)):
|
||||
nested_list = []
|
||||
for val in x:
|
||||
if isinstance(val, np.ndarray):
|
||||
if (isinstance(val, np.ndarray)):
|
||||
nested_list.append(val)
|
||||
else:
|
||||
nested_list.append(val.numpy())
|
||||
input_list.append(nested_list)
|
||||
elif isinstance(x, np.ndarray):
|
||||
elif (isinstance(x, np.ndarray)):
|
||||
input_list.append(x)
|
||||
else:
|
||||
input_list.append(x.numpy())
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
1. Install torchdynamo
|
||||
- `git clone https://github.com/pytorch/torchdynamo.git`
|
||||
- `cd torchdynamo`
|
||||
- `python -m pip install -r requirements.txt`
|
||||
- `python setup.py develop`
|
||||
|
||||
2. Install functorch
|
||||
- `python -m pip install -v "git+https://github.com/pytorch/pytorch.git@$(python -c "import torch.version; print(torch.version.git_version)")#subdirectory=functorch"`
|
||||
|
||||
3. Run examples.
|
||||
- `python shark/examples/shark_dynamo/basic_examples.py`
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user