mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 14:58:11 -05:00
Compare commits
186 Commits
github-pag
...
20221213.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48ec11c514 | ||
|
|
8ae76d18b5 | ||
|
|
e5be1790e5 | ||
|
|
e64aa40b17 | ||
|
|
eb8114ece8 | ||
|
|
616ee9b824 | ||
|
|
57c94f8f80 | ||
|
|
2a59c4f670 | ||
|
|
192ff487c4 | ||
|
|
b62ee3fcb9 | ||
|
|
0225292a44 | ||
|
|
589a7ed02f | ||
|
|
b3a42cd0b1 | ||
|
|
e3e1ca7cc6 | ||
|
|
57e417d174 | ||
|
|
1699db79b5 | ||
|
|
dab9403b8f | ||
|
|
9a14298146 | ||
|
|
40eea21863 | ||
|
|
d2475ec169 | ||
|
|
b3bcf4bf44 | ||
|
|
6049f86bc4 | ||
|
|
ff649b52ef | ||
|
|
e9e138c757 | ||
|
|
1096936a15 | ||
|
|
29cc478525 | ||
|
|
05e9eb40b5 | ||
|
|
c4444ff695 | ||
|
|
27b34f3929 | ||
|
|
2b8d784660 | ||
|
|
18f447d8d8 | ||
|
|
d7e1078d68 | ||
|
|
6be592653f | ||
|
|
8859853b41 | ||
|
|
3c46021102 | ||
|
|
bba8646669 | ||
|
|
b0dc19a910 | ||
|
|
df79ebd0f2 | ||
|
|
e19a97f316 | ||
|
|
482ffd6275 | ||
|
|
5117e50602 | ||
|
|
83b138208d | ||
|
|
1870cb4557 | ||
|
|
42ad5b9c5c | ||
|
|
333975eb8f | ||
|
|
aa0195e4ef | ||
|
|
56109fe09b | ||
|
|
e74046478b | ||
|
|
aa5a60812f | ||
|
|
ebb60019aa | ||
|
|
6393dc5d14 | ||
|
|
8c158f2452 | ||
|
|
8c3eabdcee | ||
|
|
8aa0ce6a24 | ||
|
|
a27ee141b3 | ||
|
|
1106456651 | ||
|
|
8856878cbd | ||
|
|
a9bac0287d | ||
|
|
efbd3dc778 | ||
|
|
a0d0eaa408 | ||
|
|
e2bf734b67 | ||
|
|
a333a90441 | ||
|
|
6dc0057d3d | ||
|
|
0f9e69d48c | ||
|
|
e6a7c019ab | ||
|
|
1d32eabd14 | ||
|
|
53d03f06a6 | ||
|
|
a2d8c40455 | ||
|
|
4f7d950c8d | ||
|
|
cac54b8c26 | ||
|
|
cd0e881d7d | ||
|
|
fee406e220 | ||
|
|
128342f47f | ||
|
|
024487c5fe | ||
|
|
879ba27ccb | ||
|
|
6d6d9627e7 | ||
|
|
af4bc82543 | ||
|
|
439a18bcc3 | ||
|
|
e12a1e0444 | ||
|
|
4400b0d3c3 | ||
|
|
5dff28ff99 | ||
|
|
d5ac841a1a | ||
|
|
232ce12e9b | ||
|
|
9a8638a6d0 | ||
|
|
a5445866b8 | ||
|
|
e8ded71a7b | ||
|
|
a14c615def | ||
|
|
3903b6ff0c | ||
|
|
41bf262482 | ||
|
|
645b658da0 | ||
|
|
6ee8f61fbe | ||
|
|
3c4c4231ce | ||
|
|
d0eef19eba | ||
|
|
6ca2eb3ad7 | ||
|
|
74aeb55733 | ||
|
|
3eb7965ca0 | ||
|
|
04f20070d1 | ||
|
|
88937fcb2f | ||
|
|
f80b85f10c | ||
|
|
32a2ec432d | ||
|
|
f4821d0d39 | ||
|
|
fdf2aa54ef | ||
|
|
275c032264 | ||
|
|
d88979fe19 | ||
|
|
e67bcffea7 | ||
|
|
005ded3c6f | ||
|
|
d624940e12 | ||
|
|
7763403b0e | ||
|
|
88c58244b9 | ||
|
|
0754c6ea20 | ||
|
|
7b1f04d121 | ||
|
|
d8a9bee244 | ||
|
|
ac0ea6bd3c | ||
|
|
45677c1e23 | ||
|
|
d9f4a9954a | ||
|
|
ec461a4456 | ||
|
|
559928e93b | ||
|
|
a526f7d5b8 | ||
|
|
749a2c2dec | ||
|
|
29a317dbb6 | ||
|
|
2f36de319a | ||
|
|
2005bce419 | ||
|
|
8a02d7729d | ||
|
|
1cdf301c14 | ||
|
|
9a86e5c476 | ||
|
|
32d3f4bd5f | ||
|
|
18689afc1a | ||
|
|
64d6da75c7 | ||
|
|
1e95e4b502 | ||
|
|
c63009a6db | ||
|
|
88f8718635 | ||
|
|
a081733a42 | ||
|
|
06ccfb0533 | ||
|
|
b18d75e3f7 | ||
|
|
3e7efaa048 | ||
|
|
a3fdfc81db | ||
|
|
f4c91df1df | ||
|
|
32e1ba8c0d | ||
|
|
1939376d72 | ||
|
|
25931d48a3 | ||
|
|
024c5e153a | ||
|
|
83f34b645d | ||
|
|
3f9f450e0d | ||
|
|
fd89b06641 | ||
|
|
f8dc996004 | ||
|
|
e6a964088b | ||
|
|
e3e767c7eb | ||
|
|
239c19eb12 | ||
|
|
7f37599a60 | ||
|
|
77c9a2c5ea | ||
|
|
fd7baae548 | ||
|
|
01fdf5ee16 | ||
|
|
e52f533c16 | ||
|
|
fbd77dc936 | ||
|
|
cdc6dd19e3 | ||
|
|
fd578a48a9 | ||
|
|
9956099516 | ||
|
|
f97b8fffed | ||
|
|
7b9e309724 | ||
|
|
1d33913d48 | ||
|
|
a48eaaed20 | ||
|
|
2741b8be53 | ||
|
|
4f906a265c | ||
|
|
0dff8d7af0 | ||
|
|
4f0d0d8167 | ||
|
|
d513060b21 | ||
|
|
d1a25ce4f3 | ||
|
|
51c98695b2 | ||
|
|
b448770ec2 | ||
|
|
5fe22a7980 | ||
|
|
38ae6b5af4 | ||
|
|
0bfe30d75d | ||
|
|
7be1d7d0be | ||
|
|
0d74c873f0 | ||
|
|
139aff2938 | ||
|
|
a3f733490c | ||
|
|
8a11f138d1 | ||
|
|
3405607917 | ||
|
|
7c99a6bd33 | ||
|
|
3fba8ce0e6 | ||
|
|
f3bde3c7fc | ||
|
|
21fee8ef33 | ||
|
|
0e217d6180 | ||
|
|
00a8ce75d1 | ||
|
|
8f3f00cd99 | ||
|
|
13bae2538a |
2
.github/workflows/gh-pages-releases.yml
vendored
2
.github/workflows/gh-pages-releases.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
- run: git fetch --all
|
||||
- run: git switch github-pages
|
||||
- run: git config --global user.email "none@none.com"
|
||||
- run: git config --global user.name "nod-team"
|
||||
- run: git config --global user.name "nod-ai"
|
||||
- run: mv /tmp/index.html package-index/index.html
|
||||
- run: git add package-index/index.html
|
||||
|
||||
|
||||
137
.github/workflows/nightly.yml
vendored
137
.github/workflows/nightly.yml
vendored
@@ -9,7 +9,80 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
windows-build:
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Compute version
|
||||
shell: powershell
|
||||
run: |
|
||||
$package_version = $(Get-Date -UFormat "%Y%m%d")+"."+${{ github.run_number }}
|
||||
$package_version_ = $(Get-Date -UFormat "%Y%m%d")+"_"+${{ github.run_number }}
|
||||
$tag_name=$package_version
|
||||
echo "package_version=$package_version" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
echo "package_version_=$package_version_" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
echo "tag_name=$tag_name" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ env.tag_name }}
|
||||
release_name: nod.ai SHARK ${{ env.tag_name }}
|
||||
body: |
|
||||
Automatic snapshot release of nod.ai SHARK.
|
||||
draft: true
|
||||
prerelease: false
|
||||
|
||||
- name: Build Package
|
||||
shell: powershell
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
pyinstaller web/shark_sd.spec
|
||||
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
|
||||
|
||||
# GHA windows VM OOMs so disable for now
|
||||
#- name: Build and validate the SHARK Runtime package
|
||||
# shell: powershell
|
||||
# run: |
|
||||
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
- uses: actions/upload-artifact@v2
|
||||
with:
|
||||
path: dist/*
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ./dist/*
|
||||
|
||||
- name: Publish Release
|
||||
id: publish_release
|
||||
uses: eregon/publish-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
|
||||
linux-build:
|
||||
|
||||
runs-on: a100
|
||||
strategy:
|
||||
@@ -32,40 +105,12 @@ jobs:
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Compute version
|
||||
run: |
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
tag_name="${package_version}"
|
||||
echo "package_version=${package_version}" >> $GITHUB_ENV
|
||||
echo "tag_name=${tag_name}" >> $GITHUB_ENV
|
||||
- name: Set Environment Variables
|
||||
run: |
|
||||
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ env.tag_name }}
|
||||
release_name: nod.ai SHARK ${{ env.tag_name }}
|
||||
body: |
|
||||
Automatic snapshot release of nod.ai SHARK.
|
||||
draft: true
|
||||
prerelease: false
|
||||
- name: Find Torch-MLIR Release
|
||||
run: |
|
||||
TM_HTML_URL="$(python3 -c "import urllib.request, json, sys; u=json.loads(urllib.request.urlopen('https://api.github.com/repos/llvm/torch-mlir/releases/latest').read().decode()).get('html_url', False); print(u) if u else sys.exit(1);")"
|
||||
TM_RELEASE_DIR=${TM_HTML_URL/"tag"/"expanded_assets"}
|
||||
echo "TM_RELEASE_DIR=${TM_RELEASE_DIR}" >> $GITHUB_ENV
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
echo "Torch-MLIR Release DIR is ${{ env.TM_RELEASE_DIR }}"
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f ${{ env.TM_RELEASE_DIR }} -f https://github.com/nod-ai/SHARK-Runtime/releases; fi
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
@@ -74,13 +119,14 @@ jobs:
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
|
||||
- name: Build and validate the IREE package
|
||||
if: ${{ matrix.backend == 'IREE' }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
|
||||
source iree.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f ${{ env.TM_RELEASE_DIR }} -f https://github.com/iree-org/iree/releases
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://iree-org.github.io/iree/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
@@ -91,8 +137,8 @@ jobs:
|
||||
if !(grep -Fxq " failed" pytest_results.txt)
|
||||
then
|
||||
export SHA=$(git log -1 --format='%h')
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/$SHA
|
||||
gsutil -m cp -r gs://shark_tank/$SHA/* gs://shark_tank/latest/
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
|
||||
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/latest/
|
||||
fi
|
||||
rm -rf ./wheelhouse/nodai*
|
||||
|
||||
@@ -104,29 +150,10 @@ jobs:
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f ${{ env.TM_RELEASE_DIR }} -f https://github.com/nod-ai/SHARK-Runtime/releases
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" tank/test_models.py |
|
||||
pytest --ci --ci_sha=${SHORT_SHA} tank/test_models.py |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
|
||||
- name: Upload Release Assets
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ${GITHUB_WORKSPACE}/wheelhouse/nodai_*.whl
|
||||
|
||||
- name: Publish Release
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
id: publish_release
|
||||
uses: eregon/publish-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
|
||||
37
.github/workflows/test-models.yml
vendored
37
.github/workflows/test-models.yml
vendored
@@ -10,6 +10,14 @@ on:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
|
||||
# Ensure that only a single job or workflow using the same
|
||||
# concurrency group will run at a time. This would cancel
|
||||
# any in-progress jobs in the same github workflow and github
|
||||
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-validate:
|
||||
strategy:
|
||||
@@ -32,8 +40,6 @@ jobs:
|
||||
suite: cuda
|
||||
- os: MacStudio
|
||||
suite: cpu
|
||||
- os: MacStudio
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
@@ -90,7 +96,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k cpu
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cpu --update_tank
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
|
||||
@@ -100,14 +106,29 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k cuda
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cuda --update_tank
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
|
||||
- name: Validate Vulkan Models
|
||||
if: matrix.suite == 'vulkan'
|
||||
- name: Validate Vulkan Models (MacOS)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k vulkan
|
||||
echo "VULKAN SDK PATH wo setup: $VULKAN_SDK"
|
||||
cd /Users/anush/VulkanSDK/1.3.224.1/
|
||||
source setup-env.sh
|
||||
cd $GITHUB_WORKSPACE
|
||||
echo "VULKAN SDK PATH with setup: $VULKAN_SDK"
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" tank/test_models.py -k vulkan --update_tank
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os != 'MacStudio'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k vulkan --update_tank
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -31,7 +31,6 @@ MANIFEST
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
@@ -163,7 +162,13 @@ cython_debug/
|
||||
# Shark related artefacts
|
||||
*venv/
|
||||
shark_tmp/
|
||||
*.vmfb
|
||||
.use-iree
|
||||
|
||||
# ORT related artefacts
|
||||
cache_models/
|
||||
onnx_models/
|
||||
|
||||
#web logging
|
||||
web/logs/
|
||||
web/stored_results/stable_diffusion/
|
||||
|
||||
412
README.md
412
README.md
@@ -5,25 +5,123 @@ High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerator
|
||||
[](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
|
||||
[](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
|
||||
|
||||
## Communication Channels
|
||||
|
||||
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
|
||||
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
|
||||
## Installation (Windows, Linux and macOS)
|
||||
|
||||
## Check out the code
|
||||
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
cd SHARK
|
||||
```
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
|
||||
### Windows 10/11 Users
|
||||
|
||||
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
|
||||
|
||||
* Install Git for Windows from [here](https://git-scm.com/download/win)
|
||||
|
||||
#### Allow the install script to run in Powershell
|
||||
```powershell
|
||||
set-executionpolicy remotesigned
|
||||
```
|
||||
|
||||
#### Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...)
|
||||
```powershell
|
||||
./setup_venv.ps1 #You can re-run this script to get the latest version
|
||||
```
|
||||
|
||||
### Linux / macOS Users
|
||||
|
||||
```shell
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
```
|
||||
|
||||
|
||||
## Installation
|
||||
### Run Stable Diffusion on your device - WebUI
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\Users\nod\SHARK> cd web
|
||||
(shark.venv) PS C:\Users\nod\SHARK\web> python index.py
|
||||
```
|
||||
#### Linux Users
|
||||
```shell
|
||||
(shark.venv) > cd web
|
||||
(shark.venv) > python index.py
|
||||
```
|
||||
|
||||
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
|
||||
|
||||
|
||||
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
|
||||
|
||||
|
||||
|
||||
### Run Stable Diffusion on your device - Commandline
|
||||
|
||||
#### Install your hardware drivers
|
||||
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mril-iree)
|
||||
* [macOS Users] Download and install the latest Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home)
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
Other users please ensure you have your latest vendor drivers and Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home) and if you are using vulkan check `vulkaninfo` works in a terminal window
|
||||
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> python .\shark\examples\shark_inference\stable_diffusion\main.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
```
|
||||
|
||||
#### Linux / macOS Users
|
||||
```shell
|
||||
python3.10 shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
```
|
||||
|
||||
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
|
||||
|
||||
The output on a 6900XT would like:
|
||||
|
||||
```shell
|
||||
44it [00:08, 5.14it/s]i = 44 t = 120 (191ms)
|
||||
45it [00:08, 5.15it/s]i = 45 t = 100 (191ms)
|
||||
46it [00:08, 5.16it/s]i = 46 t = 80 (191ms)
|
||||
47it [00:09, 5.16it/s]i = 47 t = 60 (193ms)
|
||||
48it [00:09, 5.15it/s]i = 48 t = 40 (195ms)
|
||||
49it [00:09, 5.12it/s]i = 49 t = 20 (196ms)
|
||||
50it [00:09, 5.14it/s]
|
||||
Average step time: 192.8154182434082ms/it
|
||||
Total image generation runtime (s): 10.390909433364868
|
||||
(shark.venv) PS C:\g\shark>
|
||||
```
|
||||
|
||||
Here are some samples generated:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)
|
||||
|
||||
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Installation (Linux and macOS)</summary>
|
||||
<summary>Binary Installation</summary>
|
||||
|
||||
### Setup a new pip Virtual Environment
|
||||
|
||||
This step sets up a new VirtualEnv for Python
|
||||
|
||||
```shell
|
||||
python --version #Check you have 3.7->3.10 on Linux or 3.10 on macOS
|
||||
python --version #Check you have 3.10 on Linux, macOS or Windows Powershell
|
||||
python -m venv shark_venv
|
||||
source shark_venv/bin/activate
|
||||
source shark_venv/bin/activate # Use shark_venv/Scripts/activate on Windows
|
||||
|
||||
# If you are using conda create and activate a new conda env
|
||||
|
||||
@@ -38,9 +136,14 @@ python -m pip install --upgrade pip
|
||||
This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10
|
||||
|
||||
```shell
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://github.com/nod-ai/shark-runtime/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
If you are on an Intel macOS machine you need this [workaround](https://github.com/nod-ai/SHARK/issues/102) for an upstream issue.
|
||||
|
||||
### Run shark tank model tests.
|
||||
```shell
|
||||
pytest tank/test_models.py
|
||||
```
|
||||
See tank/README.md for a more detailed walkthrough of our pytest suite and CLI.
|
||||
|
||||
### Download and run Resnet50 sample
|
||||
|
||||
@@ -61,29 +164,27 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Source Installation</summary>
|
||||
<summary>Development, Testing and Benchmarks</summary>
|
||||
|
||||
## Check out the code
|
||||
If you want to use Python3.10 and with TF Import tools you can use the environment variables like:
|
||||
Set `USE_IREE=1` to use upstream IREE
|
||||
```
|
||||
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
|
||||
```
|
||||
|
||||
### Run any of the hundreds of SHARK tank models via the test framework
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
```
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
```shell
|
||||
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
```
|
||||
For example if you want to use Python3.10 and upstream IREE with TF Import tools you can use the environment variables like:
|
||||
```
|
||||
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 USE_IREE=1 ./setup_venv.sh
|
||||
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
|
||||
# Or a pytest
|
||||
pytest tank/test_models.py -k "MiniLM"
|
||||
```
|
||||
|
||||
|
||||
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
|
||||
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
|
||||
with Python bindings and set your PYTHONPATH as mentioned [here](https://google.github.io/iree/bindings/python/)
|
||||
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
|
||||
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
|
||||
for Torch-MLIR.
|
||||
|
||||
@@ -102,82 +203,39 @@ for Torch-MLIR.
|
||||
```
|
||||
Now the SHARK will use your locally build Torch-MLIR repo.
|
||||
|
||||
### Run a demo script
|
||||
```shell
|
||||
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
|
||||
# Or a pytest
|
||||
pytest tank/test_models.py -k "MiniLM"
|
||||
|
||||
## Benchmarking Dispatches
|
||||
|
||||
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your command line argument.
|
||||
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
|
||||
|
||||
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
|
||||
|
||||
```
|
||||
shark_module = SharkInference(
|
||||
mlir_model,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
dispatch_benchmarks="all",
|
||||
dispatch_benchmarks_dir="results"
|
||||
)
|
||||
```
|
||||
|
||||
Output will include:
|
||||
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
|
||||
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
|
||||
- An .mlir file containing the dispatch benchmark
|
||||
- A compiled .vmfb file containing the dispatch benchmark
|
||||
- An .mlir file containing just the hal executable
|
||||
- A compiled .vmfb file of the hal executable
|
||||
- A .txt file containing benchmark output
|
||||
|
||||
|
||||
See tank/README.md for instructions on how to run model tests and benchmarks from the SHARK tank.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Testing and Benchmarks</summary>
|
||||
|
||||
### Run all model tests on CPU/GPU/VULKAN/Metal
|
||||
```shell
|
||||
pytest tank/test_models.py
|
||||
|
||||
# If on Linux for multithreading on CPU (faster results):
|
||||
pytest tank/test_models.py -n auto
|
||||
```
|
||||
|
||||
### Running specific tests
|
||||
```shell
|
||||
|
||||
# Search for test cases by including a keyword that matches all or part of the test case's name;
|
||||
pytest tank/test_models.py -k "keyword"
|
||||
|
||||
# Test cases are named uniformly by format test_module_<model_name_underscores_only>_<torch/tf>_<static/dynamic>_<device>.
|
||||
|
||||
# Example: Test all models on nvidia gpu:
|
||||
pytest tank/test_models.py -k "cuda"
|
||||
|
||||
# Example: Test all tensorflow resnet models on Vulkan backend:
|
||||
pytest tank/test_models.py -k "resnet and tf and vulkan"
|
||||
|
||||
# Exclude a test case:
|
||||
pytest tank/test_models.py -k "not ..."
|
||||
|
||||
### Run benchmarks on SHARK tank pytests and generate bench_results.csv with results.
|
||||
|
||||
(the following requires source installation with `IMPORTER=1 ./setup_venv.sh`)
|
||||
|
||||
```shell
|
||||
pytest --benchmark tank/test_models.py
|
||||
|
||||
# Just do static GPU benchmarks for PyTorch tests:
|
||||
pytest --benchmark tank/test_models.py -k "pytorch and static and cuda"
|
||||
|
||||
```
|
||||
|
||||
### Benchmark Resnet50, MiniLM on CPU
|
||||
|
||||
(requires source installation with `IMPORTER=1 ./setup_venv.sh`)
|
||||
|
||||
```shell
|
||||
# We suggest running the following commands as root before running benchmarks on CPU:
|
||||
|
||||
cat /sys/devices/system/cpu/cpu*/topology/thread_siblings_list | awk -F, '{print $2}' | sort -n | uniq | ( while read X ; do echo $X ; echo 0 > /sys/devices/system/cpu/cpu$X/online ; done )
|
||||
echo 1 > /sys/devices/system/cpu/intel_pstate/no_turbo
|
||||
|
||||
# Benchmark canonical Resnet50 on CPU via pytest
|
||||
pytest --benchmark tank/test_models -k "resnet50 and tf_static_cpu"
|
||||
|
||||
# Benchmark canonical MiniLM on CPU via pytest
|
||||
pytest --benchmark tank/test_models -k "MiniLM and cpu"
|
||||
|
||||
# Benchmark MiniLM on CPU via transformer-benchmarks:
|
||||
git clone --recursive https://github.com/nod-ai/transformer-benchmarks.git
|
||||
cd transformer-benchmarks
|
||||
./perf-ci.sh -n
|
||||
# Check detail.csv for MLIR/IREE results.
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>API Reference</summary>
|
||||
|
||||
@@ -228,160 +286,26 @@ result = shark_module.forward((arg0, arg1))
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
## Supported and Validated Models
|
||||
|
||||
<details>
|
||||
<summary>PyTorch Models</summary>
|
||||
SHARK is maintained to support the latest innovations in ML Models:
|
||||
|
||||
### Huggingface PyTorch Models
|
||||
| TF HuggingFace Models | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DistilBERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| GPT2 | :green_heart: | :green_heart: | :green_heart: |
|
||||
| BLOOM | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Stable Diffusion | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Vision Transformer | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ResNet50 | :green_heart: | :green_heart: | :green_heart: |
|
||||
|
||||
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Albert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| BigBird | :green_heart: (AOT) | | | |
|
||||
| DistilBERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| GPT2 | :broken_heart: (AOT) | | | |
|
||||
| MobileBert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
For a complete list of the models supported in SHARK, please refer to [tank/README.md](https://github.com/nod-ai/SHARK/blob/main/tank/README.md).
|
||||
|
||||
### Torchvision Models
|
||||
## Communication Channels
|
||||
|
||||
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|--------------------|----------------------|----------|----------|-------------|
|
||||
| AlexNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DenseNet121 | :green_heart: (Script) | | | |
|
||||
| MNasNet1_0 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| MobileNetV2 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| MobileNetV3 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Unet | :broken_heart: (Script) | | | |
|
||||
| Resnet18 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnet50 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnet101 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnext50_32x4d | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ShuffleNet_v2 | :broken_heart: (Script) | | | |
|
||||
| SqueezeNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| EfficientNet | :green_heart: (Script) | | | |
|
||||
| Regnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnest | :broken_heart: (Script) | | | |
|
||||
| Vision Transformer | :green_heart: (Script) | | | |
|
||||
| VGG 16 | :green_heart: (Script) | :green_heart: | :green_heart: | |
|
||||
| Wide Resnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| RAFT | :broken_heart: (JIT) | | | |
|
||||
|
||||
For more information refer to [MODEL TRACKING SHEET](https://docs.google.com/spreadsheets/d/15PcjKeHZIrB5LfDyuw7DGEEE8XnQEX2aX8lm8qbxV8A/edit#gid=0)
|
||||
|
||||
### PyTorch Training Models
|
||||
|
||||
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :broken_heart: | :broken_heart: | | |
|
||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>JAX Models</summary>
|
||||
|
||||
|
||||
### JAX Models
|
||||
|
||||
| Models | JAX-MHLO lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| DALL-E | :broken_heart: | :broken_heart: | | |
|
||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>TFLite Models</summary>
|
||||
|
||||
### TFLite Models
|
||||
|
||||
| Models | TOSA/LinAlg | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :broken_heart: | :broken_heart: | | |
|
||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||
| albert | :green_heart: | :green_heart: | | |
|
||||
| asr_conformer | :green_heart: | :green_heart: | | |
|
||||
| bird_classifier | :green_heart: | :green_heart: | | |
|
||||
| cartoon_gan | :green_heart: | :green_heart: | | |
|
||||
| craft_text | :green_heart: | :green_heart: | | |
|
||||
| deeplab_v3 | :green_heart: | :green_heart: | | |
|
||||
| densenet | :green_heart: | :green_heart: | | |
|
||||
| east_text_detector | :green_heart: | :green_heart: | | |
|
||||
| efficientnet_lite0_int8 | :green_heart: | :green_heart: | | |
|
||||
| efficientnet | :green_heart: | :green_heart: | | |
|
||||
| gpt2 | :green_heart: | :green_heart: | | |
|
||||
| image_stylization | :green_heart: | :green_heart: | | |
|
||||
| inception_v4 | :green_heart: | :green_heart: | | |
|
||||
| inception_v4_uint8 | :green_heart: | :green_heart: | | |
|
||||
| lightning_fp16 | :green_heart: | :green_heart: | | |
|
||||
| lightning_i8 | :green_heart: | :green_heart: | | |
|
||||
| lightning | :green_heart: | :green_heart: | | |
|
||||
| magenta | :green_heart: | :green_heart: | | |
|
||||
| midas | :green_heart: | :green_heart: | | |
|
||||
| mirnet | :green_heart: | :green_heart: | | |
|
||||
| mnasnet | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_edgetpu_s_float | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_edgetpu_s_quant | :green_heart: | :green_heart: | | |
|
||||
| mobilebert | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_tf2_float | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_tf2_quant | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_ssd_quant | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v1 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v2 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v2_uint8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v3-large | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v3-large_uint8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v35-int8 | :green_heart: | :green_heart: | | |
|
||||
| nasnet | :green_heart: | :green_heart: | | |
|
||||
| person_detect | :green_heart: | :green_heart: | | |
|
||||
| posenet | :green_heart: | :green_heart: | | |
|
||||
| resnet_50_int8 | :green_heart: | :green_heart: | | |
|
||||
| rosetta | :green_heart: | :green_heart: | | |
|
||||
| spice | :green_heart: | :green_heart: | | |
|
||||
| squeezenet | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v1 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2_fpnlite | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2_fpnlite_uint8 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2 | :green_heart: | :green_heart: | | |
|
||||
| ssd_spaghettinet_large | :green_heart: | :green_heart: | | |
|
||||
| ssd_spaghettinet_large_uint8 | :green_heart: | :green_heart: | | |
|
||||
| visual_wake_words_i8 | :green_heart: | :green_heart: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>TF Models</summary>
|
||||
|
||||
### Tensorflow Models (Inference)
|
||||
|
||||
| Hugging Face Models | tf-mhlo lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| albert-base-v2 | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DistilBERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| CamemBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ConvBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Deberta | | | | |
|
||||
| electra | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| funnel | | | | |
|
||||
| layoutlm | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| longformer | | | | |
|
||||
| mobile-bert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| remembert | | | | |
|
||||
| tapas | | | | |
|
||||
| flaubert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| xlm-roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| mpnet | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
|
||||
</details>
|
||||
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
|
||||
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
|
||||
|
||||
## Related Projects
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class TFHuggingFaceLanguage(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=tf_bert_input)
|
||||
@tf.function(input_signature=tf_bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
@@ -36,6 +36,12 @@ def pytest_addoption(parser):
|
||||
default="False",
|
||||
help="Enables uploading of reproduction artifacts upon test case failure during iree-compile or validation. Must be passed with --ci_sha option ",
|
||||
)
|
||||
parser.addoption(
|
||||
"--update_tank",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Update local shark tank with latest artifacts.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--ci_sha",
|
||||
action="store",
|
||||
|
||||
3
cpp/.gitignore
vendored
Normal file
3
cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
*.mlir
|
||||
*.vmfb
|
||||
*.ini
|
||||
@@ -54,5 +54,29 @@ python -m pip install tensorflow
|
||||
|
||||
*Run the vulkan_gui*
|
||||
```bash
|
||||
./build/vulkan_gui/iree-samples-vulkan-gui
|
||||
./build/vulkan_gui/iree-samples-resnet-vulkan-gui
|
||||
```
|
||||
|
||||
## Other models
|
||||
A tool for benchmarking other models is built and can be invoked with a command like the following
|
||||
```bash
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=path/to/.vmfb --function_input=...
|
||||
```
|
||||
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
|
||||
```
|
||||
VAE and Autoencoder are also available
|
||||
```bash
|
||||
# VAE
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
|
||||
|
||||
# CLIP Autoencoder
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
|
||||
```
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
|
||||
|
||||
def load_and_preprocess_image(fname: str):
|
||||
|
||||
@@ -40,45 +40,77 @@ set(IMGUI_DIR ${CMAKE_BINARY_DIR}/_deps/imgui-src)
|
||||
message("Looking for Imgui in ${IMGUI_DIR}")
|
||||
include_directories(${IMGUI_DIR} ${IMGUI_DIR}/backends ..)
|
||||
|
||||
# Define the sample executable.
|
||||
set(_NAME "iree-samples-vulkan-gui")
|
||||
add_executable(${_NAME} "")
|
||||
target_sources(${_NAME}
|
||||
PRIVATE
|
||||
vulkan_inference_gui.cc
|
||||
"${IMGUI_DIR}/backends/imgui_impl_sdl.cpp"
|
||||
"${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp"
|
||||
"${IMGUI_DIR}/imgui.cpp"
|
||||
"${IMGUI_DIR}/imgui_draw.cpp"
|
||||
"${IMGUI_DIR}/imgui_demo.cpp"
|
||||
"${IMGUI_DIR}/imgui_tables.cpp"
|
||||
"${IMGUI_DIR}/imgui_widgets.cpp"
|
||||
)
|
||||
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "iree-samples-vulkan-gui")
|
||||
target_include_directories(${_NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
|
||||
)
|
||||
target_link_libraries(${_NAME}
|
||||
SDL2::SDL2
|
||||
Vulkan::Vulkan
|
||||
iree_runtime_runtime
|
||||
iree_base_internal_main
|
||||
iree_hal_drivers_vulkan_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_vm_cc
|
||||
|
||||
function(iree_vulkan_sample)
|
||||
|
||||
cmake_parse_arguments(
|
||||
_RULE
|
||||
""
|
||||
"NAME"
|
||||
"SRCS"
|
||||
${ARGN}
|
||||
)
|
||||
|
||||
|
||||
# Define the sample executable.
|
||||
set(_NAME "${_RULE_NAME}")
|
||||
set(SRCS "${_RULE_SRCS}")
|
||||
add_executable(${_NAME} "")
|
||||
target_sources(${_NAME}
|
||||
PRIVATE
|
||||
${SRCS}
|
||||
"${IMGUI_DIR}/backends/imgui_impl_sdl.cpp"
|
||||
"${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp"
|
||||
"${IMGUI_DIR}/imgui.cpp"
|
||||
"${IMGUI_DIR}/imgui_draw.cpp"
|
||||
"${IMGUI_DIR}/imgui_demo.cpp"
|
||||
"${IMGUI_DIR}/imgui_tables.cpp"
|
||||
"${IMGUI_DIR}/imgui_widgets.cpp"
|
||||
)
|
||||
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_NAME}")
|
||||
target_include_directories(${_NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
|
||||
)
|
||||
target_link_libraries(${_NAME}
|
||||
SDL2::SDL2
|
||||
Vulkan::Vulkan
|
||||
iree_runtime_runtime
|
||||
iree_base_internal_main
|
||||
iree_hal_drivers_vulkan_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_vm_cc
|
||||
iree_tooling_vm_util_cc
|
||||
iree_tooling_context_util
|
||||
)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
|
||||
set(_GUI_LINKOPTS "-SUBSYSTEM:CONSOLE")
|
||||
else()
|
||||
set(_GUI_LINKOPTS "")
|
||||
endif()
|
||||
|
||||
target_link_options(${_NAME}
|
||||
PRIVATE
|
||||
${_GUI_LINKOPTS}
|
||||
)
|
||||
endfunction()
|
||||
|
||||
iree_vulkan_sample(
|
||||
NAME
|
||||
iree-samples-resnet-vulkan-gui
|
||||
|
||||
SRCS
|
||||
vulkan_resnet_inference_gui.cc
|
||||
)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
|
||||
set(_GUI_LINKOPTS "-SUBSYSTEM:CONSOLE")
|
||||
else()
|
||||
set(_GUI_LINKOPTS "")
|
||||
endif()
|
||||
iree_vulkan_sample(
|
||||
NAME
|
||||
iree-vulkan-gui
|
||||
|
||||
target_link_options(${_NAME}
|
||||
PRIVATE
|
||||
${_GUI_LINKOPTS}
|
||||
SRCS
|
||||
vulkan_inference_gui.cc
|
||||
)
|
||||
|
||||
message(STATUS "Configured vulkan_gui sample successfully")
|
||||
|
||||
@@ -18,6 +18,12 @@
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <array>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "iree/hal/drivers/vulkan/api.h"
|
||||
|
||||
@@ -30,6 +36,15 @@
|
||||
#include "iree/vm/bytecode_module.h"
|
||||
#include "iree/vm/ref_cc.h"
|
||||
|
||||
// iree-run-module
|
||||
#include "iree/base/internal/flags.h"
|
||||
#include "iree/base/status_cc.h"
|
||||
#include "iree/base/tracing.h"
|
||||
#include "iree/modules/hal/types.h"
|
||||
#include "iree/tooling/comparison.h"
|
||||
#include "iree/tooling/context_util.h"
|
||||
#include "iree/tooling/vm_util_cc.h"
|
||||
|
||||
// Other dependencies (helpers, etc.)
|
||||
#include "iree/base/internal/main.h"
|
||||
|
||||
@@ -38,6 +53,49 @@
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
IREE_FLAG(string, entry_function, "",
|
||||
"Name of a function contained in the module specified by module_file "
|
||||
"to run.");
|
||||
|
||||
// TODO(benvanik): move --function_input= flag into a util.
|
||||
static iree_status_t parse_function_io(iree_string_view_t flag_name,
|
||||
void* storage,
|
||||
iree_string_view_t value) {
|
||||
auto* list = (std::vector<std::string>*)storage;
|
||||
list->push_back(std::string(value.data, value.size));
|
||||
return iree_ok_status();
|
||||
}
|
||||
static void print_function_io(iree_string_view_t flag_name, void* storage,
|
||||
FILE* file) {
|
||||
auto* list = (std::vector<std::string>*)storage;
|
||||
if (list->empty()) {
|
||||
fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
|
||||
} else {
|
||||
for (size_t i = 0; i < list->size(); ++i) {
|
||||
fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
|
||||
list->at(i).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
static std::vector<std::string> FLAG_function_inputs;
|
||||
IREE_FLAG_CALLBACK(
|
||||
parse_function_io, print_function_io, &FLAG_function_inputs, function_input,
|
||||
"An input (a) value or (b) buffer of the format:\n"
|
||||
" (a) scalar value\n"
|
||||
" value\n"
|
||||
" e.g.: --function_input=\"3.14\"\n"
|
||||
" (b) buffer:\n"
|
||||
" [shape]xtype=[value]\n"
|
||||
" e.g.: --function_input=\"2x2xi32=1 2 3 4\"\n"
|
||||
"Optionally, brackets may be used to separate the element values:\n"
|
||||
" 2x2xi32=[[1 2][3 4]]\n"
|
||||
"Raw binary files can be read to provide buffer contents:\n"
|
||||
" 2x2xi32=@some/file.bin\n"
|
||||
"numpy npy files (from numpy.save) can be read to provide 1+ values:\n"
|
||||
" @some.npy\n"
|
||||
"Each occurrence of the flag indicates an input in the order they were\n"
|
||||
"specified on the command line.");
|
||||
|
||||
typedef struct iree_file_toc_t {
|
||||
const char* name; // the file's original name
|
||||
char* data; // beginning of the file
|
||||
@@ -87,225 +145,6 @@ static void check_vk_result(VkResult err) {
|
||||
abort();
|
||||
}
|
||||
|
||||
// Helper function to find Vulkan memory type bits. See ImGui_ImplVulkan_MemoryType() in imgui_impl_vulkan.cpp
|
||||
uint32_t findMemoryType(uint32_t type_filter, VkMemoryPropertyFlags properties)
|
||||
{
|
||||
VkPhysicalDeviceMemoryProperties mem_properties;
|
||||
vkGetPhysicalDeviceMemoryProperties(g_PhysicalDevice, &mem_properties);
|
||||
|
||||
for (uint32_t i = 0; i < mem_properties.memoryTypeCount; i++)
|
||||
{
|
||||
if ((type_filter & (1 << i)) && (mem_properties.memoryTypes[i].propertyFlags & properties) == properties)
|
||||
{
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return 0xFFFFFFFF; // Unable to find memoryType
|
||||
}
|
||||
|
||||
// Helper function to load an image with common settings and return a VkDescriptorSet as a sort of Vulkan pointer
|
||||
bool LoadTextureFromFile(const char* filename, VkDescriptorSet* img_ds, int* image_width, int* image_height)
|
||||
{
|
||||
// Specifying 4 channels forces stb to load the image in RGBA which is an easy format for Vulkan
|
||||
int image_channels = 4;
|
||||
unsigned char* image_data = stbi_load(filename, image_width, image_height, 0, image_channels);
|
||||
|
||||
if (image_data == NULL)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Calculate allocation size (in number of bytes)
|
||||
size_t image_size = (*image_width)*(*image_height)*image_channels;
|
||||
|
||||
VkResult err;
|
||||
|
||||
// Create the Vulkan image.
|
||||
VkImage texture_image;
|
||||
VkDeviceMemory texture_image_memory;
|
||||
{
|
||||
VkImageCreateInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO;
|
||||
info.imageType = VK_IMAGE_TYPE_2D;
|
||||
info.format = VK_FORMAT_R8G8B8A8_UNORM;
|
||||
info.extent.width = *image_width;
|
||||
info.extent.height = *image_height;
|
||||
info.extent.depth = 1;
|
||||
info.mipLevels = 1;
|
||||
info.arrayLayers = 1;
|
||||
info.samples = VK_SAMPLE_COUNT_1_BIT;
|
||||
info.tiling = VK_IMAGE_TILING_OPTIMAL;
|
||||
info.usage = VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT;
|
||||
info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
|
||||
info.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED;
|
||||
err = vkCreateImage(g_Device, &info, g_Allocator, &texture_image);
|
||||
check_vk_result(err);
|
||||
VkMemoryRequirements req;
|
||||
vkGetImageMemoryRequirements(g_Device, texture_image, &req);
|
||||
VkMemoryAllocateInfo alloc_info = {};
|
||||
alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
|
||||
alloc_info.allocationSize = req.size;
|
||||
alloc_info.memoryTypeIndex = findMemoryType(req.memoryTypeBits, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);
|
||||
err = vkAllocateMemory(g_Device, &alloc_info, g_Allocator, &texture_image_memory);
|
||||
check_vk_result(err);
|
||||
err = vkBindImageMemory(g_Device, texture_image, texture_image_memory, 0);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Create the Image View
|
||||
VkImageView image_view;
|
||||
{
|
||||
VkImageViewCreateInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO;
|
||||
info.image = texture_image;
|
||||
info.viewType = VK_IMAGE_VIEW_TYPE_2D;
|
||||
info.format = VK_FORMAT_R8G8B8A8_UNORM;
|
||||
info.subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
|
||||
info.subresourceRange.levelCount = 1;
|
||||
info.subresourceRange.layerCount = 1;
|
||||
err = vkCreateImageView(g_Device, &info, g_Allocator, &image_view);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Create Sampler
|
||||
VkSampler sampler;
|
||||
{
|
||||
VkSamplerCreateInfo sampler_info{};
|
||||
sampler_info.sType = VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO;
|
||||
sampler_info.magFilter = VK_FILTER_LINEAR;
|
||||
sampler_info.minFilter = VK_FILTER_LINEAR;
|
||||
sampler_info.mipmapMode = VK_SAMPLER_MIPMAP_MODE_LINEAR;
|
||||
sampler_info.addressModeU = VK_SAMPLER_ADDRESS_MODE_REPEAT; // outside image bounds just use border color
|
||||
sampler_info.addressModeV = VK_SAMPLER_ADDRESS_MODE_REPEAT;
|
||||
sampler_info.addressModeW = VK_SAMPLER_ADDRESS_MODE_REPEAT;
|
||||
sampler_info.minLod = -1000;
|
||||
sampler_info.maxLod = 1000;
|
||||
sampler_info.maxAnisotropy = 1.0f;
|
||||
err = vkCreateSampler(g_Device, &sampler_info, g_Allocator, &sampler);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Create Descriptor Set using ImGUI's implementation
|
||||
*img_ds = ImGui_ImplVulkan_AddTexture(sampler, image_view, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL);
|
||||
|
||||
// Create Upload Buffer
|
||||
VkBuffer upload_buffer;
|
||||
VkDeviceMemory upload_buffer_memory;
|
||||
{
|
||||
VkBufferCreateInfo buffer_info = {};
|
||||
buffer_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
|
||||
buffer_info.size = image_size;
|
||||
buffer_info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
|
||||
buffer_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
|
||||
err = vkCreateBuffer(g_Device, &buffer_info, g_Allocator, &upload_buffer);
|
||||
check_vk_result(err);
|
||||
VkMemoryRequirements req;
|
||||
vkGetBufferMemoryRequirements(g_Device, upload_buffer, &req);
|
||||
VkMemoryAllocateInfo alloc_info = {};
|
||||
alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
|
||||
alloc_info.allocationSize = req.size;
|
||||
alloc_info.memoryTypeIndex = findMemoryType(req.memoryTypeBits, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
|
||||
err = vkAllocateMemory(g_Device, &alloc_info, g_Allocator, &upload_buffer_memory);
|
||||
check_vk_result(err);
|
||||
err = vkBindBufferMemory(g_Device, upload_buffer, upload_buffer_memory, 0);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Upload to Buffer:
|
||||
{
|
||||
void* map = NULL;
|
||||
err = vkMapMemory(g_Device, upload_buffer_memory, 0, image_size, 0, &map);
|
||||
check_vk_result(err);
|
||||
memcpy(map, image_data, image_size);
|
||||
VkMappedMemoryRange range[1] = {};
|
||||
range[0].sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
|
||||
range[0].memory = upload_buffer_memory;
|
||||
range[0].size = image_size;
|
||||
err = vkFlushMappedMemoryRanges(g_Device, 1, range);
|
||||
check_vk_result(err);
|
||||
vkUnmapMemory(g_Device, upload_buffer_memory);
|
||||
}
|
||||
|
||||
// Release image memory using stb
|
||||
stbi_image_free(image_data);
|
||||
|
||||
// Create a command buffer that will perform following steps when hit in the command queue.
|
||||
// TODO: this works in the example, but may need input if this is an acceptable way to access the pool/create the command buffer.
|
||||
VkCommandPool command_pool = g_MainWindowData.Frames[g_MainWindowData.FrameIndex].CommandPool;
|
||||
VkCommandBuffer command_buffer;
|
||||
{
|
||||
VkCommandBufferAllocateInfo alloc_info{};
|
||||
alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
|
||||
alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
|
||||
alloc_info.commandPool = command_pool;
|
||||
alloc_info.commandBufferCount = 1;
|
||||
|
||||
err = vkAllocateCommandBuffers(g_Device, &alloc_info, &command_buffer);
|
||||
check_vk_result(err);
|
||||
|
||||
VkCommandBufferBeginInfo begin_info = {};
|
||||
begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
|
||||
begin_info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
|
||||
err = vkBeginCommandBuffer(command_buffer, &begin_info);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Copy to Image
|
||||
{
|
||||
VkImageMemoryBarrier copy_barrier[1] = {};
|
||||
copy_barrier[0].sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER;
|
||||
copy_barrier[0].dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
|
||||
copy_barrier[0].oldLayout = VK_IMAGE_LAYOUT_UNDEFINED;
|
||||
copy_barrier[0].newLayout = VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL;
|
||||
copy_barrier[0].srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
|
||||
copy_barrier[0].dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
|
||||
copy_barrier[0].image = texture_image;
|
||||
copy_barrier[0].subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
|
||||
copy_barrier[0].subresourceRange.levelCount = 1;
|
||||
copy_barrier[0].subresourceRange.layerCount = 1;
|
||||
vkCmdPipelineBarrier(command_buffer, VK_PIPELINE_STAGE_HOST_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 0, NULL, 0, NULL, 1, copy_barrier);
|
||||
|
||||
VkBufferImageCopy region = {};
|
||||
region.imageSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
|
||||
region.imageSubresource.layerCount = 1;
|
||||
region.imageExtent.width = *image_width;
|
||||
region.imageExtent.height = *image_height;
|
||||
region.imageExtent.depth = 1;
|
||||
vkCmdCopyBufferToImage(command_buffer, upload_buffer, texture_image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1, ®ion);
|
||||
|
||||
VkImageMemoryBarrier use_barrier[1] = {};
|
||||
use_barrier[0].sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER;
|
||||
use_barrier[0].srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
|
||||
use_barrier[0].dstAccessMask = VK_ACCESS_SHADER_READ_BIT;
|
||||
use_barrier[0].oldLayout = VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL;
|
||||
use_barrier[0].newLayout = VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL;
|
||||
use_barrier[0].srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
|
||||
use_barrier[0].dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
|
||||
use_barrier[0].image = texture_image;
|
||||
use_barrier[0].subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
|
||||
use_barrier[0].subresourceRange.levelCount = 1;
|
||||
use_barrier[0].subresourceRange.layerCount = 1;
|
||||
vkCmdPipelineBarrier(command_buffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0, 0, NULL, 0, NULL, 1, use_barrier);
|
||||
}
|
||||
|
||||
// End command buffer
|
||||
{
|
||||
VkSubmitInfo end_info = {};
|
||||
end_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
|
||||
end_info.commandBufferCount = 1;
|
||||
end_info.pCommandBuffers = &command_buffer;
|
||||
err = vkEndCommandBuffer(command_buffer);
|
||||
check_vk_result(err);
|
||||
err = vkQueueSubmit(g_Queue, 1, &end_info, VK_NULL_HANDLE);
|
||||
check_vk_result(err);
|
||||
err = vkDeviceWaitIdle(g_Device);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns the names of the Vulkan layers used for the given IREE
|
||||
// |extensibility_set| and |features|.
|
||||
std::vector<const char*> GetIreeLayers(
|
||||
@@ -723,7 +562,16 @@ namespace iree {
|
||||
|
||||
extern "C" int iree_main(int argc, char** argv) {
|
||||
|
||||
fprintf(stdout, "starting yo\n");
|
||||
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
|
||||
if (argc > 1) {
|
||||
// Avoid iree-run-module spinning endlessly on stdin if the user uses single
|
||||
// dashes for flags.
|
||||
printf(
|
||||
"[ERROR] unexpected positional argument (expected none)."
|
||||
" Did you use pass a flag with a single dash ('-')?"
|
||||
" Use '--' instead.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Create a window.
|
||||
@@ -835,8 +683,6 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
|
||||
// Demo state.
|
||||
bool show_iree_window = true;
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Setup IREE.
|
||||
|
||||
@@ -900,69 +746,44 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
|
||||
|
||||
// Load bytecode module
|
||||
iree_file_toc_t module_file_toc;
|
||||
const char network_model[] = "resnet50_tf.vmfb";
|
||||
fprintf(stdout, "Loading: %s\n", network_model);
|
||||
if (load_file(network_model, &module_file_toc.data, &module_file_toc.size) == false)
|
||||
{
|
||||
abort();
|
||||
return 1;
|
||||
}
|
||||
fprintf(stdout, "module size: %zu\n", module_file_toc.size);
|
||||
|
||||
static float input_res50[224*224*3];
|
||||
static float output_res50[1000];
|
||||
|
||||
char filename[] = "dog_imagenet.jpg";
|
||||
fprintf(stdout, "loading: %s\n", filename);
|
||||
int x,y,n;
|
||||
//unsigned char *image_raw = stbi_load(filename, &x, &y, &n, 3);
|
||||
stbi_load(filename, &x, &y, &n, 3);
|
||||
fprintf(stdout, "res: %i x %i x %i\n", x, y, n);
|
||||
|
||||
/* Preprocessing needs to go here. For now use a buffer preprocessed in python.
|
||||
|
||||
//convert image into floating point format
|
||||
for(int i=0;i<224*224*3;i++)
|
||||
{
|
||||
input_res50[i]= ((float)image_raw[i])/255.0f;
|
||||
}*/
|
||||
|
||||
std::ifstream fin("dog.bin", std::ifstream::in | std::ifstream::binary);
|
||||
fin.read((char*)input_res50, 224*224*3*sizeof(float));
|
||||
|
||||
// load image again so imgui can display it
|
||||
int my_image_width = 0;
|
||||
int my_image_height = 0;
|
||||
VkDescriptorSet my_image_texture = 0;
|
||||
bool ret = LoadTextureFromFile(filename, &my_image_texture, &my_image_width, &my_image_height);
|
||||
fprintf(stdout, "creating vulkan image: %s\n", ret ?"OK":"FAIL");
|
||||
IM_ASSERT(ret);
|
||||
//iree_file_toc_t module_file_toc;
|
||||
//const char network_model[] = "resnet50_tf.vmfb";
|
||||
//fprintf(stdout, "Loading: %s\n", network_model);
|
||||
//if (load_file(network_model, &module_file_toc.data, &module_file_toc.size) == false)
|
||||
//{
|
||||
// abort();
|
||||
// return 1;
|
||||
//}
|
||||
//fprintf(stdout, "module size: %zu\n", module_file_toc.size);
|
||||
|
||||
iree_vm_module_t* bytecode_module = nullptr;
|
||||
IREE_CHECK_OK(iree_vm_bytecode_module_create(
|
||||
iree_instance,
|
||||
iree_const_byte_span_t{
|
||||
reinterpret_cast<const uint8_t*>(module_file_toc.data),
|
||||
module_file_toc.size},
|
||||
iree_allocator_null(), iree_allocator_system(), &bytecode_module));
|
||||
// Query for details about what is in the loaded module.
|
||||
iree_vm_module_signature_t bytecode_module_signature =
|
||||
iree_vm_module_signature(bytecode_module);
|
||||
fprintf(stdout, "Module loaded, have <%" PRIhsz "> exported functions:\n",
|
||||
bytecode_module_signature.export_function_count);
|
||||
for (int i = 0; i < bytecode_module_signature.export_function_count; ++i) {
|
||||
iree_vm_function_t function;
|
||||
IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal(
|
||||
bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
|
||||
auto function_name = iree_vm_function_name(&function);
|
||||
auto function_signature = iree_vm_function_signature(&function);
|
||||
iree_status_t module_status = iree_tooling_load_module_from_flags(
|
||||
iree_instance, iree_allocator_system(), &bytecode_module);
|
||||
if (!iree_status_is_ok(module_status))
|
||||
return -1;
|
||||
//IREE_CHECK_OK(iree_vm_bytecode_module_create(
|
||||
// iree_instance,
|
||||
// iree_const_byte_span_t{
|
||||
// reinterpret_cast<const uint8_t*>(module_file_toc.data),
|
||||
// module_file_toc.size},
|
||||
// iree_allocator_null(), iree_allocator_system(), &bytecode_module));
|
||||
//// Query for details about what is in the loaded module.
|
||||
//iree_vm_module_signature_t bytecode_module_signature =
|
||||
// iree_vm_module_signature(bytecode_module);
|
||||
//fprintf(stdout, "Module loaded, have <%" PRIhsz "> exported functions:\n",
|
||||
// bytecode_module_signature.export_function_count);
|
||||
//for (int i = 0; i < bytecode_module_signature.export_function_count; ++i) {
|
||||
// iree_vm_function_t function;
|
||||
// IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal(
|
||||
// bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
|
||||
// auto function_name = iree_vm_function_name(&function);
|
||||
// auto function_signature = iree_vm_function_signature(&function);
|
||||
|
||||
fprintf(stdout, " %d: '%.*s' with calling convention '%.*s'\n", i,
|
||||
(int)function_name.size, function_name.data,
|
||||
(int)function_signature.calling_convention.size,
|
||||
function_signature.calling_convention.data);
|
||||
}
|
||||
// fprintf(stdout, " %d: '%.*s' with calling convention '%.*s'\n", i,
|
||||
// (int)function_name.size, function_name.data,
|
||||
// (int)function_signature.calling_convention.size,
|
||||
// function_signature.calling_convention.data);
|
||||
//}
|
||||
|
||||
// Allocate a context that will hold the module state across invocations.
|
||||
iree_vm_context_t* iree_context = nullptr;
|
||||
@@ -988,33 +809,42 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
// Write inputs into mappable buffers.
|
||||
iree_hal_allocator_t* allocator =
|
||||
iree_hal_device_allocator(iree_vk_device);
|
||||
iree_hal_memory_type_t input_memory_type =
|
||||
static_cast<iree_hal_memory_type_t>(
|
||||
IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
|
||||
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
|
||||
iree_hal_buffer_usage_t input_buffer_usage =
|
||||
static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
|
||||
iree_hal_buffer_params_t buffer_params;
|
||||
buffer_params.type = input_memory_type;
|
||||
buffer_params.usage = input_buffer_usage;
|
||||
buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE;
|
||||
//iree_hal_memory_type_t input_memory_type =
|
||||
// static_cast<iree_hal_memory_type_t>(
|
||||
// IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
|
||||
// IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
|
||||
//iree_hal_buffer_usage_t input_buffer_usage =
|
||||
// static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
|
||||
//iree_hal_buffer_params_t buffer_params;
|
||||
//buffer_params.type = input_memory_type;
|
||||
//buffer_params.usage = input_buffer_usage;
|
||||
//buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE;
|
||||
|
||||
// Wrap input buffers in buffer views.
|
||||
|
||||
iree_hal_buffer_view_t* input0_buffer_view = nullptr;
|
||||
constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
|
||||
IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
|
||||
allocator,
|
||||
/*shape_rank=*/4, /*shape=*/input_buffer_shape,
|
||||
IREE_HAL_ELEMENT_TYPE_FLOAT_32,
|
||||
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
|
||||
iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
|
||||
&input0_buffer_view));
|
||||
|
||||
vm::ref<iree_vm_list_t> inputs;
|
||||
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
|
||||
auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
|
||||
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
|
||||
iree_status_t input_status = ParseToVariantList(
|
||||
allocator,
|
||||
iree::span<const std::string>{FLAG_function_inputs.data(),
|
||||
FLAG_function_inputs.size()},
|
||||
iree_allocator_system(), &inputs);
|
||||
if (!iree_status_is_ok(input_status))
|
||||
return -1;
|
||||
//vm::ref<iree_vm_list_t> inputs;
|
||||
//IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
|
||||
|
||||
//iree_hal_buffer_view_t* input0_buffer_view = nullptr;
|
||||
//constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
|
||||
//IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
|
||||
// allocator,
|
||||
// /*shape_rank=*/4, /*shape=*/input_buffer_shape,
|
||||
// IREE_HAL_ELEMENT_TYPE_FLOAT_32,
|
||||
// IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
|
||||
// iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
|
||||
// &input0_buffer_view));
|
||||
|
||||
//auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
|
||||
//IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
|
||||
|
||||
// Prepare outputs list to accept results from the invocation.
|
||||
|
||||
@@ -1023,6 +853,7 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, kOutputCount * sizeof(float), iree_allocator_system(), &outputs));
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Main loop.
|
||||
bool done = false;
|
||||
while (!done) {
|
||||
@@ -1076,46 +907,11 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
/*policy=*/nullptr, inputs.get(),
|
||||
outputs.get(), iree_allocator_system()));
|
||||
|
||||
// Read back the results.
|
||||
auto* output_buffer_view = reinterpret_cast<iree_hal_buffer_view_t*>(
|
||||
iree_vm_list_get_ref_deref(outputs.get(),
|
||||
0,
|
||||
iree_hal_buffer_view_get_descriptor()));
|
||||
IREE_CHECK_OK(iree_hal_device_transfer_d2h(
|
||||
iree_vk_device,
|
||||
iree_hal_buffer_view_buffer(output_buffer_view),
|
||||
0,
|
||||
output_res50, sizeof(output_res50),
|
||||
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
|
||||
|
||||
// we want to run continuously so we can use tools like RenderDoc, RGP, etc...
|
||||
dirty = true;
|
||||
}
|
||||
|
||||
// find maxarg from results
|
||||
float max = 0.0f;
|
||||
int max_idx = -1;
|
||||
for(int i=0;i<1000;i++)
|
||||
{
|
||||
if (output_res50[i] > max)
|
||||
{
|
||||
max = output_res50[i];
|
||||
max_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
ImGui::Text("pointer = %p", my_image_texture);
|
||||
ImGui::Text("size = %d x %d", my_image_width, my_image_height);
|
||||
ImGui::Image((ImTextureID)my_image_texture, ImVec2(my_image_width, my_image_height));
|
||||
|
||||
// Display the latest computation output.
|
||||
ImGui::Text("Max idx = [%i]", max_idx);
|
||||
ImGui::Text("Max value = [%f]", max);
|
||||
|
||||
ImGui::Text("Resnet50 categories:");
|
||||
ImGui::PlotHistogram("Histogram", output_res50, IM_ARRAYSIZE(output_res50), 0, NULL, 0.0f, 1.0f, ImVec2(0,80));
|
||||
ImGui::Separator();
|
||||
|
||||
// Framerate counter.
|
||||
ImGui::Text("Application average %.3f ms/frame (%.1f FPS)",
|
||||
1000.0f / ImGui::GetIO().Framerate, ImGui::GetIO().Framerate);
|
||||
@@ -1137,6 +933,7 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
iree_vm_module_release(bytecode_module);
|
||||
iree_vm_context_release(iree_context);
|
||||
iree_hal_device_release(iree_vk_device);
|
||||
iree_hal_allocator_release(allocator);
|
||||
iree_hal_driver_release(iree_vk_driver);
|
||||
iree_hal_vulkan_syms_release(iree_vk_syms);
|
||||
iree_vm_instance_release(iree_instance);
|
||||
|
||||
1160
cpp/vulkan_gui/vulkan_resnet_inference_gui.cc
Normal file
1160
cpp/vulkan_gui/vulkan_resnet_inference_gui.cc
Normal file
File diff suppressed because it is too large
Load Diff
@@ -205,14 +205,14 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--torch_model_csv",
|
||||
type=lambda x: is_valid_file(x),
|
||||
default="./tank/pytorch/torch_model_list.csv",
|
||||
default="./tank/torch_model_list.csv",
|
||||
help="""Contains the file with torch_model name and args.
|
||||
Please see: https://github.com/nod-ai/SHARK/blob/main/tank/pytorch/torch_model_list.csv""",
|
||||
Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tf_model_csv",
|
||||
type=lambda x: is_valid_file(x),
|
||||
default="./tank/tf/tf_model_list.csv",
|
||||
default="./tank/tf_model_list.csv",
|
||||
help="Contains the file with tf model name and args.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -4,9 +4,9 @@ requires = [
|
||||
"wheel",
|
||||
"packaging",
|
||||
|
||||
"numpy==1.22.4",
|
||||
"torch-mlir>=20220428.420",
|
||||
"iree-compiler>=20220427.13",
|
||||
"iree-runtime>=20220427.13",
|
||||
"numpy>=1.22.4",
|
||||
"torch-mlir>=20221021.633",
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/
|
||||
--pre
|
||||
|
||||
numpy
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
--pre
|
||||
|
||||
numpy==1.22.4
|
||||
torch
|
||||
torchvision
|
||||
|
||||
tqdm
|
||||
@@ -14,7 +13,8 @@ iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow
|
||||
tensorflow==2.10
|
||||
keras==2.10
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
setuptools
|
||||
wheel
|
||||
pyinstaller
|
||||
|
||||
# SHARK Runner
|
||||
tqdm
|
||||
|
||||
# SHARK Downloader
|
||||
gsutil
|
||||
google-cloud-storage
|
||||
|
||||
# Testing
|
||||
pytest
|
||||
pytest-xdist
|
||||
Pillow
|
||||
parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
transformers
|
||||
diffusers
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
|
||||
8
setup.py
8
setup.py
@@ -10,8 +10,8 @@ PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.4"
|
||||
backend_deps = []
|
||||
if "NO_BACKEND" in os.environ.keys():
|
||||
backend_deps = [
|
||||
"iree-compiler>=20220427.13",
|
||||
"iree-runtime>=20220427.13",
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
]
|
||||
|
||||
setup(
|
||||
@@ -33,11 +33,11 @@ setup(
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
packages=find_packages(exclude=("examples")),
|
||||
python_requires=">=3.7",
|
||||
python_requires=">=3.9",
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"PyYAML",
|
||||
"torch-mlir>=20220428.420",
|
||||
"torch-mlir>=20221021.633",
|
||||
]
|
||||
+ backend_deps,
|
||||
)
|
||||
|
||||
40
setup_venv.ps1
Normal file
40
setup_venv.ps1
Normal file
@@ -0,0 +1,40 @@
|
||||
#Write-Host "Installing python"
|
||||
|
||||
#Start-Process winget install Python.Python.3.10 '/quiet InstallAllUsers=1 PrependPath=1' -wait -NoNewWindow
|
||||
|
||||
#Write-Host "python installation completed successfully"
|
||||
|
||||
#Write-Host "Reload environment variables"
|
||||
#$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||
#Write-Host "Reloaded environment variables"
|
||||
|
||||
|
||||
# redirect stderr into stdout
|
||||
$p = &{python -V} 2>&1
|
||||
# check if an ErrorRecord was returned
|
||||
$version = if($p -is [System.Management.Automation.ErrorRecord])
|
||||
{
|
||||
# grab the version string from the error message
|
||||
$p.Exception.Message
|
||||
}
|
||||
else
|
||||
{
|
||||
# otherwise return as is
|
||||
$p
|
||||
}
|
||||
|
||||
Write-Host "Python version found is"
|
||||
Write-Host $p
|
||||
|
||||
|
||||
Write-Host "Installing Build Dependencies"
|
||||
python -m venv .\shark.venv\
|
||||
.\shark.venv\Scripts\activate
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
pip install diffusers transformers scipy pillow gradio
|
||||
Write-Host "Build and installation completed successfully"
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
@@ -76,11 +76,16 @@ fi
|
||||
$PYTHON -m pip install --upgrade pip || die "Could not upgrade pip"
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
|
||||
if [ "$torch_mlir_bin" = true ]; then
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch-mlir"
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
echo "Could not install torch-mlir" >&2
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch-mlir"
|
||||
else
|
||||
echo "Could not install torch-mlir" >&2
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "${Red}No binaries found for Python $PYTHON_VERSION_X_Y on $(uname -s)"
|
||||
@@ -89,37 +94,41 @@ else
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "${USE_IREE}" ]]; then
|
||||
RUNTIME="nod-ai/SHARK-Runtime"
|
||||
rm .use-iree
|
||||
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
|
||||
else
|
||||
RUNTIME="google/iree"
|
||||
touch ./.use-iree
|
||||
RUNTIME="https://iree-org.github.io/iree/pip-release-links.html"
|
||||
fi
|
||||
if [[ -z "${NO_BACKEND}" ]]; then
|
||||
echo "Installing ${RUNTIME}..."
|
||||
$PYTHON -m pip install --find-links https://github.com/${RUNTIME}/releases iree-compiler iree-runtime
|
||||
$PYTHON -m pip install --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
|
||||
else
|
||||
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
|
||||
fi
|
||||
|
||||
if [[ ! -z "${IMPORTER}" ]]; then
|
||||
echo "${Yellow}Installing importer tools.."
|
||||
if [[ $(uname -s) = 'Linux' ]]; then
|
||||
echo "${Yellow}Linux detected.. installing Linux importer tools"
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements-importer.txt" -f https://github.com/${RUNTIME}/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
#Always get the importer tools from upstream IREE
|
||||
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://iree-org.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
elif [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "${Yellow}macOS detected.. installing macOS importer tools"
|
||||
#Conda seems to have some problems installing these packages and hope they get resolved upstream.
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements-importer-macos.txt" -f https://github.com/${RUNTIME}/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer-macos.txt" -f ${RUNTIME} --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
fi
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://github.com/${RUNTIME}/releases
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu116
|
||||
$PYTHON -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu117
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu116."
|
||||
echo "Successfully Installed torch + cu117."
|
||||
else
|
||||
echo "Could not install torch + cu116." >&2
|
||||
echo "Could not install torch + cu117." >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class CLIPModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, pixel_values=z
|
||||
)
|
||||
|
||||
@tf.function(input_signature=clip_vit_inputs)
|
||||
@tf.function(input_signature=clip_vit_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, pixel_values):
|
||||
return self.m.predict(
|
||||
input_ids, attention_mask, pixel_values
|
||||
|
||||
15
shark/examples/shark_inference/ESRGAN/README.md
Normal file
15
shark/examples/shark_inference/ESRGAN/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
## Running ESRGAN
|
||||
|
||||
```
|
||||
1. pip install numpy opencv-python
|
||||
2. mkdir InputImages
|
||||
(this is where all the input images will reside in)
|
||||
3. mkdir OutputImages
|
||||
(this is where the model will generate all the images)
|
||||
4. mkdir models
|
||||
(save the .pth checkpoint file here)
|
||||
5. python esrgan.py
|
||||
```
|
||||
|
||||
- Download [RRDB_ESRGAN_x4.pth](https://drive.google.com/drive/u/0/folders/17VYV_SoZZesU6mbxz2dMAIccSSlqLecY) and place it in the `models` directory as mentioned above in step 4.
|
||||
- Credits : [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
240
shark/examples/shark_inference/ESRGAN/esrgan.py
Normal file
240
shark/examples/shark_inference/ESRGAN/esrgan.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from ast import arg
|
||||
import os.path as osp
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from shark.shark_inference import SharkInference
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""Residual in Residual Dense Block"""
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(
|
||||
self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
fea = self.lrelu(
|
||||
self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument(
|
||||
"--mlir_loc",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location of the model's mlir file",
|
||||
)
|
||||
args = p.parse_args()
|
||||
###################################################
|
||||
|
||||
|
||||
def inference(input_m):
|
||||
return model(input_m)
|
||||
|
||||
|
||||
def load_mlir(mlir_loc):
|
||||
import os
|
||||
|
||||
if mlir_loc == None:
|
||||
return None
|
||||
print(f"Trying to load the model from {mlir_loc}.")
|
||||
with open(os.path.join(mlir_loc)) as f:
|
||||
mlir_module = f.read()
|
||||
return mlir_module
|
||||
|
||||
|
||||
def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
|
||||
module = load_mlir(mlir_loc)
|
||||
if module == None:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
print("Torchscript graph generated successfully")
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
inputs,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mlir_model = str(module)
|
||||
func_name = "forward"
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
model_path = "models/RRDB_ESRGAN_x4.pth" # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
|
||||
# device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
|
||||
device = torch.device("cpu")
|
||||
|
||||
test_img_folder = "InputImages/*"
|
||||
|
||||
model = RRDBNet(3, 3, 64, 23, gc=32)
|
||||
model.load_state_dict(torch.load(model_path), strict=True)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
print("Model path {:s}. \nTesting...".format(model_path))
|
||||
|
||||
if __name__ == "__main__":
|
||||
idx = 0
|
||||
for path in glob.glob(test_img_folder):
|
||||
idx += 1
|
||||
base = osp.splitext(osp.basename(path))[0]
|
||||
print(idx, base)
|
||||
# read images
|
||||
img = cv2.imread(path, cv2.IMREAD_COLOR)
|
||||
img = img * 1.0 / 255
|
||||
img = torch.from_numpy(
|
||||
np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))
|
||||
).float()
|
||||
img_LR = img.unsqueeze(0)
|
||||
img_LR = img_LR.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
shark_module = compile_through_fx(inference, img_LR)
|
||||
shark_output = shark_module.forward((img_LR,))
|
||||
shark_output = torch.from_numpy(shark_output)
|
||||
shark_output = (
|
||||
shark_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
)
|
||||
esrgan_output = (
|
||||
model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
)
|
||||
# SHARK OUTPUT
|
||||
shark_output = np.transpose(shark_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
shark_output = (shark_output * 255.0).round()
|
||||
cv2.imwrite(
|
||||
"OutputImages/{:s}_rlt_shark_output.png".format(base), shark_output
|
||||
)
|
||||
print("Generated SHARK's output")
|
||||
# ESRGAN OUTPUT
|
||||
esrgan_output = np.transpose(esrgan_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
esrgan_output = (esrgan_output * 255.0).round()
|
||||
cv2.imwrite(
|
||||
"OutputImages/{:s}_rlt_esrgan_output.png".format(base),
|
||||
esrgan_output,
|
||||
)
|
||||
print("Generated ESRGAN's output")
|
||||
@@ -28,7 +28,7 @@ class AlbertModule(tf.Module):
|
||||
self.m = TFAutoModelForMaskedLM.from_pretrained("albert-base-v2")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs)
|
||||
@tf.function(input_signature=t5_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("bloom")
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"bloom", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"
|
||||
|
||||
@@ -19,7 +19,7 @@ class GPT2Module(tf.Module):
|
||||
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
|
||||
|
||||
@tf.function(input_signature=gpt2_inputs)
|
||||
@tf.function(input_signature=gpt2_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class BertModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased",
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class BertModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
################################## Preprocessing inputs and model ############
|
||||
@@ -66,10 +66,12 @@ labels = load_labels()
|
||||
|
||||
|
||||
## Can pass any img or input to the forward module.
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50")
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"resnet50", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
|
||||
# shark_module.compile()
|
||||
shark_module.compile()
|
||||
path = shark_module.save_module()
|
||||
shark_module.load_module(path)
|
||||
result = shark_module.forward((img.detach().numpy(),))
|
||||
|
||||
@@ -47,7 +47,7 @@ def load_mlir(mlir_loc):
|
||||
return mlir_module
|
||||
|
||||
|
||||
def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
def compile_through_fx(model, inputs, mlir_loc=None, extra_args=[]):
|
||||
|
||||
module = load_mlir(mlir_loc)
|
||||
if mlir_loc == None:
|
||||
@@ -98,9 +98,12 @@ def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="tm_tensor"
|
||||
mlir_model,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile()
|
||||
shark_module.compile(extra_args)
|
||||
|
||||
return shark_module
|
||||
|
||||
@@ -161,6 +164,7 @@ if __name__ == "__main__":
|
||||
unet,
|
||||
(latent_model_input, torch.tensor([1.0]), text_embeddings),
|
||||
args.mlir_loc,
|
||||
["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
|
||||
)
|
||||
|
||||
# torch.jit.script(unet)
|
||||
|
||||
@@ -10,20 +10,59 @@ from torch._decomp import get_decompositions
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
##############################################################################
|
||||
# pip install diffusers
|
||||
# pip install scipy
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a photograph of an astronaut riding a horse",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument("--steps", type=int, default=50, help="the device to use")
|
||||
p.add_argument("--mlir_loc", type=str, default=None, help="the device to use")
|
||||
p.add_argument("--vae_loc", type=str, default=None, help="the device to use")
|
||||
args = p.parse_args()
|
||||
|
||||
#####################################################
|
||||
|
||||
|
||||
def fp16_unet():
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"stable_diff_f16_18_OCT",
|
||||
tank_url="gs://shark_tank/prashant_nod",
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
return shark_module
|
||||
|
||||
|
||||
def load_mlir(mlir_loc):
|
||||
import os
|
||||
|
||||
if mlir_loc == None:
|
||||
return None
|
||||
print(f"Trying to load the model from {mlir_loc}.")
|
||||
with open(os.path.join(mlir_loc)) as f:
|
||||
mlir_module = f.read()
|
||||
return mlir_module
|
||||
|
||||
|
||||
def compile_through_fx(model, inputs, device, mlir_loc=None):
|
||||
def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
|
||||
module = load_mlir(mlir_loc)
|
||||
if mlir_loc == None:
|
||||
@@ -74,106 +113,79 @@ def compile_through_fx(model, inputs, device, mlir_loc=None):
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=device, mlir_dialect="tm_tensor"
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
##############################################################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
DEBUG = False
|
||||
compiled_module = {}
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
|
||||
def stable_diff_inf(prompt: str, steps, device: str):
|
||||
|
||||
args = {}
|
||||
args["prompt"] = [prompt]
|
||||
args["steps"] = steps
|
||||
args["device"] = device
|
||||
args["mlir_loc"] = "./stable_diffusion.mlir"
|
||||
output_loc = (
|
||||
f"stored_results/stable_diffusion/{prompt}_{int(steps)}_{device}.jpg"
|
||||
# 1. Load the autoencoder model which will be used to decode the latents into image space.
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
global DEBUG
|
||||
global compiled_module
|
||||
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
DEBUG = False
|
||||
log_write = open(r"logs/stable_diffusion_log.txt", "w")
|
||||
if log_write:
|
||||
DEBUG = True
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
if args["device"] not in compiled_module.keys():
|
||||
YOUR_TOKEN = "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"
|
||||
def forward(self, input):
|
||||
return self.vae.decode(input, return_dict=False)[0]
|
||||
|
||||
# 1. Load the autoencoder model which will be used to decode the latents into image space.
|
||||
compiled_module["vae"] = AutoencoderKL.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
vae = VaeModel()
|
||||
vae_input = torch.rand(1, 4, 64, 64)
|
||||
shark_vae = compile_through_fx(vae, (vae_input,), args.vae_loc)
|
||||
|
||||
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
|
||||
compiled_module["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
compiled_module["text_encoder"] = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
if DEBUG:
|
||||
log_write.write("Compiling the Unet module.\n")
|
||||
# Wrap the unet model to return tuples.
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
# Wrap the unet model to return tuples.
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
# # 3. The UNet model for generating the latents.
|
||||
unet = UnetModel()
|
||||
|
||||
# 3. The UNet model for generating the latents.
|
||||
unet = UnetModel()
|
||||
latent_model_input = torch.rand([2, 4, 64, 64])
|
||||
text_embeddings = torch.rand([2, 77, 768])
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
(latent_model_input, torch.tensor([1.0]), text_embeddings),
|
||||
args["device"],
|
||||
args["mlir_loc"],
|
||||
)
|
||||
compiled_module[args["device"]] = shark_unet
|
||||
if DEBUG:
|
||||
log_write.write("Compilation successful.\n")
|
||||
shark_unet = fp16_unet()
|
||||
|
||||
compiled_module["unet"] = unet
|
||||
compiled_module["scheduler"] = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
scheduler = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
shark_unet = compiled_module[args["device"]]
|
||||
vae = compiled_module["vae"]
|
||||
unet = compiled_module["unet"]
|
||||
tokenizer = compiled_module["tokenizer"]
|
||||
text_encoder = compiled_module["text_encoder"]
|
||||
scheduler = compiled_module["scheduler"]
|
||||
prompt = [args.prompt]
|
||||
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
|
||||
num_inference_steps = int(args["steps"]) # Number of denoising steps
|
||||
num_inference_steps = args.steps # Number of denoising steps
|
||||
|
||||
guidance_scale = 7.5 # Scale for classifier-free guidance
|
||||
|
||||
@@ -181,10 +193,10 @@ def stable_diff_inf(prompt: str, steps, device: str):
|
||||
42
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
batch_size = len(args["prompt"])
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_input = tokenizer(
|
||||
args["prompt"],
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
@@ -208,30 +220,41 @@ def stable_diff_inf(prompt: str, steps, device: str):
|
||||
(batch_size, unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
)
|
||||
# latents = latents.to(torch_device)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
latents = latents * scheduler.sigmas[0]
|
||||
# print(latents, latents.shape)
|
||||
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
|
||||
if DEBUG:
|
||||
log_write.write(f"i = {i} t = {t}\n")
|
||||
print(f"i = {i} t = {t}")
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
sigma = scheduler.sigmas[i]
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
latent_model_input_numpy = latent_model_input.detach().numpy()
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
|
||||
# with torch.no_grad():
|
||||
# noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
|
||||
|
||||
latent_model_input_numpy = (
|
||||
latent_model_input.detach().numpy().astype(np.half)
|
||||
)
|
||||
text_embeddings_numpy = (
|
||||
text_embeddings.detach().numpy().astype(np.half)
|
||||
)
|
||||
|
||||
noise_pred = shark_unet.forward(
|
||||
(
|
||||
latent_model_input_numpy,
|
||||
np.array([t]).astype(np.float32),
|
||||
np.array([t]).astype(np.half),
|
||||
text_embeddings_numpy,
|
||||
)
|
||||
)
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
noise_pred = torch.from_numpy(noise_pred).to(torch.float32)
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
@@ -242,21 +265,16 @@ def stable_diff_inf(prompt: str, steps, device: str):
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
|
||||
|
||||
# print("Latents shape : ", latents.shape)
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = vae.decode(latents).sample
|
||||
latents_numpy = latents.detach().numpy()
|
||||
image = shark_vae.forward((latents_numpy,))
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
output = pil_images[0]
|
||||
# save the output image with the prompt name.
|
||||
output.save(os.path.join(output_loc))
|
||||
log_write.close()
|
||||
|
||||
std_output = ""
|
||||
with open(r"logs/stable_diffusion_log.txt", "r") as log_read:
|
||||
std_output = log_read.read()
|
||||
|
||||
return output, std_output
|
||||
pil_images[0].save("astro.jpg")
|
||||
@@ -17,7 +17,7 @@ from keras_cv.models.generative.stable_diffusion.text_encoder import (
|
||||
)
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
from shark.shark_downloader import download_model
|
||||
from PIL import Image
|
||||
|
||||
# pip install "git+https://github.com/keras-team/keras-cv.git"
|
||||
@@ -75,8 +75,8 @@ class SharkStableDiffusion:
|
||||
# Create models
|
||||
self.text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_tf_model(
|
||||
"stable_diff", tank_url="gs://shark_tank/quinn"
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"stable_diff", tank_url="gs://shark_tank/quinn", frontend="tf"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=device, mlir_dialect="mhlo"
|
||||
|
||||
2
shark/examples/shark_inference/stable_diffusion/.gitignore
vendored
Normal file
2
shark/examples/shark_inference/stable_diffusion/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*.vmfb
|
||||
*.jpg
|
||||
44
shark/examples/shark_inference/stable_diffusion/README.md
Normal file
44
shark/examples/shark_inference/stable_diffusion/README.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# STABLE DIFFUSION
|
||||
|
||||
## Installation
|
||||
|
||||
Follow setup instructions in the main [README.md](https://github.com/nod-ai/SHARK#readme) for regular usage.
|
||||
|
||||
## Debug commands and other advanced usage follows.
|
||||
|
||||
```shell
|
||||
python main.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
|
||||
|
||||
```
|
||||
|
||||
## dump all dispatch .spv and isa using amdllpc
|
||||
|
||||
```shell
|
||||
python main.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
|
||||
```
|
||||
|
||||
## Compile and save the .vmfb (using vulkan fp16 as an example):
|
||||
|
||||
```shell
|
||||
python shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
|
||||
```
|
||||
|
||||
## Capture an RGP trace
|
||||
|
||||
```shell
|
||||
python shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
|
||||
```
|
||||
|
||||
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
|
||||
|
||||
```shell
|
||||
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --device=vulkan --function_input=1x4x64x64xf16
|
||||
```
|
||||
|
||||
## Run the unet module with iree-benchmark-module (same config as above):
|
||||
```shell
|
||||
##if you want to use .npz inputs:
|
||||
unzip ~/.local/shark_tank/<your unet>/inputs.npz
|
||||
|
||||
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --function_input=@arr_0.npy --function_input=1xf16 --function_input=@arr_2.npy --function_input=@arr_3.npy --function_input=@arr_4.npy
|
||||
```
|
||||
@@ -0,0 +1,25 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(
|
||||
text=["a photo of a cat", "a photo of a dog"],
|
||||
images=image,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
|
||||
outputs = model(**inputs)
|
||||
logits_per_image = (
|
||||
outputs.logits_per_image
|
||||
) # this is the image-text similarity score
|
||||
probs = logits_per_image.softmax(
|
||||
dim=1
|
||||
) # we can take the softmax to get the label probabilities
|
||||
188
shark/examples/shark_inference/stable_diffusion/main.py
Normal file
188
shark/examples/shark_inference/stable_diffusion/main.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from stable_args import args
|
||||
from utils import get_shark_model, set_iree_runtime_flags
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
import time
|
||||
from model_wrappers import get_vae_mlir
|
||||
from shark.iree_utils.compile_utils import dump_isas
|
||||
|
||||
# Helper function to profile the vulkan device.
|
||||
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
|
||||
if args.vulkan_debug_utils and "vulkan" in args.device:
|
||||
import iree
|
||||
|
||||
print(f"Profiling and saving to {file_path}.")
|
||||
vulkan_device = iree.runtime.get_device(args.device)
|
||||
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
|
||||
return vulkan_device
|
||||
return None
|
||||
|
||||
|
||||
def end_profiling(device):
|
||||
if device:
|
||||
return device.end_profiling()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
|
||||
prompt = args.prompts
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
if args.version == "v2":
|
||||
height = 768
|
||||
width = 768
|
||||
|
||||
num_inference_steps = args.steps # Number of denoising steps
|
||||
|
||||
# Scale for classifier-free guidance
|
||||
guidance_scale = torch.tensor(args.guidance_scale).to(torch.float32)
|
||||
|
||||
generator = torch.manual_seed(
|
||||
args.seed
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
set_iree_runtime_flags()
|
||||
unet = get_unet()
|
||||
vae = get_vae()
|
||||
clip = get_clip()
|
||||
if args.dump_isa:
|
||||
dump_isas(args.dispatch_benchmarks_dir)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
if args.version == "v2":
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2", subfolder="tokenizer"
|
||||
)
|
||||
|
||||
scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
if args.version == "v2.1base":
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
|
||||
)
|
||||
|
||||
scheduler = EulerDiscreteScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
start = time.time()
|
||||
|
||||
text_input = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=args.max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
clip_inf_start = time.time()
|
||||
text_embeddings = clip.forward((text_input.input_ids,))
|
||||
clip_inf_end = time.time()
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer(
|
||||
[""] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_clip_inf_start = time.time()
|
||||
uncond_embeddings = clip.forward((uncond_input.input_ids,))
|
||||
uncond_clip_inf_end = time.time()
|
||||
uncond_embeddings = torch.from_numpy(uncond_embeddings).to(dtype)
|
||||
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler.is_scale_input_called = True
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
avg_ms = 0
|
||||
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
step_start = time.time()
|
||||
print(f"i = {i} t = {t}", end="")
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
latents_numpy = latent_model_input.detach().numpy()
|
||||
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
|
||||
noise_pred = unet.forward(
|
||||
(
|
||||
latents_numpy,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
)
|
||||
)
|
||||
|
||||
end_profiling(profile_device)
|
||||
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
step_time = time.time() - step_start
|
||||
avg_ms += step_time
|
||||
step_ms = int((step_time) * 1000)
|
||||
print(f" ({step_ms}ms)")
|
||||
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
print(f"Average step time: {avg_ms}ms/it")
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
# latents = latents.
|
||||
latents_numpy = latents.detach().numpy()
|
||||
profile_device = start_profiling(file_path="vae.rdc")
|
||||
vae_start = time.time()
|
||||
image = vae.forward((latents_numpy,))
|
||||
vae_end = time.time()
|
||||
end_profiling(profile_device)
|
||||
image = torch.from_numpy(image)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1) * 255.0
|
||||
images = image.numpy().round().astype("uint8")
|
||||
total_end = time.time()
|
||||
|
||||
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
|
||||
uncond_clip_inf_time = (uncond_clip_inf_end - uncond_clip_inf_start) * 1000
|
||||
avg_clip_inf = (clip_inf_time + uncond_clip_inf_time) / 2
|
||||
vae_inf_time = (vae_end - vae_start) * 1000
|
||||
print(
|
||||
f"Clip Inference Avg time (ms) = ({clip_inf_time:.3f} + {uncond_clip_inf_time:.3f}) / 2 = {avg_clip_inf:.3f}"
|
||||
)
|
||||
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
|
||||
print(f"Total image generation runtime (s): {total_end - start:.4f}")
|
||||
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
for i in range(batch_size):
|
||||
pil_images[i].save(f"{args.prompts[i]}_{i}.jpg")
|
||||
@@ -0,0 +1,173 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
||||
from transformers import CLIPTextModel
|
||||
from utils import compile_through_fx
|
||||
from stable_args import args
|
||||
import torch
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
|
||||
model_config = {
|
||||
"v2": "stabilityai/stable-diffusion-2",
|
||||
"v1.4": "CompVis/stable-diffusion-v1-4",
|
||||
}
|
||||
|
||||
model_input = {
|
||||
"v2": {
|
||||
"clip": (torch.randint(1, 2, (1, 77)),),
|
||||
"vae": (torch.randn(1, 4, 96, 96),),
|
||||
"unet": (
|
||||
torch.randn(1, 4, 96, 96), # latents
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, 77, 1024), # embedding
|
||||
torch.tensor(1).to(torch.float32), # guidance_scale
|
||||
),
|
||||
},
|
||||
"v1.4": {
|
||||
"clip": (torch.randint(1, 2, (1, 77)),),
|
||||
"vae": (torch.randn(1, 4, 64, 64),),
|
||||
"unet": (
|
||||
torch.randn(1, 4, 64, 64),
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, 77, 768),
|
||||
torch.tensor(1).to(torch.float32),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# revision param for from_pretrained defaults to "main" => fp32
|
||||
model_revision = "fp16" if args.precision == "fp16" else "main"
|
||||
|
||||
|
||||
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
if args.version == "v2":
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_config[args.version], subfolder="text_encoder"
|
||||
)
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
model_input[args.version]["clip"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
|
||||
def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version],
|
||||
subfolder="vae",
|
||||
revision=model_revision,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
return (x / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
vae = VaeModel()
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_vae_encode_mlir(model_name="vae_encode", extra_args=[]):
|
||||
class VaeEncodeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version],
|
||||
subfolder="vae",
|
||||
revision="fp16",
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
input = 2 * (x - 0.5)
|
||||
return self.vae.encode(input, return_dict=False)[0]
|
||||
|
||||
vae = VaeEncodeModel()
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input[args.version]["vae"]]
|
||||
)
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_config[args.version],
|
||||
subfolder="unet",
|
||||
revision=model_revision,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, latent, timestep, text_embedding, guidance_scale):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latents = torch.cat([latent] * 2)
|
||||
unet_out = self.unet.forward(
|
||||
latents, timestep, text_embedding, return_dict=False
|
||||
)[0]
|
||||
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel()
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input[args.version]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["unet"]
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_unet
|
||||
153
shark/examples/shark_inference/stable_diffusion/opt_params.py
Normal file
153
shark/examples/shark_inference/stable_diffusion/opt_params.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import sys
|
||||
from model_wrappers import (
|
||||
get_vae_mlir,
|
||||
get_vae_encode_mlir,
|
||||
get_unet_mlir,
|
||||
get_clip_mlir,
|
||||
)
|
||||
from stable_args import args
|
||||
from utils import get_shark_model
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
if BATCH_SIZE != 1:
|
||||
sys.exit("Only batch size 1 is supported.")
|
||||
|
||||
|
||||
def get_unet():
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Tuned model is present for `fp16` precision.
|
||||
if args.precision == "fp16":
|
||||
if args.use_tuned:
|
||||
bucket = "gs://shark_tank/vivian"
|
||||
model_name = "unet_1dec_fp16_tuned"
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
else:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "unet_1dec_fp16"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "unet2base_8dec_fp16"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
# Tuned model is not present for `fp32` case.
|
||||
if args.precision == "fp32":
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "unet_1dec_fp32"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
if args.precision == "int8":
|
||||
bucket = "gs://shark_tank/prashant_nod"
|
||||
model_name = "unet_int8"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
]
|
||||
sys.exit("int8 model is currently in maintenance.")
|
||||
# # TODO: Pass iree_flags to the exported model.
|
||||
# if args.import_mlir:
|
||||
# sys.exit(
|
||||
# "--import_mlir is not supported for the int8 model, try --no-import_mlir flag."
|
||||
# )
|
||||
# return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae():
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
if args.precision in ["fp16", "int8"]:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "vae_1dec_fp16"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "vae2base_8dec_fp16"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_vae_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
if args.precision == "fp32":
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "vae_1dec_fp32"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_vae_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae_encode():
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
if args.precision in ["fp16", "int8"]:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "vae_encode_1dec_fp16"
|
||||
if args.version == "v2":
|
||||
model_name = "vae2_encode_29nov_fp16"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_vae_encode_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
if args.precision == "fp32":
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "vae_encode_1dec_fp32"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_vae_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_clip():
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "clip_1dec_fp32"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "clip2base_8dec_fp32"
|
||||
iree_flags += [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_clip_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
@@ -0,0 +1,44 @@
|
||||
Compile / Run Instructions:
|
||||
|
||||
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
|
||||
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
|
||||
|
||||
Compile Commands FP32/FP16:
|
||||
|
||||
```shell
|
||||
Vulkan AMD:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
|
||||
# use –iree-input-type=mhlo for tf models
|
||||
|
||||
CUDA NVIDIA:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
CPU:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
```
|
||||
|
||||
|
||||
|
||||
Run / Benchmark Command (FP32 - NCHW):
|
||||
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
|
||||
|
||||
```shell
|
||||
## Vulkan AMD:
|
||||
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --device=vulkan --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
|
||||
|
||||
## CUDA:
|
||||
iree-benchmark-module --module_file=/path/to/vmfb --entry_function=forward --device=cuda --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
|
||||
|
||||
## CPU:
|
||||
iree-benchmark-module --module_file=/path/to/vmfb --entry_function=forward --device=local-task --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
|
||||
|
||||
```
|
||||
|
||||
Run via vulkan_gui for RGP Profiling:
|
||||
|
||||
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
|
||||
```shell
|
||||
./build/vulkan_gui/iree-vulkan-gui --module_file=/path/to/unet.vmfb --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
|
||||
```
|
||||
128
shark/examples/shark_inference/stable_diffusion/stable_args.py
Normal file
128
shark/examples/shark_inference/stable_diffusion/stable_args.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=["a photograph of an astronaut riding a horse"],
|
||||
help="text of which images to be generated.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--device", type=str, default="cpu", help="device to run the model."
|
||||
)
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the no. of steps to do the sampling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="v1.4",
|
||||
help="Specify version of stable diffusion model",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed to use.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="the value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp32", help="precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=77,
|
||||
help="max length of the tokenizer output.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="saves the compiled flatbuffer to the local directory",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree-vulkan-target-triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Profiles vulkan device and collects the .rdc info",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dump_isa",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks",
|
||||
default=None,
|
||||
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="temp_dispatch_benchmarks",
|
||||
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="4294967296",
|
||||
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--enable_rgp",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for inserting debug frames between iterations for use with rgp.",
|
||||
)
|
||||
|
||||
args = p.parse_args()
|
||||
@@ -0,0 +1,111 @@
|
||||
# Stable Diffusion optimized for AMD RDNA2/RDNA3 GPUs
|
||||
|
||||
## Install the latest AMD Drivers
|
||||
|
||||
### RDNA2 Drivers:
|
||||
*AMD Software: Adrenalin Edition 22.11.1 for MLIR/IREE Driver Version 22.20.29.09 for Windows® 10 and Windows® 11 (Windows Driver Store Version 31.0.12029.9003)*
|
||||
|
||||
https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mlir-iree
|
||||
|
||||
Note that if you previously tried Stable Diffusion with a different driver, it may be necessary to clear vulkan cache after changing drivers.
|
||||
|
||||
For Windows users this can be done by clearing the contents of `C:\Users\<username>\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
|
||||
|
||||
## Installation
|
||||
|
||||
Download the latest Windows `shark_sd.exe` [here](https://storage.googleapis.com/shark-public/anush/windows/shark_sd_05122022v3.exe). Accept if Windows warns of an unsigned .exe.
|
||||
|
||||
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
|
||||
|
||||
|
||||
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
|
||||
|
||||
|
||||
Here are some samples generated:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
<details>
|
||||
<summary>Advanced Installation </summary>
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
|
||||
### Windows 10/11 Users
|
||||
|
||||
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
|
||||
|
||||
* Install Git for Windows from [here](https://git-scm.com/download/win)
|
||||
|
||||
#### Allow the install script to run in Powershell
|
||||
```powershell
|
||||
set-executionpolicy remotesigned
|
||||
```
|
||||
|
||||
#### Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...)
|
||||
```powershell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
cd SHARK
|
||||
./setup_venv.ps1 #You can re-run this script to get the latest version
|
||||
```
|
||||
|
||||
### Linux
|
||||
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
cd SHARK
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
```
|
||||
|
||||
### Run Stable Diffusion on your device - WebUI
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\Users\nod\SHARK> cd web
|
||||
(shark.venv) PS C:\Users\nod\SHARK\web> python index.py
|
||||
```
|
||||
#### Linux Users
|
||||
```shell
|
||||
(shark.venv) > cd web
|
||||
(shark.venv) > python index.py
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Run Stable Diffusion on your device - Commandline
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> python .\shark\examples\shark_inference\stable_diffusion\main.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
```
|
||||
|
||||
#### Linux
|
||||
```shell
|
||||
python3.10 shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
```
|
||||
|
||||
The output on a 6900XT would like:
|
||||
|
||||
```shell
|
||||
44it [00:08, 5.14it/s]i = 44 t = 120 (191ms)
|
||||
45it [00:08, 5.15it/s]i = 45 t = 100 (191ms)
|
||||
46it [00:08, 5.16it/s]i = 46 t = 80 (191ms)
|
||||
47it [00:09, 5.16it/s]i = 47 t = 60 (193ms)
|
||||
48it [00:09, 5.15it/s]i = 48 t = 40 (195ms)
|
||||
49it [00:09, 5.12it/s]i = 49 t = 20 (196ms)
|
||||
50it [00:09, 5.14it/s]
|
||||
Average step time: 192.8154182434082ms/it
|
||||
Total image generation runtime (s): 10.390909433364868
|
||||
(shark.venv) PS C:\g\shark>
|
||||
```
|
||||
|
||||
|
||||
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)
|
||||
</details>
|
||||
<details>
|
||||
<summary>Discord link</summary>
|
||||
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
|
||||
</details>
|
||||
83
shark/examples/shark_inference/stable_diffusion/utils.py
Normal file
83
shark/examples/shark_inference/stable_diffusion/utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from stable_args import args
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
print("Saving to {}".format(vmfb_path))
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(
|
||||
vmfb_path
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=tank_url,
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
# Converts the torch-module into a shark_module.
|
||||
def compile_through_fx(model, inputs, model_name, extra_args=[]):
|
||||
|
||||
mlir_module, func_name = import_with_fx(model, inputs)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
|
||||
]
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if "vulkan" in args.device:
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
return
|
||||
@@ -18,7 +18,7 @@ class T5Module(tf.Module):
|
||||
self.m = TFT5Model.from_pretrained("t5-small")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, decoder_input_ids=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs)
|
||||
@tf.function(input_signature=t5_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, decoder_input_ids):
|
||||
return self.m.predict(input_ids, decoder_input_ids)
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"v_diffusion", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="vulkan", mlir_dialect="linalg"
|
||||
|
||||
@@ -52,7 +52,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
# Stable Diffusion Img2Img model
|
||||
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
<summary>Installation (Linux)</summary>
|
||||
|
||||
### Activate shark.venv Virtual Environment
|
||||
|
||||
```shell
|
||||
source shark.venv/bin/activate
|
||||
|
||||
# Some older pip installs may not be able to handle the recent PyTorch deps
|
||||
python -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
### Install dependencies
|
||||
|
||||
# Run the setup.sh script
|
||||
|
||||
```shell
|
||||
./setup.sh
|
||||
```
|
||||
|
||||
### Run the Stable diffusion Img2Img model
|
||||
|
||||
To run the model with the default set of images and params, run:
|
||||
```shell
|
||||
python stable_diffusion_img2img.py
|
||||
```
|
||||
To run the model with your set of images, and parameters you need to specify the following params:
|
||||
1.) Input images directory with the arg `--input_dir` containing 3-5 images.
|
||||
2.) What to teach the model? Using the arg `--what_to_teach`, allowed values are `object` or `style`.
|
||||
3.) Placeholder token using the arg `--placeholder_token`, that represents your new concept. It should be passed with the opening and closing angle brackets. For ex: token is `cat-toy`, it should be passed as `<cat-toy>`.
|
||||
4.) Initializer token using the arg `--initializer_token`, which summarise what is your new concept.
|
||||
|
||||
For the result, you need to pass the text prompt with the arg: `--prompt`. The prompt string should contain a "*s" in it, which will be replaced by the placeholder token during the inference.
|
||||
|
||||
By default the result images will go into the `sd_result` dir. To specify your output dir use the arg: `--output_dir`.
|
||||
|
||||
The default value of max_training_steps is `3000`, which takes some hours to complete. You can pass the smaller value with the arg `--training_steps`. Specify the number of images to be sampled for the result with the `--num_inference_samples` arg.
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
TD="$(cd $(dirname $0) && pwd)"
|
||||
if [ -z "$PYTHON" ]; then
|
||||
PYTHON="$(which python3)"
|
||||
fi
|
||||
|
||||
function die() {
|
||||
echo "Error executing command: $*"
|
||||
exit 1
|
||||
}
|
||||
|
||||
PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; print("{0}.{1}".format(*version))'`
|
||||
|
||||
echo "Python: $PYTHON"
|
||||
echo "Python version: $PYTHON_VERSION_X_Y"
|
||||
|
||||
mkdir input_images
|
||||
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg -P input_images/
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg -P input_images/
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg -P input_images/
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg -P input_images/
|
||||
|
||||
pip install diffusers["training"]==0.4.1 transformers ftfy opencv-python
|
||||
@@ -0,0 +1,597 @@
|
||||
# Textual-inversion fine-tuning for Stable Diffusion using diffusers
|
||||
# This script shows how to "teach" Stable Diffusion a new concept via
|
||||
# textual-inversion using 🤗 Hugging Face [🧨 Diffusers library](https://github.com/huggingface/diffusers).
|
||||
# By using just 3-5 images you can teach new concepts to Stable Diffusion
|
||||
# and personalize the model on your own images.
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import PIL
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
p.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
default="input_images/",
|
||||
help="the directory contains the images used for fine tuning",
|
||||
)
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd_result",
|
||||
help="the directory contains the images used for fine tuning",
|
||||
)
|
||||
p.add_argument(
|
||||
"--training_steps",
|
||||
type=int,
|
||||
default=3000,
|
||||
help="the maximum number of training steps",
|
||||
)
|
||||
p.add_argument("--seed", type=int, default=42, help="the random seed")
|
||||
p.add_argument(
|
||||
"--what_to_teach",
|
||||
type=str,
|
||||
choices=["object", "style"],
|
||||
default="object",
|
||||
help="what is it that you are teaching?",
|
||||
)
|
||||
p.add_argument(
|
||||
"--placeholder_token",
|
||||
type=str,
|
||||
default="<cat-toy>",
|
||||
help="It is the token you are going to use to represent your new concept",
|
||||
)
|
||||
p.add_argument(
|
||||
"--initializer_token",
|
||||
type=str,
|
||||
default="toy",
|
||||
help="It is a word that can summarise what is your new concept",
|
||||
)
|
||||
p.add_argument(
|
||||
"--inference_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the number of steps for inference",
|
||||
)
|
||||
p.add_argument(
|
||||
"--num_inference_samples",
|
||||
type=int,
|
||||
default=4,
|
||||
help="the number of samples for inference",
|
||||
)
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a grafitti in a wall with a *s on it",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
if "*s" not in args.prompt:
|
||||
raise ValueError(
|
||||
f'The prompt should have a "*s" which will be replaced by a placeholder token.'
|
||||
)
|
||||
|
||||
prompt1, prompt2 = args.prompt.split("*s")
|
||||
args.prompt = prompt1 + args.placeholder_token + prompt2
|
||||
|
||||
pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
# Load input images.
|
||||
images = []
|
||||
for filename in os.listdir(args.input_dir):
|
||||
img = cv2.imread(os.path.join(args.input_dir, filename))
|
||||
if img is not None:
|
||||
images.append(img)
|
||||
|
||||
# Setup the prompt templates for training
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
# Setup the dataset
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
):
|
||||
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
|
||||
self.templates = (
|
||||
imagenet_style_templates_small
|
||||
if learnable_property == "style"
|
||||
else imagenet_templates_small
|
||||
)
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
text = random.choice(self.templates).format(placeholder_string)
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
# Setting up the model
|
||||
# Load the tokenizer and add the placeholder token as a additional special token.
|
||||
# Please read and if you agree accept the LICENSE
|
||||
# [here](https://huggingface.co/CompVis/stable-diffusion-v1-4) if you see an error
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Get token ids for our placeholder and initializer token.
|
||||
# This code block will complain if initializer string is not a single token
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load the Stable Diffusion model
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# We have added the `placeholder_token` in the `tokenizer` so we resize the token embeddings here,
|
||||
# this will a new embedding vector in the token embeddings for our `placeholder_token`
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
|
||||
# In Textual-Inversion we only train the newly added embedding vector,
|
||||
# so lets freeze rest of the model parameters here.
|
||||
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
# Freeze vae and unet
|
||||
freeze_params(vae.parameters())
|
||||
freeze_params(unet.parameters())
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
params_to_freeze = itertools.chain(
|
||||
text_encoder.text_model.encoder.parameters(),
|
||||
text_encoder.text_model.final_layer_norm.parameters(),
|
||||
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
||||
)
|
||||
freeze_params(params_to_freeze)
|
||||
|
||||
# Creating our training data
|
||||
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=args.input_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=512,
|
||||
placeholder_token=args.placeholder_token,
|
||||
repeats=100,
|
||||
learnable_property=args.what_to_teach, # Option selected above between object and style
|
||||
center_crop=False,
|
||||
set="train",
|
||||
)
|
||||
|
||||
|
||||
def create_dataloader(train_batch_size=1):
|
||||
return torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True
|
||||
)
|
||||
|
||||
|
||||
# Create noise_scheduler for training.
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
tensor_format="pt",
|
||||
)
|
||||
|
||||
# Define hyperparameters for our training
|
||||
hyperparameters = {
|
||||
"learning_rate": 5e-04,
|
||||
"scale_lr": True,
|
||||
"max_train_steps": args.training_steps,
|
||||
"train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"seed": args.seed,
|
||||
"output_dir": "sd-concept-output",
|
||||
}
|
||||
|
||||
|
||||
def training_function(text_encoder, vae, unet):
|
||||
logger = get_logger(__name__)
|
||||
|
||||
train_batch_size = hyperparameters["train_batch_size"]
|
||||
gradient_accumulation_steps = hyperparameters[
|
||||
"gradient_accumulation_steps"
|
||||
]
|
||||
learning_rate = hyperparameters["learning_rate"]
|
||||
max_train_steps = hyperparameters["max_train_steps"]
|
||||
output_dir = hyperparameters["output_dir"]
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
train_dataloader = create_dataloader(train_batch_size)
|
||||
|
||||
if hyperparameters["scale_lr"]:
|
||||
learning_rate = (
|
||||
learning_rate
|
||||
* gradient_accumulation_steps
|
||||
* train_batch_size
|
||||
* accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=learning_rate,
|
||||
)
|
||||
|
||||
text_encoder, optimizer, train_dataloader = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
# Move vae and unet to device
|
||||
vae.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
|
||||
# Keep vae and unet in eval model as we don't train these
|
||||
vae.eval()
|
||||
unet.eval()
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
total_batch_size = (
|
||||
train_batch_size
|
||||
* accelerator.num_processes
|
||||
* gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
range(max_train_steps), disable=not accelerator.is_local_main_process
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = (
|
||||
vae.encode(batch["pixel_values"])
|
||||
.latent_dist.sample()
|
||||
.detach()
|
||||
)
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(latents.shape).to(latents.device)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(
|
||||
latents, noise, timesteps
|
||||
)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(
|
||||
noisy_latents, timesteps, encoder_hidden_states
|
||||
).sample
|
||||
|
||||
loss = (
|
||||
F.mse_loss(noise_pred, noise, reduction="none")
|
||||
.mean([1, 2, 3])
|
||||
.mean()
|
||||
)
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Zero out the gradients for all token embeddings except the newly added
|
||||
# embeddings for the concept, as we only want to optimize the concept embeddings
|
||||
if accelerator.num_processes > 1:
|
||||
grads = (
|
||||
text_encoder.module.get_input_embeddings().weight.grad
|
||||
)
|
||||
else:
|
||||
grads = text_encoder.get_input_embeddings().weight.grad
|
||||
# Get the index for tokens that we want to zero the grads for
|
||||
index_grads_to_zero = (
|
||||
torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
)
|
||||
grads.data[index_grads_to_zero, :] = grads.data[
|
||||
index_grads_to_zero, :
|
||||
].fill_(0)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {"loss": loss.detach().item()}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
skip_prk_steps=True,
|
||||
),
|
||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker"
|
||||
),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained(
|
||||
"openai/clip-vit-base-patch32"
|
||||
),
|
||||
)
|
||||
pipeline.save_pretrained(output_dir)
|
||||
# Also save the newly trained embeddings
|
||||
learned_embeds = (
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight[placeholder_token_id]
|
||||
)
|
||||
learned_embeds_dict = {
|
||||
args.placeholder_token: learned_embeds.detach().cpu()
|
||||
}
|
||||
torch.save(
|
||||
learned_embeds_dict, os.path.join(output_dir, "learned_embeds.bin")
|
||||
)
|
||||
|
||||
|
||||
import accelerate
|
||||
|
||||
accelerate.notebook_launcher(
|
||||
training_function, args=(text_encoder, vae, unet), num_processes=1
|
||||
)
|
||||
|
||||
# Set up the pipeline
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
hyperparameters["output_dir"],
|
||||
# torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
all_images = []
|
||||
for _ in range(args.num_inference_samples):
|
||||
images = pipe(
|
||||
[args.prompt],
|
||||
num_inference_steps=args.inference_steps,
|
||||
guidance_scale=7.5,
|
||||
).images
|
||||
all_images.extend(images)
|
||||
|
||||
# output_path = os.path.abspath(os.path.join(os.getcwd(), args.output_dir))
|
||||
if not os.path.isdir(args.output_dir):
|
||||
os.mkdir(args.output_dir)
|
||||
|
||||
[
|
||||
image.save(f"{args.output_dir}/{i}.jpeg")
|
||||
for i, image in enumerate(all_images)
|
||||
]
|
||||
@@ -37,7 +37,64 @@ def run_cmd(cmd):
|
||||
sys.exit("Exiting program due to error running:", cmd)
|
||||
|
||||
|
||||
IREE_DEVICE_MAP = {
|
||||
def iree_device_map(device):
|
||||
|
||||
from iree.runtime import get_driver, get_device
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list = []
|
||||
for device_dict in device_list_src:
|
||||
device_list.append(f"{driver_name}://{device_dict['path']}")
|
||||
device_list.sort()
|
||||
return device_list
|
||||
|
||||
# only supported for vulkan as of now
|
||||
if "vulkan://" in device:
|
||||
device_list = get_all_devices("vulkan")
|
||||
_, d_index = device.split("://")
|
||||
matched_index = None
|
||||
match_with_index = False
|
||||
if 0 <= len(d_index) <= 2:
|
||||
try:
|
||||
d_index = int(d_index)
|
||||
except:
|
||||
print(
|
||||
f"{d_index} is not valid index or uri. Will choose device 0"
|
||||
)
|
||||
d_index = 0
|
||||
match_with_index = True
|
||||
|
||||
if len(device_list) > 1:
|
||||
print("List of available vulkan devices:")
|
||||
for i, d in enumerate(device_list):
|
||||
print(f"vulkan://{i} => {d}")
|
||||
if (match_with_index and d_index == i) or (
|
||||
not match_with_index and d == device
|
||||
):
|
||||
matched_index = i
|
||||
print(
|
||||
f"Choosing device vulkan://{matched_index}\nTo choose another device please specify device index or uri accordingly."
|
||||
)
|
||||
return get_device(device_list[matched_index])
|
||||
elif len(device_list) == 1:
|
||||
print(f"Found one vulkan device: {device_list[0]}. Using this.")
|
||||
return get_device(device_list[0])
|
||||
else:
|
||||
print(
|
||||
f"No device found! returning device corresponding to driver name: vulkan"
|
||||
)
|
||||
return _IREE_DEVICE_MAP["vulkan"]
|
||||
else:
|
||||
return _IREE_DEVICE_MAP[device]
|
||||
|
||||
|
||||
def get_supported_device_list():
|
||||
return list(_IREE_DEVICE_MAP.keys())
|
||||
|
||||
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
@@ -46,7 +103,14 @@ IREE_DEVICE_MAP = {
|
||||
"intel-gpu": "level_zero",
|
||||
}
|
||||
|
||||
IREE_TARGET_MAP = {
|
||||
|
||||
def iree_target_map(device):
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
return _IREE_TARGET_MAP[device]
|
||||
|
||||
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
@@ -58,6 +122,9 @@ IREE_TARGET_MAP = {
|
||||
# Finds whether the required drivers are installed for the given device.
|
||||
def check_device_drivers(device):
|
||||
"""Checks necessary drivers present for gpu and vulkan devices"""
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
subprocess.check_output("nvidia-smi")
|
||||
|
||||
@@ -13,12 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
|
||||
from shark.iree_utils._common import run_cmd, IREE_DEVICE_MAP
|
||||
from shark.iree_utils._common import run_cmd, iree_device_map
|
||||
from shark.iree_utils.cpu_utils import get_cpu_count
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
|
||||
UNIT_TO_SECOND_MAP = {"ms": 0.001, "s": 1}
|
||||
UNIT_TO_SECOND_MAP = {"us": 1e-6, "ms": 0.001, "s": 1}
|
||||
|
||||
|
||||
def tensor_to_type_str(input_tensors: tuple, mlir_dialect: str):
|
||||
@@ -69,10 +70,40 @@ def build_benchmark_args(
|
||||
# TODO: Replace name of train with actual train fn name.
|
||||
fn_name = "train"
|
||||
benchmark_cl.append(f"--entry_function={fn_name}")
|
||||
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
|
||||
benchmark_cl.append(f"--device={iree_device_map(device)}")
|
||||
mlir_input_types = tensor_to_type_str(input_tensors, mlir_dialect)
|
||||
for mlir_input in mlir_input_types:
|
||||
benchmark_cl.append(f"--function_input={mlir_input}")
|
||||
if device == "cpu":
|
||||
num_cpus = get_cpu_count()
|
||||
if num_cpus is not None:
|
||||
benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl.append(time_extractor)
|
||||
return benchmark_cl
|
||||
|
||||
|
||||
def build_benchmark_args_non_tensor_input(
|
||||
input_file: str,
|
||||
device: str,
|
||||
inputs: tuple,
|
||||
mlir_dialect: str,
|
||||
function_name: str,
|
||||
):
|
||||
"""
|
||||
Inputs: input_file leading to vmfb, input_tensor to function, target device,
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = benchmark_module.__path__[0]
|
||||
benchmarker_path = os.path.join(path, "..", "..", "iree-benchmark-module")
|
||||
benchmark_cl = [benchmarker_path, f"--module_file={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
if function_name:
|
||||
benchmark_cl.append(f"--entry_function={function_name}")
|
||||
benchmark_cl.append(f"--device={iree_device_map(device)}")
|
||||
for input in inputs:
|
||||
benchmark_cl.append(f"--function_input={input}")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl.append(time_extractor)
|
||||
return benchmark_cl
|
||||
|
||||
@@ -13,12 +13,17 @@
|
||||
# limitations under the License.
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.iree_utils._common import IREE_DEVICE_MAP, IREE_TARGET_MAP
|
||||
from shark.iree_utils._common import iree_device_map, iree_target_map
|
||||
from shark.iree_utils.benchmark_utils import *
|
||||
from shark.parser import shark_args
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device):
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
if device == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
@@ -30,7 +35,7 @@ def get_iree_device_args(device):
|
||||
if device in ["metal", "vulkan"]:
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args()
|
||||
return get_iree_vulkan_args(extra_args=extra_args)
|
||||
if device == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
@@ -62,14 +67,178 @@ def get_iree_common_args():
|
||||
]
|
||||
|
||||
|
||||
# Args that are suitable only for certain models or groups of models.
|
||||
# shark_args are passed down from pytests to control which models compile with these flags,
|
||||
# but they can also be set in shark/parser.py
|
||||
def get_model_specific_args():
|
||||
ms_args = []
|
||||
if shark_args.enable_conv_transform == True:
|
||||
ms_args += ["--iree-flow-enable-conv-nchw-to-nhwc-transform"]
|
||||
return ms_args
|
||||
|
||||
|
||||
def create_dispatch_dirs(bench_dir, device):
|
||||
protected_files = ["ordered-dispatches.txt"]
|
||||
bench_dir_path = bench_dir.split("/")
|
||||
bench_dir_path[-1] = "temp_" + bench_dir_path[-1]
|
||||
tmp_bench_dir = "/".join(bench_dir_path)
|
||||
for f_ in os.listdir(bench_dir):
|
||||
if os.path.isfile(f"{bench_dir}/{f_}") and f_ not in protected_files:
|
||||
dir_name = re.sub("\.\S*$", "", f_)
|
||||
if os.path.exists(f"{bench_dir}/{dir_name}"):
|
||||
os.system(f"rm -rf {bench_dir}/{dir_name}")
|
||||
os.system(f"mkdir {bench_dir}/{dir_name}")
|
||||
os.system(f"mv {bench_dir}/{f_} {bench_dir}/{dir_name}/{f_}")
|
||||
for f_ in os.listdir(tmp_bench_dir):
|
||||
if os.path.isfile(f"{tmp_bench_dir}/{f_}"):
|
||||
dir_name = ""
|
||||
for d_ in os.listdir(bench_dir):
|
||||
if re.search(f"{d_}(?=\D)", f_):
|
||||
dir_name = d_
|
||||
if dir_name != "":
|
||||
os.system(
|
||||
f"mv {tmp_bench_dir}/{f_} {bench_dir}/{dir_name}/{dir_name}_benchmark.mlir"
|
||||
)
|
||||
|
||||
|
||||
def dump_isas(bench_dir):
|
||||
for d_ in os.listdir(bench_dir):
|
||||
if os.path.isdir(f"{bench_dir}/{d_}"):
|
||||
for f_ in os.listdir(f"{bench_dir}/{d_}"):
|
||||
if f_.endswith(".spv"):
|
||||
os.system(
|
||||
f"amdllpc -gfxip 11.0 {bench_dir}/{d_}/{f_} -v > \
|
||||
{bench_dir}/{d_}/isa.txt"
|
||||
)
|
||||
|
||||
|
||||
def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
|
||||
benchmark_runtimes = {}
|
||||
dispatch_list = []
|
||||
all_dispatches = False
|
||||
|
||||
if dispatch_benchmarks.lower().strip() == "all":
|
||||
all_dispatches = True
|
||||
else:
|
||||
try:
|
||||
dispatch_list = [
|
||||
int(dispatch_index)
|
||||
for dispatch_index in dispatch_benchmarks.split(" ")
|
||||
]
|
||||
except:
|
||||
print("ERROR: Invalid dispatch benchmarks")
|
||||
return None
|
||||
for d_ in os.listdir(bench_dir):
|
||||
if os.path.isdir(f"{bench_dir}/{d_}"):
|
||||
in_dispatches = False
|
||||
for dispatch in dispatch_list:
|
||||
if str(dispatch) in d_:
|
||||
in_dispatches = True
|
||||
if all_dispatches or in_dispatches:
|
||||
for f_ in os.listdir(f"{bench_dir}/{d_}"):
|
||||
|
||||
if "benchmark.mlir" in f_:
|
||||
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
|
||||
module = dispatch_file.read()
|
||||
dispatch_file.close()
|
||||
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module, target_backends=[iree_target_map(device)]
|
||||
)
|
||||
|
||||
vmfb_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_benchmark.vmfb", "wb"
|
||||
)
|
||||
vmfb_file.write(flatbuffer_blob)
|
||||
vmfb_file.close()
|
||||
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_flatbuffer(
|
||||
config.vm_instance, flatbuffer_blob
|
||||
)
|
||||
|
||||
benchmark_cl = build_benchmark_args_non_tensor_input(
|
||||
input_file=f"{bench_dir}/{d_}/{d_}_benchmark.vmfb",
|
||||
device=device,
|
||||
inputs=(0,),
|
||||
mlir_dialect="linalg",
|
||||
function_name="",
|
||||
)
|
||||
|
||||
benchmark_bash = open(
|
||||
f"{bench_dir}/{d_}/{d_}_benchmark.sh", "w+"
|
||||
)
|
||||
benchmark_bash.write("#!/bin/bash\n")
|
||||
benchmark_bash.write(" ".join(benchmark_cl))
|
||||
benchmark_bash.close()
|
||||
|
||||
benchmark_data = run_benchmark_module(benchmark_cl)
|
||||
|
||||
benchmark_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
|
||||
)
|
||||
benchmark_file.write(f"DISPATCH: {d_}\n")
|
||||
benchmark_file.write(str(benchmark_data) + "\n")
|
||||
benchmark_file.write(
|
||||
"SHARK BENCHMARK RESULT: "
|
||||
+ str(1 / (benchmark_data * 0.001))
|
||||
+ "\n"
|
||||
)
|
||||
benchmark_file.close()
|
||||
|
||||
benchmark_runtimes[d_] = 1 / (benchmark_data * 0.001)
|
||||
|
||||
elif ".mlir" in f_ and "benchmark" not in f_:
|
||||
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
|
||||
module = dispatch_file.read()
|
||||
dispatch_file.close()
|
||||
|
||||
module = re.sub(
|
||||
"hal.executable private",
|
||||
"hal.executable public",
|
||||
module,
|
||||
)
|
||||
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=["--compile-mode=hal-executable"],
|
||||
)
|
||||
|
||||
spirv_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_spirv.vmfb", "wb"
|
||||
)
|
||||
spirv_file.write(flatbuffer_blob)
|
||||
spirv_file.close()
|
||||
|
||||
ordered_dispatches = [
|
||||
(k, v)
|
||||
for k, v in sorted(
|
||||
benchmark_runtimes.items(), key=lambda item: item[1]
|
||||
)
|
||||
][::-1]
|
||||
f_ = open(f"{bench_dir}/ordered-dispatches.txt", "w+")
|
||||
for dispatch in ordered_dispatches:
|
||||
f_.write(f"{dispatch[0]}: {dispatch[1]}ms\n")
|
||||
f_.close()
|
||||
|
||||
|
||||
def compile_module_to_flatbuffer(
|
||||
module, device, frontend, func_name, model_config_path
|
||||
module,
|
||||
device,
|
||||
frontend,
|
||||
func_name,
|
||||
model_config_path,
|
||||
extra_args,
|
||||
model_name="None",
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args()
|
||||
args += get_model_specific_args()
|
||||
args += extra_args
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "mhlo"
|
||||
@@ -78,25 +247,23 @@ def compile_module_to_flatbuffer(
|
||||
elif frontend in ["tflite", "tflite-tosa"]:
|
||||
input_type = "tosa"
|
||||
elif frontend in ["tm_tensor"]:
|
||||
input_type = frontend
|
||||
input_type = ireec.InputType.TM_TENSOR
|
||||
|
||||
# TODO: make it simpler.
|
||||
# Compile according to the input type, else just try compiling.
|
||||
if input_type not in ["mhlo", "tosa"]:
|
||||
module = str(module)
|
||||
if input_type != "":
|
||||
# Currently for MHLO/TOSA.
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[IREE_TARGET_MAP[device]],
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=args,
|
||||
input_type=input_type,
|
||||
)
|
||||
else:
|
||||
# Currently for Torch.
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
str(module),
|
||||
target_backends=[IREE_TARGET_MAP[device]],
|
||||
module,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=args,
|
||||
)
|
||||
|
||||
@@ -105,7 +272,7 @@ def compile_module_to_flatbuffer(
|
||||
|
||||
def get_iree_module(flatbuffer_blob, device, func_name):
|
||||
# Returns the compiled module and the configs.
|
||||
config = ireert.Config(IREE_DEVICE_MAP[device])
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_flatbuffer(
|
||||
config.vm_instance, flatbuffer_blob
|
||||
)
|
||||
@@ -121,10 +288,11 @@ def get_iree_compiled_module(
|
||||
frontend: str = "torch",
|
||||
func_name: str = "forward",
|
||||
model_config_path: str = None,
|
||||
extra_args: list = [],
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, func_name, model_config_path
|
||||
module, device, frontend, func_name, model_config_path, extra_args
|
||||
)
|
||||
return get_iree_module(flatbuffer_blob, device, func_name)
|
||||
|
||||
@@ -146,12 +314,18 @@ def export_iree_module_to_vmfb(
|
||||
mlir_dialect: str = "linalg",
|
||||
func_name: str = "forward",
|
||||
model_config_path: str = None,
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, mlir_dialect, func_name, model_config_path
|
||||
module, device, mlir_dialect, func_name, model_config_path, extra_args
|
||||
)
|
||||
module_name = f"{mlir_dialect}_{func_name}_{device}"
|
||||
if module_name is None:
|
||||
device_name = (
|
||||
device if "://" not in device else "-".join(device.split("://"))
|
||||
)
|
||||
module_name = f"{mlir_dialect}_{func_name}_{device_name}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
with open(filename, "wb") as f:
|
||||
@@ -187,4 +361,14 @@ def get_results(compiled_vm, input, config, frontend="torch"):
|
||||
res = np.array(data, dtype=object)
|
||||
return np.copy(res)
|
||||
else:
|
||||
return np.copy(np.asarray(result, dtype=result.dtype))
|
||||
return result.to_host()
|
||||
|
||||
|
||||
def get_iree_runtime_config(device):
|
||||
device = iree_device_map(device)
|
||||
if type(device) == ireert.HalDevice:
|
||||
config = ireert.Config(device=device)
|
||||
else:
|
||||
driver_name = device.split("://")[0] if "://" in device else device
|
||||
config = ireert.Config(driver_name=driver_name)
|
||||
return config
|
||||
|
||||
@@ -16,6 +16,17 @@
|
||||
|
||||
import subprocess
|
||||
|
||||
|
||||
def get_cpu_count():
|
||||
import multiprocessing
|
||||
|
||||
try:
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
return cpu_count
|
||||
except NotImplementedError:
|
||||
return None
|
||||
|
||||
|
||||
# Get the default cpu args.
|
||||
def get_iree_cpu_args():
|
||||
find_triple_cmd = "uname -s -m"
|
||||
|
||||
@@ -14,12 +14,42 @@
|
||||
|
||||
# All the iree_vulkan related functionalities go here.
|
||||
|
||||
from os import linesep
|
||||
from shark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
from sys import platform
|
||||
|
||||
|
||||
def get_vulkan_triple_flag():
|
||||
vulkan_device_cmd = "vulkaninfo | grep deviceName"
|
||||
vulkan_device = run_cmd(vulkan_device_cmd).strip()
|
||||
def get_vulkan_device_name():
|
||||
vulkaninfo_dump = run_cmd("vulkaninfo").split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print(
|
||||
f"Found {len(vulkaninfo_list)} device names. choosing first one: {vulkaninfo_list[0]}"
|
||||
)
|
||||
return vulkaninfo_list[0]
|
||||
|
||||
|
||||
def get_os_name():
|
||||
if platform.startswith("linux"):
|
||||
return "linux"
|
||||
elif platform == "darwin":
|
||||
return "macos"
|
||||
elif platform == "win32":
|
||||
return "windows"
|
||||
else:
|
||||
print("Cannot detect OS type, defaulting to linux.")
|
||||
return "linux"
|
||||
|
||||
|
||||
def get_vulkan_triple_flag(extra_args=[]):
|
||||
if "-iree-vulkan-target-triple=" in " ".join(extra_args):
|
||||
print(f"Using target triple from command line args")
|
||||
return None
|
||||
system_os = get_os_name()
|
||||
vulkan_device = get_vulkan_device_name()
|
||||
if all(x in vulkan_device for x in ("Apple", "M1")):
|
||||
print(f"Found {vulkan_device} Device. Using m1-moltenvk-macos")
|
||||
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
@@ -27,21 +57,26 @@ def get_vulkan_triple_flag():
|
||||
print("Found Apple M2 Device. Using m1-moltenvk-macos")
|
||||
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
elif all(x in vulkan_device for x in ("A100", "SXM4")):
|
||||
print(f"Found {vulkan_device} Device. Using ampere-rtx3080-linux")
|
||||
return "-iree-vulkan-target-triple=ampere-rtx3080-linux"
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using ampere-rtx3080-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=ampere-rtx3080-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "3090")):
|
||||
print(f"Found {vulkan_device} Device. Using ampere-rtx3090-linux")
|
||||
return "-iree-vulkan-target-triple=ampere-rtx3090-linux"
|
||||
elif any(x in vulkan_device for x in ("Radeon", "RX 5")):
|
||||
print(
|
||||
"Found AMD Radeon RX 5000 series device. Using rdna1-5700xt-linux"
|
||||
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
|
||||
)
|
||||
return "-iree-vulkan-target-triple=rdna1-5700xt-linux"
|
||||
elif all(x in vulkan_device for x in ("Radeon", "RX 6")):
|
||||
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "4090")):
|
||||
print(
|
||||
"Found AMD Radeon RX 6000 series device. Using rdna2-unknown-linux"
|
||||
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
|
||||
)
|
||||
return "-iree-vulkan-target-triple=rdna2-unknown-linux"
|
||||
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("AMD", "7900")):
|
||||
print(f"Found {vulkan_device} Device. Using rdna3-7900-{system_os}")
|
||||
return f"-iree-vulkan-target-triple=rdna3-7900-{system_os}"
|
||||
elif any(x in vulkan_device for x in ("AMD", "Radeon")):
|
||||
print(f"Found AMD device. Using rdna2-unknown-{system_os}")
|
||||
return f"-iree-vulkan-target-triple=rdna2-unknown-{system_os}"
|
||||
else:
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
@@ -52,10 +87,16 @@ def get_vulkan_triple_flag():
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_vulkan_args():
|
||||
def get_iree_vulkan_args(extra_args=[]):
|
||||
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
vulkan_flag = []
|
||||
vulkan_triple_flag = get_vulkan_triple_flag()
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args)
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_flag.append(vulkan_triple_flag)
|
||||
return vulkan_flag
|
||||
|
||||
|
||||
def set_iree_vulkan_runtime_flags(flags):
|
||||
for flag in flags:
|
||||
ireert.flags.parse_flags(flag)
|
||||
return
|
||||
|
||||
@@ -12,6 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
This function takes the model mlir file and the tuned config file as input,
|
||||
and output a new mlir file with lowering configs annotated on certain ops.
|
||||
There are two ways to utilize the function:
|
||||
1. Call model_annotation function within another python script
|
||||
from shark.model_annotation import model_annotation
|
||||
with create_context() as ctx:
|
||||
module = model_annotation(ctx, input_contents=..., config_path=..., search_op=...)
|
||||
2. Run model_annotation.py directly
|
||||
python model_annotation.py path_to_original_mlir path_to_config_file
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@@ -105,7 +118,9 @@ def add_attributes(op: ir.Operation, config: Dict):
|
||||
|
||||
|
||||
def parse_config(config: Dict):
|
||||
if config["pipeline"] == "GPU" or config["pipeline"] == "GPU_TENSORCORE":
|
||||
split_k = None
|
||||
pipeline_depth = None
|
||||
if "GPU" in config["pipeline"]:
|
||||
pipeline = (
|
||||
"LLVMGPUMatmulSimt"
|
||||
if config["pipeline"] == "GPU"
|
||||
@@ -113,24 +128,31 @@ def parse_config(config: Dict):
|
||||
)
|
||||
tile_sizes = [config["work_group_tile_sizes"]]
|
||||
workgroup_size = config["work_group_sizes"]
|
||||
try:
|
||||
if "pipeline_depth" in config.keys():
|
||||
pipeline_depth = config["pipeline_depth"]
|
||||
except:
|
||||
pipeline_depth = None
|
||||
try:
|
||||
if "split_k" in config.keys():
|
||||
split_k = config["split_k"]
|
||||
except:
|
||||
split_k = None
|
||||
else:
|
||||
elif "SPIRV" in config["pipeline"]:
|
||||
pipeline = config["pipeline"]
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"],
|
||||
config["l1_tile_sizes"],
|
||||
config["vector_tile_sizes"],
|
||||
config["parallel_tile_sizes"],
|
||||
config["reduction_tile_sizes"],
|
||||
]
|
||||
if "vector_tile_sizes" in config.keys():
|
||||
tile_sizes += [config["vector_tile_sizes"]]
|
||||
if "window_tile_sizes" in config.keys():
|
||||
tile_sizes += [config["window_tile_sizes"]]
|
||||
workgroup_size = config["work_group_sizes"]
|
||||
else:
|
||||
# For IREE CPU pipelines
|
||||
pipeline = config["pipeline"]
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"],
|
||||
config["parallel_tile_sizes"],
|
||||
config["reduction_tile_sizes"],
|
||||
]
|
||||
workgroup_size = []
|
||||
split_k = None
|
||||
pipeline_depth = None
|
||||
return tile_sizes, pipeline, workgroup_size, split_k, pipeline_depth
|
||||
|
||||
|
||||
|
||||
@@ -93,4 +93,23 @@ parser.add_argument(
|
||||
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dispatch_benchmarks",
|
||||
default=None,
|
||||
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="temp_dispatch_benchmarks",
|
||||
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_conv_transform",
|
||||
default=False,
|
||||
action="store",
|
||||
help="Enables the --iree-flow-enable-conv-nchw-to-nhwc-transform flag.",
|
||||
)
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
|
||||
@@ -39,29 +39,54 @@ class OnnxFusionOptions(object):
|
||||
self.no_attention_mask = False
|
||||
|
||||
|
||||
def check_requirements(frontend):
|
||||
import importlib
|
||||
|
||||
has_pkgs = False
|
||||
if frontend == "torch":
|
||||
tv_spec = importlib.util.find_spec("torchvision")
|
||||
has_pkgs = tv_spec is not None
|
||||
|
||||
elif frontend in ["tensorflow", "tf"]:
|
||||
keras_spec = importlib.util.find_spec("keras")
|
||||
tf_spec = importlib.util.find_spec("tensorflow")
|
||||
has_pkgs = keras_spec is not None and tf_spec is not None
|
||||
|
||||
return has_pkgs
|
||||
|
||||
|
||||
class SharkBenchmarkRunner(SharkRunner):
|
||||
# SharkRunner derived class with Benchmarking capabilities.
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: str,
|
||||
mlir_module: bytes,
|
||||
function_name: str = "forward",
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
extra_args: list = [],
|
||||
):
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.frontend_model = None
|
||||
self.vmfb_file = None
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
SharkRunner.__init__(
|
||||
self,
|
||||
mlir_module,
|
||||
function_name,
|
||||
device,
|
||||
self.mlir_dialect,
|
||||
self.extra_args,
|
||||
compile_vmfb=True,
|
||||
)
|
||||
if self.vmfb_file == None:
|
||||
self.vmfb_file = export_iree_module_to_vmfb(
|
||||
mlir_module, device, shark_args.repro_dir, self.mlir_dialect
|
||||
mlir_module,
|
||||
device,
|
||||
shark_args.repro_dir,
|
||||
self.mlir_dialect,
|
||||
function_name,
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
|
||||
def setup_cl(self, input_tensors):
|
||||
@@ -71,7 +96,6 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
input_tensors,
|
||||
mlir_dialect=self.mlir_dialect,
|
||||
)
|
||||
print(self.benchmark_cl)
|
||||
|
||||
def benchmark_frontend(self, modelname):
|
||||
if self.mlir_dialect in ["linalg", "torch"]:
|
||||
@@ -116,30 +140,31 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
import tensorflow as tf
|
||||
from tank.model_utils_tf import get_tf_model
|
||||
|
||||
model, input, = get_tf_model(
|
||||
modelname
|
||||
)[:2]
|
||||
frontend_model = model
|
||||
tf_device = "/GPU:0" if self.device == "cuda" else "/CPU:0"
|
||||
with tf.device(tf_device):
|
||||
model, input, = get_tf_model(
|
||||
modelname
|
||||
)[:2]
|
||||
frontend_model = model
|
||||
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(*input)
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(*input)
|
||||
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(*input)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
print(
|
||||
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
]
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(*input)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
print(
|
||||
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
]
|
||||
|
||||
def benchmark_c(self):
|
||||
print(self.benchmark_cl)
|
||||
result = run_benchmark_module(self.benchmark_cl)
|
||||
print(f"Shark-IREE-C benchmark:{result} iter/second")
|
||||
return [f"{result}", f"{1000/result}"]
|
||||
@@ -249,19 +274,15 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
return [param_count, model_tags, model_notes]
|
||||
|
||||
def compare_bench_results(self, baseline: str, result: str):
|
||||
# Takes two numbers represented as strings and returns "<n>x slower/faster", as in "result is <n>x slower than baseline".
|
||||
a = float(baseline)
|
||||
b = float(result)
|
||||
if a < b:
|
||||
# result slower than baseline
|
||||
comparison = (b - a) / a
|
||||
comp_str = f"{round(comparison, 2)}x slower"
|
||||
elif a > b:
|
||||
# result faster than baseline
|
||||
if baseline is not None:
|
||||
# Takes a baseline and a result string and calculates a comparison, e.g. "1.04x baseline".
|
||||
a = float(baseline)
|
||||
b = float(result)
|
||||
comparison = a / b
|
||||
comp_str = f"{round(comparison, 2)}x faster"
|
||||
comp_str = f"{round(comparison, 2)}x baseline"
|
||||
else:
|
||||
comp_str = "equal"
|
||||
comp_str = "N/A"
|
||||
|
||||
return comp_str
|
||||
|
||||
def benchmark_all_csv(
|
||||
@@ -311,17 +332,21 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
) = ["", "", ""]
|
||||
if e == "frontend":
|
||||
bench_result["engine"] = frontend
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
) = self.benchmark_frontend(modelname)
|
||||
self.frontend_result = bench_result["ms/iter"]
|
||||
bench_result["vs. PyTorch/TF"] = "="
|
||||
(
|
||||
bench_result["param_count"],
|
||||
bench_result["tags"],
|
||||
bench_result["notes"],
|
||||
) = self.get_metadata(modelname)
|
||||
if check_requirements(frontend):
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
) = self.benchmark_frontend(modelname)
|
||||
self.frontend_result = bench_result["ms/iter"]
|
||||
bench_result["vs. PyTorch/TF"] = "baseline"
|
||||
(
|
||||
bench_result["param_count"],
|
||||
bench_result["tags"],
|
||||
bench_result["notes"],
|
||||
) = self.get_metadata(modelname)
|
||||
else:
|
||||
self.frontend_result = None
|
||||
continue
|
||||
|
||||
elif e == "shark_python":
|
||||
bench_result["engine"] = "shark_python"
|
||||
|
||||
@@ -14,11 +14,51 @@
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import urllib.request
|
||||
import json
|
||||
import hashlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from shark.parser import shark_args
|
||||
from google.cloud import storage
|
||||
|
||||
|
||||
def download_public_file(
|
||||
full_gs_url, destination_folder_name, single_file=False
|
||||
):
|
||||
"""Downloads a public blob from the bucket."""
|
||||
# bucket_name = "gs://your-bucket-name/path/to/file"
|
||||
# destination_file_name = "local/path/to/file"
|
||||
|
||||
storage_client = storage.Client.create_anonymous_client()
|
||||
bucket_name = full_gs_url.split("/")[2]
|
||||
source_blob_name = None
|
||||
dest_filename = None
|
||||
desired_file = None
|
||||
if single_file:
|
||||
|
||||
desired_file = full_gs_url.split("/")[-1]
|
||||
source_blob_name = "/".join(full_gs_url.split("/")[3:-1])
|
||||
destination_folder_name, dest_filename = os.path.split(
|
||||
destination_folder_name
|
||||
)
|
||||
else:
|
||||
source_blob_name = "/".join(full_gs_url.split("/")[3:])
|
||||
bucket = storage_client.bucket(bucket_name)
|
||||
blobs = bucket.list_blobs(prefix=source_blob_name)
|
||||
if not os.path.exists(destination_folder_name):
|
||||
os.mkdir(destination_folder_name)
|
||||
for blob in blobs:
|
||||
blob_name = blob.name.split("/")[-1]
|
||||
if single_file:
|
||||
if blob_name == desired_file:
|
||||
destination_filename = os.path.join(
|
||||
destination_folder_name, dest_filename
|
||||
)
|
||||
blob.download_to_filename(destination_filename)
|
||||
else:
|
||||
continue
|
||||
|
||||
destination_filename = os.path.join(destination_folder_name, blob_name)
|
||||
blob.download_to_filename(destination_filename)
|
||||
|
||||
|
||||
input_type_to_np_dtype = {
|
||||
"float32": np.float32,
|
||||
@@ -50,8 +90,7 @@ if custom_path:
|
||||
else:
|
||||
WORKDIR = os.path.join(home, ".local/shark_tank/")
|
||||
print(
|
||||
f"shark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache="
|
||||
" pytest flag"
|
||||
f"shark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache= flag"
|
||||
)
|
||||
|
||||
# Checks whether the directory and files exists.
|
||||
@@ -88,185 +127,65 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
|
||||
|
||||
|
||||
# Downloads the torch model from gs://shark_tank dir.
|
||||
def download_torch_model(
|
||||
model_name, dynamic=False, tank_url="gs://shark_tank/latest"
|
||||
def download_model(
|
||||
model_name,
|
||||
dynamic=False,
|
||||
tank_url="gs://shark_tank/latest",
|
||||
frontend=None,
|
||||
tuned=None,
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_torch"
|
||||
|
||||
def gs_download_model():
|
||||
gs_command = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp -r '
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ " "
|
||||
+ WORKDIR
|
||||
)
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
|
||||
if not check_dir_exists(model_dir_name, frontend="torch", dynamic=dyn_str):
|
||||
gs_download_model()
|
||||
else:
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp '
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ "/hash.npy"
|
||||
+ " "
|
||||
+ os.path.join(model_dir, "upstream_hash.npy")
|
||||
)
|
||||
if os.system(gs_hash) != 0:
|
||||
raise Exception("hash of the model not present in the tank.")
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
if shark_args.update_tank == True:
|
||||
gs_download_model()
|
||||
else:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
|
||||
model_dir_name = model_name + "_" + frontend
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
with open(
|
||||
os.path.join(model_dir, model_name + dyn_str + "_torch.mlir")
|
||||
) as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
|
||||
|
||||
# Downloads the tflite model from gs://shark_tank dir.
|
||||
def download_tflite_model(
|
||||
model_name, dynamic=False, tank_url="gs://shark_tank/latest"
|
||||
):
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_tflite"
|
||||
|
||||
def gs_download_model():
|
||||
gs_command = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp -r '
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ " "
|
||||
+ WORKDIR
|
||||
)
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
|
||||
|
||||
if not check_dir_exists(
|
||||
model_dir_name, frontend="tflite", dynamic=dyn_str
|
||||
model_dir_name, frontend=frontend, dynamic=dyn_str
|
||||
):
|
||||
gs_download_model()
|
||||
print(f"Downloading artifacts for model {model_name}...")
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
else:
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp '
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ "/hash.npy"
|
||||
+ " "
|
||||
+ os.path.join(model_dir, "upstream_hash.npy")
|
||||
)
|
||||
if os.system(gs_hash) != 0:
|
||||
raise Exception("hash of the model not present in the tank.")
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
if shark_args.update_tank == True:
|
||||
gs_download_model()
|
||||
else:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
if not _internet_connected():
|
||||
print(
|
||||
"No internet connection. Using the model already present in the tank."
|
||||
)
|
||||
else:
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash_url = (
|
||||
tank_url.rstrip("/") + "/" + model_dir_name + "/hash.npy"
|
||||
)
|
||||
download_public_file(
|
||||
gs_hash_url,
|
||||
os.path.join(model_dir, "upstream_hash.npy"),
|
||||
single_file=True,
|
||||
)
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
if shark_args.update_tank == True:
|
||||
print(f"Updating artifacts for model {model_name}...")
|
||||
download_public_file(full_gs_url, WORKDIR)
|
||||
else:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
with open(
|
||||
os.path.join(model_dir, model_name + dyn_str + "_tflite.mlir")
|
||||
) as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
|
||||
|
||||
def download_tf_model(
|
||||
model_name, tuned=None, tank_url="gs://shark_tank/latest"
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_tf"
|
||||
|
||||
def gs_download_model():
|
||||
gs_command = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp -r '
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ " "
|
||||
+ WORKDIR
|
||||
)
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
|
||||
if not check_dir_exists(model_dir_name, frontend="tf"):
|
||||
gs_download_model()
|
||||
else:
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp '
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ "/hash.npy"
|
||||
+ " "
|
||||
+ os.path.join(model_dir, "upstream_hash.npy")
|
||||
)
|
||||
if os.system(gs_hash) != 0:
|
||||
raise Exception("hash of the model not present in the tank.")
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
if shark_args.update_tank == True:
|
||||
gs_download_model()
|
||||
else:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
suffix = "_tf.mlir" if tuned is None else "_tf_" + tuned + ".mlir"
|
||||
suffix = (
|
||||
"_" + frontend + ".mlir"
|
||||
if tuned is None
|
||||
else "_" + frontend + "_" + tuned + ".mlir"
|
||||
)
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
if not os.path.isfile(filename):
|
||||
filename = os.path.join(model_dir, model_name + "_tf.mlir")
|
||||
filename = os.path.join(
|
||||
model_dir, model_name + "_" + frontend + ".mlir"
|
||||
)
|
||||
|
||||
with open(filename) as f:
|
||||
with open(filename, mode="rb") as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
@@ -276,3 +195,13 @@ def download_tf_model(
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
|
||||
|
||||
def _internet_connected():
|
||||
import requests as req
|
||||
|
||||
try:
|
||||
req.get("http://1.1.1.1")
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
@@ -75,21 +75,24 @@ class SharkImporter:
|
||||
self.module, self.inputs, is_dynamic, tracing_required
|
||||
)
|
||||
|
||||
def _tf_mlir(self, func_name):
|
||||
def _tf_mlir(self, func_name, save_dir="./shark_tmp/"):
|
||||
from iree.compiler import tf as tfc
|
||||
|
||||
return tfc.compile_module(
|
||||
self.module, exported_names=[func_name], import_only=True
|
||||
self.module,
|
||||
exported_names=[func_name],
|
||||
import_only=True,
|
||||
output_file=save_dir,
|
||||
)
|
||||
|
||||
def _tflite_mlir(self, func_name):
|
||||
def _tflite_mlir(self, func_name, save_dir="./shark_tmp/"):
|
||||
from iree.compiler import tflite as tflitec
|
||||
from shark.iree_utils._common import IREE_TARGET_MAP
|
||||
|
||||
self.mlir_model = tflitec.compile_file(
|
||||
self.raw_model_file, # in tflite, it is a path to .tflite file, not a tflite interpreter
|
||||
input_type="tosa",
|
||||
import_only=True,
|
||||
output_file=save_dir,
|
||||
)
|
||||
return self.mlir_model
|
||||
|
||||
@@ -99,6 +102,7 @@ class SharkImporter:
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
func_name="forward",
|
||||
save_dir="./shark_tmp/",
|
||||
):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
if self.inputs == None:
|
||||
@@ -108,15 +112,15 @@ class SharkImporter:
|
||||
sys.exit(1)
|
||||
return self._torch_mlir(is_dynamic, tracing_required), func_name
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
return self._tf_mlir(func_name), func_name
|
||||
return self._tf_mlir(func_name, save_dir), func_name
|
||||
if self.frontend in ["tflite", "tf-lite"]:
|
||||
func_name = "main"
|
||||
return self._tflite_mlir(func_name), func_name
|
||||
return self._tflite_mlir(func_name, save_dir), func_name
|
||||
|
||||
# Converts the frontend specific tensors into np array.
|
||||
def convert_to_numpy(self, array_tuple: tuple):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
return [x.detach().numpy() for x in array_tuple]
|
||||
return [x.detach().cpu().numpy() for x in array_tuple]
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
return [x.numpy() for x in array_tuple]
|
||||
|
||||
@@ -130,19 +134,20 @@ class SharkImporter:
|
||||
outputs_name = "golden_out.npz"
|
||||
func_file_name = "function_name"
|
||||
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
|
||||
try:
|
||||
inputs = [x.cpu().detach() for x in inputs]
|
||||
except AttributeError:
|
||||
try:
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
except AttributeError:
|
||||
inputs = [x for x in inputs]
|
||||
np.savez(os.path.join(dir, inputs_name), *inputs)
|
||||
np.savez(os.path.join(dir, outputs_name), *outputs)
|
||||
np.save(os.path.join(dir, func_file_name), np.array(func_name))
|
||||
|
||||
mlir_str = mlir_data
|
||||
if self.frontend == "torch":
|
||||
mlir_str = mlir_data.operation.get_asm()
|
||||
elif self.frontend == "tf":
|
||||
mlir_str = mlir_data.decode("utf-8")
|
||||
elif self.frontend == "tflite":
|
||||
mlir_str = mlir_data.decode("utf-8")
|
||||
with open(os.path.join(dir, model_name_mlir), "w") as mlir_file:
|
||||
mlir_file.write(mlir_str)
|
||||
with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file:
|
||||
mlir_file.write(mlir_data)
|
||||
|
||||
return
|
||||
|
||||
@@ -159,9 +164,13 @@ class SharkImporter:
|
||||
f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
|
||||
artifact_path = os.path.join(dir, model_name_mlir)
|
||||
imported_mlir = self.import_mlir(
|
||||
is_dynamic, tracing_required, func_name
|
||||
is_dynamic,
|
||||
tracing_required,
|
||||
func_name,
|
||||
save_dir=artifact_path,
|
||||
)
|
||||
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
|
||||
# TODO: Check for multiple outputs.
|
||||
@@ -171,7 +180,7 @@ class SharkImporter:
|
||||
golden_out = self.module(*self.inputs)
|
||||
if torch.is_tensor(golden_out):
|
||||
golden_out = tuple(
|
||||
golden_out.detach().numpy(),
|
||||
golden_out.detach().cpu().numpy(),
|
||||
)
|
||||
else:
|
||||
golden_out = self.convert_to_numpy(golden_out)
|
||||
@@ -234,3 +243,58 @@ class SharkImporter:
|
||||
self.inputs,
|
||||
golden_out,
|
||||
)
|
||||
|
||||
|
||||
# Applies fx conversion to the model and imports the mlir.
|
||||
def import_with_fx(model, inputs, debug=False):
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
|
||||
# TODO: Control the decompositions.
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(*inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
fx_g,
|
||||
inputs,
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
if debug:
|
||||
(mlir_module, func_name), _, _ = mlir_importer.import_debug()
|
||||
return mlir_module, func_name
|
||||
|
||||
mlir_module, func_name = mlir_importer.import_mlir()
|
||||
|
||||
return mlir_module, func_name
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
from shark.iree_utils.compile_utils import (
|
||||
export_iree_module_to_vmfb,
|
||||
load_flatbuffer,
|
||||
create_dispatch_dirs,
|
||||
compile_benchmark_dirs,
|
||||
)
|
||||
import os
|
||||
from shark.shark_runner import SharkRunner
|
||||
@@ -37,7 +39,7 @@ class SharkInference:
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string.
|
||||
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
|
||||
function_name : str
|
||||
function to execute in the given mlir_module.
|
||||
device : str
|
||||
@@ -63,21 +65,48 @@ class SharkInference:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: str,
|
||||
mlir_module: bytes,
|
||||
function_name: str = "forward",
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
is_benchmark: bool = False,
|
||||
dispatch_benchmark: str = None,
|
||||
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
self.function_name = function_name
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.is_benchmark = is_benchmark
|
||||
self.dispatch_benchmarks = (
|
||||
shark_args.dispatch_benchmarks
|
||||
if dispatch_benchmark is None
|
||||
else dispatch_benchmark
|
||||
)
|
||||
self.dispatch_benchmarks_dir = (
|
||||
shark_args.dispatch_benchmarks_dir
|
||||
if dispatch_benchmark_dir == "temp_dispatch_benchmarks"
|
||||
else dispatch_benchmark_dir
|
||||
)
|
||||
|
||||
self.shark_runner = None
|
||||
|
||||
def compile(self):
|
||||
def compile(self, extra_args=[]):
|
||||
|
||||
if self.dispatch_benchmarks is not None:
|
||||
extra_args.append(
|
||||
f"--iree-hal-dump-executable-sources-to={self.dispatch_benchmarks_dir}"
|
||||
)
|
||||
extra_args.append(
|
||||
f"--iree-hal-dump-executable-binaries-to={self.dispatch_benchmarks_dir}"
|
||||
)
|
||||
temp_dir = self.dispatch_benchmarks_dir.split("/")
|
||||
temp_dir[-1] = "temp_" + temp_dir[-1]
|
||||
temp_dir = "/".join(temp_dir)
|
||||
self.temp_dispatch_benchmarks_dir = temp_dir
|
||||
extra_args.append(
|
||||
f"--iree-hal-dump-executable-benchmarks-to={self.temp_dispatch_benchmarks_dir}"
|
||||
)
|
||||
|
||||
if self.is_benchmark == True:
|
||||
from shark.shark_benchmark_runner import SharkBenchmarkRunner
|
||||
@@ -87,6 +116,7 @@ class SharkInference:
|
||||
self.function_name,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -95,8 +125,18 @@ class SharkInference:
|
||||
self.function_name,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
if self.dispatch_benchmarks is not None:
|
||||
create_dispatch_dirs(self.dispatch_benchmarks_dir, self.device)
|
||||
compile_benchmark_dirs(
|
||||
self.dispatch_benchmarks_dir,
|
||||
self.device,
|
||||
self.dispatch_benchmarks,
|
||||
)
|
||||
os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}")
|
||||
|
||||
# inputs are considered to be tuple of np.array.
|
||||
def forward(self, inputs: tuple):
|
||||
return self.shark_runner.run(inputs)
|
||||
@@ -144,21 +184,24 @@ class SharkInference:
|
||||
|
||||
# TODO: Instead of passing directory and having names decided by the module
|
||||
# , user may want to save the module with manual names.
|
||||
def save_module(self, dir=os.getcwd()):
|
||||
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
|
||||
return export_iree_module_to_vmfb(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
dir,
|
||||
self.mlir_dialect,
|
||||
self.function_name,
|
||||
module_name=module_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
# load and return the module.
|
||||
def load_module(self, path):
|
||||
def load_module(self, path, extra_args=[]):
|
||||
self.shark_runner = SharkRunner(
|
||||
function_name=self.function_name,
|
||||
device=self.device,
|
||||
compile_vmfb=False,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
(
|
||||
self.shark_runner.iree_compilation_module,
|
||||
|
||||
@@ -61,19 +61,21 @@ class SharkRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: str = "none",
|
||||
mlir_module: bytes = None,
|
||||
function_name: str = "forward",
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
extra_args: list = [],
|
||||
compile_vmfb: bool = True,
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
self.function_name = function_name
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
|
||||
if check_device_drivers(self.device):
|
||||
device_driver_info(self.device)
|
||||
print(device_driver_info(self.device))
|
||||
sys.exit(1)
|
||||
|
||||
if compile_vmfb == True:
|
||||
@@ -86,6 +88,7 @@ class SharkRunner:
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
func_name=self.function_name,
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
|
||||
def run(self, inputs: tuple):
|
||||
|
||||
@@ -17,6 +17,7 @@ import torch_mlir
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
import tempfile
|
||||
from shark.parser import shark_args
|
||||
import io
|
||||
|
||||
|
||||
def get_module_name_for_asm_dump(module):
|
||||
@@ -55,9 +56,8 @@ def get_torch_mlir_module(
|
||||
input: tuple,
|
||||
dynamic: bool,
|
||||
jit_trace: bool,
|
||||
from_torchscript: bool = False,
|
||||
):
|
||||
"""Get the MLIR's linalg-on-tensors module from torchscipt module."""
|
||||
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
|
||||
ignore_traced_shapes = False
|
||||
if dynamic:
|
||||
input = create_dynamic_placeholders(input)
|
||||
@@ -66,11 +66,14 @@ def get_torch_mlir_module(
|
||||
|
||||
tempfile.tempdir = shark_args.repro_dir
|
||||
|
||||
module = torch_mlir.compile(
|
||||
mlir_module = torch_mlir.compile(
|
||||
module,
|
||||
input,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=jit_trace,
|
||||
ignore_traced_shapes=ignore_traced_shapes,
|
||||
)
|
||||
return module
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
return bytecode
|
||||
|
||||
210
tank/README.md
210
tank/README.md
@@ -1,3 +1,211 @@
|
||||
## Supported and Validated Models
|
||||
|
||||
### PyTorch HuggingFace Models
|
||||
|
||||
| PyTorch Language Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Albert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| BigBird | :green_heart: (AOT) | | | |
|
||||
| dbmdz/ConvBERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DistilBERT | :broken_heart: (JIT) | | | |
|
||||
| GPT2 | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| MobileBert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| microsoft/beit | :green_heart: | :green_heart: | :broken_heart: | :broken_heart: |
|
||||
| facebook/deit | :green_heart: | :green_heart: | :broken_heart: | :broken_heart: |
|
||||
| facebook/convnext | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
|
||||
### Torchvision Models
|
||||
|
||||
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|--------------------|----------------------|----------|----------|-------------|
|
||||
| AlexNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| MobileNetV2 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| MobileNetV3 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Unet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnet18 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnet50 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnet101 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnext50_32x4d | :green_heart: (Script) | | | |
|
||||
| SqueezeNet | :green_heart: (Script) | :green_heart: | :broken_heart: | :broken_heart: |
|
||||
| EfficientNet | :green_heart: (Script) | | | |
|
||||
| Regnet | :green_heart: (Script) | | | |
|
||||
| Resnest | :broken_heart: (Script) | | | |
|
||||
| Vision Transformer | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| VGG 16 | :green_heart: (Script) | :green_heart: | :green_heart: | |
|
||||
| Wide Resnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| RAFT | :broken_heart: (JIT) | | | |
|
||||
|
||||
For more information refer to [MODEL TRACKING SHEET](https://docs.google.com/spreadsheets/d/15PcjKeHZIrB5LfDyuw7DGEEE8XnQEX2aX8lm8qbxV8A/edit#gid=0)
|
||||
|
||||
### Tensorflow Models (Inference)
|
||||
|
||||
| Hugging Face Models | tf-mhlo lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| MiniLM | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| albert-base-v2 | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DistilBERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| CamemBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ConvBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Deberta | | | | |
|
||||
| electra | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| funnel | | | | |
|
||||
| layoutlm | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| longformer | | | | |
|
||||
| mobile-bert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| rembert | | | | |
|
||||
| tapas | | | | |
|
||||
| flaubert | :broken_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| xlm-roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| mpnet | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
|
||||
### PyTorch Training Models
|
||||
|
||||
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: | :green_heart: | | |
|
||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||
|
||||
### JAX Models
|
||||
|
||||
| Models | JAX-MHLO lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| DALL-E | :broken_heart: | :broken_heart: | | |
|
||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||
|
||||
<details>
|
||||
<summary>TFLite Models</summary>
|
||||
|
||||
### TFLite Models
|
||||
|
||||
| Models | TOSA/LinAlg | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :broken_heart: | :broken_heart: | | |
|
||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||
| albert | :green_heart: | :green_heart: | | |
|
||||
| asr_conformer | :green_heart: | :green_heart: | | |
|
||||
| bird_classifier | :green_heart: | :green_heart: | | |
|
||||
| cartoon_gan | :green_heart: | :green_heart: | | |
|
||||
| craft_text | :green_heart: | :green_heart: | | |
|
||||
| deeplab_v3 | :green_heart: | :green_heart: | | |
|
||||
| densenet | :green_heart: | :green_heart: | | |
|
||||
| east_text_detector | :green_heart: | :green_heart: | | |
|
||||
| efficientnet_lite0_int8 | :green_heart: | :green_heart: | | |
|
||||
| efficientnet | :green_heart: | :green_heart: | | |
|
||||
| gpt2 | :green_heart: | :green_heart: | | |
|
||||
| image_stylization | :green_heart: | :green_heart: | | |
|
||||
| inception_v4 | :green_heart: | :green_heart: | | |
|
||||
| inception_v4_uint8 | :green_heart: | :green_heart: | | |
|
||||
| lightning_fp16 | :green_heart: | :green_heart: | | |
|
||||
| lightning_i8 | :green_heart: | :green_heart: | | |
|
||||
| lightning | :green_heart: | :green_heart: | | |
|
||||
| magenta | :green_heart: | :green_heart: | | |
|
||||
| midas | :green_heart: | :green_heart: | | |
|
||||
| mirnet | :green_heart: | :green_heart: | | |
|
||||
| mnasnet | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_edgetpu_s_float | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_edgetpu_s_quant | :green_heart: | :green_heart: | | |
|
||||
| mobilebert | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_tf2_float | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_tf2_quant | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_ssd_quant | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v1 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v2 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v2_uint8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v3-large | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v3-large_uint8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v35-int8 | :green_heart: | :green_heart: | | |
|
||||
| nasnet | :green_heart: | :green_heart: | | |
|
||||
| person_detect | :green_heart: | :green_heart: | | |
|
||||
| posenet | :green_heart: | :green_heart: | | |
|
||||
| resnet_50_int8 | :green_heart: | :green_heart: | | |
|
||||
| rosetta | :green_heart: | :green_heart: | | |
|
||||
| spice | :green_heart: | :green_heart: | | |
|
||||
| squeezenet | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v1 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2_fpnlite | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2_fpnlite_uint8 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2 | :green_heart: | :green_heart: | | |
|
||||
| ssd_spaghettinet_large | :green_heart: | :green_heart: | | |
|
||||
| ssd_spaghettinet_large_uint8 | :green_heart: | :green_heart: | | |
|
||||
| visual_wake_words_i8 | :green_heart: | :green_heart: | | |
|
||||
|
||||
</details>
|
||||
|
||||
## Testing and Benchmarks
|
||||
|
||||
### Run all model tests on CPU/GPU/VULKAN/Metal
|
||||
|
||||
For a list of models included in our pytest model suite, see https://github.com/nod-ai/SHARK/blob/main/tank/all_models.csv
|
||||
|
||||
```shell
|
||||
pytest tank/test_models.py
|
||||
|
||||
# Models included in the pytest suite can be found listed in all_models.csv.
|
||||
|
||||
# If on Linux for multithreading on CPU (faster results):
|
||||
pytest tank/test_models.py -n auto
|
||||
```
|
||||
|
||||
### Running specific tests
|
||||
```shell
|
||||
|
||||
# Search for test cases by including a keyword that matches all or part of the test case's name;
|
||||
pytest tank/test_models.py -k "keyword"
|
||||
|
||||
# Test cases are named uniformly by format test_module_<model_name_underscores_only>_<torch/tf>_<static/dynamic>_<device>.
|
||||
|
||||
# Example: Test all models on nvidia gpu:
|
||||
pytest tank/test_models.py -k "cuda"
|
||||
|
||||
# Example: Test all tensorflow resnet models on Vulkan backend:
|
||||
pytest tank/test_models.py -k "resnet and tf and vulkan"
|
||||
|
||||
# Exclude a test case:
|
||||
pytest tank/test_models.py -k "not ..."
|
||||
|
||||
### Run benchmarks on SHARK tank pytests and generate bench_results.csv with results.
|
||||
|
||||
(the following requires source installation with `IMPORTER=1 ./setup_venv.sh`)
|
||||
|
||||
```shell
|
||||
pytest --benchmark tank/test_models.py
|
||||
|
||||
# Just do static GPU benchmarks for PyTorch tests:
|
||||
pytest --benchmark tank/test_models.py -k "pytorch and static and cuda"
|
||||
|
||||
```
|
||||
|
||||
### Benchmark Resnet50, MiniLM on CPU
|
||||
|
||||
(requires source installation with `IMPORTER=1 ./setup_venv.sh`)
|
||||
|
||||
```shell
|
||||
# We suggest running the following commands as root before running benchmarks on CPU:
|
||||
|
||||
cat /sys/devices/system/cpu/cpu*/topology/thread_siblings_list | awk -F, '{print $2}' | sort -n | uniq | ( while read X ; do echo $X ; echo 0 > /sys/devices/system/cpu/cpu$X/online ; done )
|
||||
echo 1 > /sys/devices/system/cpu/intel_pstate/no_turbo
|
||||
|
||||
# Benchmark canonical Resnet50 on CPU via pytest
|
||||
pytest --benchmark tank/test_models.py -k "resnet50 and tf_static_cpu"
|
||||
|
||||
# Benchmark canonical MiniLM on CPU via pytest
|
||||
pytest --benchmark tank/test_models.py -k "MiniLM and cpu"
|
||||
|
||||
# Benchmark MiniLM on CPU via transformer-benchmarks:
|
||||
git clone --recursive https://github.com/nod-ai/transformer-benchmarks.git
|
||||
cd transformer-benchmarks
|
||||
./perf-ci.sh -n
|
||||
# Check detail.csv for MLIR/IREE results.
|
||||
|
||||
```
|
||||
|
||||
To run the fine tuning example, from the root SHARK directory, run:
|
||||
|
||||
```shell
|
||||
@@ -11,3 +219,5 @@ if running from a google vm, you can view jupyter notebooks on your local system
|
||||
gcloud compute ssh <YOUR_INSTANCE_DETAILS> --ssh-flag="-N -L localhost:8888:localhost:8888"
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,34 +1,35 @@
|
||||
resnet50,mhlo,tf,1e-02,1e-3,default
|
||||
albert-base-v2,mhlo,tf,1e-02,1e-3,default
|
||||
roberta-base,mhlo,tf,1e-02,1e-3,default
|
||||
bert-base-uncased,mhlo,tf,1e-2,1e-3,default
|
||||
camembert-base,mhlo,tf,1e-2,1e-3,default
|
||||
dbmdz/convbert-base-turkish-cased,mhlo,tf,1e-2,1e-3,default
|
||||
distilbert-base-uncased,mhlo,tf,1e-2,1e-3,default
|
||||
facebook/convnext-tiny-224,mhlo,tf,1e-2,1e-3,tf_vit
|
||||
funnel-transformer/small,mhlo,tf,1e-2,1e-3,default
|
||||
google/electra-small-discriminator,mhlo,tf,1e-2,1e-3,default
|
||||
google/mobilebert-uncased,mhlo,tf,1e-2,1e-3,default
|
||||
google/vit-base-patch16-224,mhlo,tf,1e-2,1e-3,tf_vit
|
||||
hf-internal-testing/tiny-random-flaubert,mhlo,tf,1e-2,1e-3,default
|
||||
microsoft/MiniLM-L12-H384-uncased,mhlo,tf,1e-2,1e-3,tf_hf
|
||||
microsoft/layoutlm-base-uncased,mhlo,tf,1e-2,1e-3,default
|
||||
microsoft/mpnet-base,mhlo,tf,1e-2,1e-3,default
|
||||
albert-base-v2,linalg,torch,1e-2,1e-3,default
|
||||
alexnet,linalg,torch,1e-2,1e-3,default
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default
|
||||
distilbert-base-uncased,linalg,torch,1e-2,1e-3,default
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default
|
||||
microsoft/resnet-50,linalg,torch,1e-2,1e-3,default
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default
|
||||
mobilenet_v3_small,linalg,torch,1e-2,1e-3,default
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default
|
||||
resnet101,linalg,torch,1e-2,1e-3,default
|
||||
resnet18,linalg,torch,1e-2,1e-3,default
|
||||
resnet50,linalg,torch,1e-2,1e-3,default
|
||||
squeezenet1_0,linalg,torch,1e-2,1e-3,default
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default
|
||||
resnet50,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc
|
||||
albert-base-v2,mhlo,tf,1e-2,1e-2,default,None
|
||||
roberta-base,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc
|
||||
bert-base-uncased,mhlo,tf,1e-2,1e-3,default,None
|
||||
camembert-base,mhlo,tf,1e-2,1e-3,default,None
|
||||
dbmdz/convbert-base-turkish-cased,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc
|
||||
distilbert-base-uncased,mhlo,tf,1e-2,1e-3,default,None
|
||||
facebook/convnext-tiny-224,mhlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,
|
||||
funnel-transformer/small,mhlo,tf,1e-2,1e-3,default,None
|
||||
google/electra-small-discriminator,mhlo,tf,1e-2,1e-3,default,None
|
||||
google/mobilebert-uncased,mhlo,tf,1e-2,1e-3,default,None
|
||||
google/vit-base-patch16-224,mhlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc
|
||||
hf-internal-testing/tiny-random-flaubert,mhlo,tf,1e-2,1e-3,default,None
|
||||
microsoft/MiniLM-L12-H384-uncased,mhlo,tf,1e-2,1e-3,tf_hf,None
|
||||
microsoft/layoutlm-base-uncased,mhlo,tf,1e-2,1e-3,default,None
|
||||
microsoft/mpnet-base,mhlo,tf,1e-2,1e-2,default,None
|
||||
albert-base-v2,linalg,torch,1e-2,1e-3,default,None
|
||||
alexnet,linalg,torch,1e-2,1e-3,default,None
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None
|
||||
microsoft/resnet-50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None
|
||||
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc
|
||||
resnet18,linalg,torch,1e-2,1e-3,default,None
|
||||
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc
|
||||
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc
|
||||
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc
|
||||
|
||||
|
@@ -32,7 +32,7 @@ class BertModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def predict(self, input_word_ids, input_mask, segment_ids):
|
||||
return self.m.predict(input_word_ids, input_mask, segment_ids)
|
||||
|
||||
@@ -33,7 +33,7 @@ class BertModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def predict(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
@@ -52,7 +52,7 @@ class SeqClassification(tf.Module):
|
||||
)
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
|
||||
|
||||
@tf.function(input_signature=inputs_signature)
|
||||
@tf.function(input_signature=inputs_signature, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return tf.math.softmax(
|
||||
self.m.predict(input_ids, attention_mask), axis=-1
|
||||
@@ -1,8 +1,9 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"bert-base-uncased_tosa"
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"bert-base-uncased_tosa",
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
@@ -72,7 +72,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
@@ -60,7 +60,8 @@ class BertModule(tf.Module):
|
||||
shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32
|
||||
), # input2: segment_ids
|
||||
tf.TensorSpec([BATCH_SIZE], tf.int32), # input3: labels
|
||||
]
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def learn(self, input_word_ids, input_mask, segment_ids, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
@@ -75,7 +76,7 @@ class BertModule(tf.Module):
|
||||
self.optimizer.apply_gradients(zip(gradients, variables))
|
||||
return loss
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def predict(self, input_word_ids, input_mask, segment_ids):
|
||||
inputs = [input_word_ids, input_mask, segment_ids]
|
||||
return self.m.predict(inputs)
|
||||
@@ -57,7 +57,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
@@ -50,7 +50,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
@@ -57,7 +57,8 @@ class BertModule(tf.Module):
|
||||
shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32
|
||||
), # input2: segment_ids
|
||||
tf.TensorSpec([BATCH_SIZE], tf.int32), # input3: labels
|
||||
]
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def learn(self, input_word_ids, input_mask, segment_ids, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
@@ -72,7 +73,7 @@ class BertModule(tf.Module):
|
||||
self.optimizer.apply_gradients(zip(gradients, variables))
|
||||
return loss
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def predict(self, input_word_ids, input_mask, segment_ids):
|
||||
inputs = [input_word_ids, input_mask, segment_ids]
|
||||
return self.m.predict(inputs)
|
||||
@@ -53,7 +53,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
@@ -46,7 +46,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
83
tank/examples/bert_tf/seq_classification.py
Executable file
83
tank/examples/bert_tf/seq_classification.py
Executable file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python
|
||||
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
import argparse
|
||||
|
||||
|
||||
seq_parser = argparse.ArgumentParser(
|
||||
description="Shark Sequence Classification."
|
||||
)
|
||||
seq_parser.add_argument(
|
||||
"--hf_model_name",
|
||||
type=str,
|
||||
default="bert-base-uncased",
|
||||
help="Hugging face model to run sequence classification.",
|
||||
)
|
||||
|
||||
seq_args, unknown = seq_parser.parse_known_args()
|
||||
|
||||
|
||||
BATCH_SIZE = 1
|
||||
MAX_SEQUENCE_LENGTH = 16
|
||||
|
||||
# Create a set of input signature.
|
||||
inputs_signature = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
# For supported models please see here:
|
||||
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForSequenceClassification
|
||||
|
||||
|
||||
def preprocess_input(text="This is just used to compile the model"):
|
||||
tokenizer = AutoTokenizer.from_pretrained(seq_args.hf_model_name)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
return_tensors="tf",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
class SeqClassification(tf.Module):
|
||||
def __init__(self, model_name):
|
||||
super(SeqClassification, self).__init__()
|
||||
self.m = TFAutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, output_attentions=False, num_labels=2
|
||||
)
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
|
||||
|
||||
@tf.function(input_signature=inputs_signature, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return tf.math.softmax(
|
||||
self.m.predict(input_ids, attention_mask), axis=-1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
inputs = preprocess_input()
|
||||
shark_module = SharkInference(
|
||||
SeqClassification(seq_args.hf_model_name),
|
||||
(inputs["input_ids"], inputs["attention_mask"]),
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
print(f"Model has been successfully compiled on {shark_args.device}")
|
||||
|
||||
while True:
|
||||
input_text = input(
|
||||
"Enter the text to classify (press q or nothing to exit): "
|
||||
)
|
||||
if not input_text or input_text == "q":
|
||||
break
|
||||
inputs = preprocess_input(input_text)
|
||||
print(
|
||||
shark_module.forward(
|
||||
(inputs["input_ids"], inputs["attention_mask"])
|
||||
)
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.parser import shark_args
|
||||
from tank.test_utils import get_valid_test_params, shark_test_name_func
|
||||
from parameterized import parameterized
|
||||
@@ -21,8 +21,8 @@ class DebertaBaseModuleTester:
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, func_name, inputs, golden_out = download_tf_model(
|
||||
"microsoft/deberta-base"
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
"microsoft/deberta-base", frontend="tf"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
@@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from shark.shark_downloader import download_tflite_model
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.shark_inference import SharkInference
|
||||
import pytest
|
||||
import unittest
|
||||
@@ -58,8 +58,8 @@ class GptTfliteModuleTester:
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
|
||||
# Preprocess to get SharkImporter input args
|
||||
mlir_model, func_name, inputs, tflite_results = download_tflite_model(
|
||||
model_name="gpt2-64"
|
||||
mlir_model, func_name, inputs, tflite_results = download_model(
|
||||
model_name="gpt2-64", backend="tflite"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module=mlir_model,
|
||||
@@ -20,10 +20,6 @@ class OPTModuleTester:
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device, model_name):
|
||||
# model_mlir, func_name, input, act_out = download_torch_model(
|
||||
# "opt", dynamic
|
||||
# )
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
||||
# config = OPTConfig()
|
||||
# opt_model = OPTModel(config)
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
from shark.shark_downloader import download_model
|
||||
from tank.test_utils import get_valid_test_params, shark_test_name_func
|
||||
from parameterized import parameterized
|
||||
|
||||
@@ -18,8 +18,8 @@ class RemBertModuleTester:
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, func_name, inputs, golden_out = download_tf_model(
|
||||
"google/rembert"
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
"google/rembert", frontend="tf"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
import iree.compiler as ireec
|
||||
import unittest
|
||||
@@ -16,8 +16,9 @@ class TapasBaseModuleTester:
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, func_name, inputs, golden_out = download_tf_model(
|
||||
"google/tapas-base"
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
"google/tapas-base",
|
||||
frontend="tf",
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
@@ -15,7 +15,7 @@ from torchvision.transforms import functional as TF
|
||||
from tqdm import trange
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
import numpy as np
|
||||
|
||||
import sys
|
||||
@@ -191,7 +191,9 @@ x_in = x[0:min_batch_size, :, :, :]
|
||||
ts = x_in.new_ones([x_in.shape[0]])
|
||||
t_in = t[0] * ts
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"v_diffusion", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.runtime_device, mlir_dialect="linalg"
|
||||
@@ -27,3 +27,5 @@ microsoft/mpnet-base,False,False,-,-,-
|
||||
roberta-base,False,False,-,-,-
|
||||
xlm-roberta-base,False,False,-,-,-
|
||||
facebook/convnext-tiny-224,False,False,-,-,-
|
||||
efficientnet-v2-s,False,False,22M,"image-classification,cnn","Includes MBConv and Fused-MBConv"
|
||||
mnasnet1_0,False,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
|
||||
|
@@ -15,6 +15,7 @@ vision_models = [
|
||||
"squeezenet1_0",
|
||||
"wide_resnet50_2",
|
||||
"mobilenet_v3_small",
|
||||
"mnasnet1_0",
|
||||
]
|
||||
hf_img_cls_models = [
|
||||
"google/vit-base-patch16-224",
|
||||
@@ -142,13 +143,14 @@ def get_vision_model(torch_model):
|
||||
import torchvision.models as models
|
||||
|
||||
vision_models_dict = {
|
||||
"alexnet": models.alexnet(pretrained=True),
|
||||
"resnet18": models.resnet18(pretrained=True),
|
||||
"resnet50": models.resnet50(pretrained=True),
|
||||
"resnet101": models.resnet101(pretrained=True),
|
||||
"squeezenet1_0": models.squeezenet1_0(pretrained=True),
|
||||
"wide_resnet50_2": models.wide_resnet50_2(pretrained=True),
|
||||
"mobilenet_v3_small": models.mobilenet_v3_small(pretrained=True),
|
||||
"alexnet": models.alexnet(weights="DEFAULT"),
|
||||
"resnet18": models.resnet18(weights="DEFAULT"),
|
||||
"resnet50": models.resnet50(weights="DEFAULT"),
|
||||
"resnet101": models.resnet101(weights="DEFAULT"),
|
||||
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
|
||||
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
|
||||
"mobilenet_v3_small": models.mobilenet_v3_small(weights="DEFAULT"),
|
||||
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
|
||||
}
|
||||
if isinstance(torch_model, str):
|
||||
torch_model = vision_models_dict[torch_model]
|
||||
@@ -160,6 +162,8 @@ def get_vision_model(torch_model):
|
||||
|
||||
################################################################################
|
||||
|
||||
####################### Other PyTorch HF Models ###############################
|
||||
|
||||
# Utility function for comparing two tensors (torch).
|
||||
def compare_tensors(torch_tensor, numpy_tensor, rtol=1e-02, atol=1e-03):
|
||||
# torch_to_numpy = torch_tensor.detach().numpy()
|
||||
|
||||
@@ -6,24 +6,12 @@ from transformers import (
|
||||
TFBertModel,
|
||||
)
|
||||
|
||||
visible_default = tf.config.list_physical_devices("GPU")
|
||||
try:
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
visible_devices = tf.config.get_visible_devices()
|
||||
for device in visible_devices:
|
||||
assert device.device_type != "GPU"
|
||||
except:
|
||||
# Invalid device or cannot modify virtual devices once initialized.
|
||||
pass
|
||||
|
||||
BATCH_SIZE = 1
|
||||
MAX_SEQUENCE_LENGTH = 128
|
||||
|
||||
################################## MHLO/TF models #########################################
|
||||
# TODO : Generate these lists or fetch model source from tank/tf/tf_model_list.csv
|
||||
keras_models = [
|
||||
"resnet50",
|
||||
]
|
||||
keras_models = ["resnet50", "efficientnet-v2-s"]
|
||||
maskedlm_models = [
|
||||
"albert-base-v2",
|
||||
"bert-base-uncased",
|
||||
@@ -87,7 +75,7 @@ class TFHuggingFaceLanguage(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=tf_bert_input)
|
||||
@tf.function(input_signature=tf_bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
@@ -162,7 +150,7 @@ class MaskedLM(tf.Module):
|
||||
)
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
|
||||
|
||||
@tf.function(input_signature=input_signature_maskedlm)
|
||||
@tf.function(input_signature=input_signature_maskedlm, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
@@ -178,42 +166,81 @@ def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
##################### TensorFlow Keras Resnet Models #########################################################
|
||||
# Static shape, including batch size (1).
|
||||
# Can be dynamic once dynamic shape support is ready.
|
||||
INPUT_SHAPE = [1, 224, 224, 3]
|
||||
|
||||
tf_model = tf.keras.applications.resnet50.ResNet50(
|
||||
weights="imagenet", include_top=True, input_shape=tuple(INPUT_SHAPE[1:])
|
||||
)
|
||||
RESNET_INPUT_SHAPE = [1, 224, 224, 3]
|
||||
EFFICIENTNET_INPUT_SHAPE = [1, 384, 384, 3]
|
||||
|
||||
|
||||
class ResNetModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(ResNetModule, self).__init__()
|
||||
self.m = tf_model
|
||||
self.m = tf.keras.applications.resnet50.ResNet50(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(RESNET_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])
|
||||
@tf.function(
|
||||
input_signature=[tf.TensorSpec(RESNET_INPUT_SHAPE, tf.float32)],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
def input_shape(self):
|
||||
return RESNET_INPUT_SHAPE
|
||||
|
||||
def load_image(path_to_image):
|
||||
def preprocess_input(self, image):
|
||||
return tf.keras.applications.resnet50.preprocess_input(image)
|
||||
|
||||
|
||||
class EfficientNetModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(EfficientNetModule, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[tf.TensorSpec(EFFICIENTNET_INPUT_SHAPE, tf.float32)],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
def input_shape(self):
|
||||
return EFFICIENTNET_INPUT_SHAPE
|
||||
|
||||
def preprocess_input(self, image):
|
||||
return tf.keras.applications.efficientnet_v2.preprocess_input(image)
|
||||
|
||||
|
||||
def load_image(path_to_image, width, height, channels):
|
||||
image = tf.io.read_file(path_to_image)
|
||||
image = tf.image.decode_image(image, channels=3)
|
||||
image = tf.image.resize(image, (224, 224))
|
||||
image = tf.image.decode_image(image, channels=channels)
|
||||
image = tf.image.resize(image, (width, height))
|
||||
image = image[tf.newaxis, :]
|
||||
return image
|
||||
|
||||
|
||||
def get_keras_model(modelname):
|
||||
model = ResNetModule()
|
||||
if modelname == "efficientnet-v2-s":
|
||||
model = EfficientNetModule()
|
||||
else:
|
||||
model = ResNetModule()
|
||||
|
||||
content_path = tf.keras.utils.get_file(
|
||||
"YellowLabradorLooking_new.jpg",
|
||||
"https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg",
|
||||
)
|
||||
content_image = load_image(content_path)
|
||||
input_tensor = tf.keras.applications.resnet50.preprocess_input(
|
||||
content_image
|
||||
input_shape = model.input_shape()
|
||||
content_image = load_image(
|
||||
content_path, input_shape[1], input_shape[2], input_shape[3]
|
||||
)
|
||||
input_tensor = model.preprocess_input(content_image)
|
||||
input_data = tf.expand_dims(input_tensor, 0)
|
||||
actual_out = model.forward(*input_data)
|
||||
return model, input_data, actual_out
|
||||
@@ -240,7 +267,7 @@ class AutoModelImageClassfication(tf.Module):
|
||||
)
|
||||
self.m.predict = lambda x: self.m(x)
|
||||
|
||||
@tf.function(input_signature=input_signature_img_cls)
|
||||
@tf.function(input_signature=input_signature_img_cls, jit_compile=True)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
from shark.iree_utils._common import (
|
||||
check_device_drivers,
|
||||
device_driver_info,
|
||||
IREE_DEVICE_MAP,
|
||||
get_supported_device_list,
|
||||
)
|
||||
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
|
||||
from parameterized import parameterized
|
||||
from shark.shark_downloader import (
|
||||
download_tf_model,
|
||||
download_torch_model,
|
||||
download_tflite_model,
|
||||
)
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
import iree.compiler as ireec
|
||||
@@ -20,6 +16,7 @@ import csv
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def load_csv_and_convert(filename, gen=False):
|
||||
@@ -41,6 +38,7 @@ def load_csv_and_convert(filename, gen=False):
|
||||
"rtol": float(row[3]),
|
||||
"atol": float(row[4]),
|
||||
"out_type": row[5],
|
||||
"flags": row[6],
|
||||
}
|
||||
)
|
||||
# This is a pytest workaround
|
||||
@@ -59,7 +57,7 @@ def get_valid_test_params():
|
||||
"""
|
||||
device_list = [
|
||||
device
|
||||
for device in IREE_DEVICE_MAP.keys()
|
||||
for device in get_supported_device_list()
|
||||
if not check_device_drivers(device)
|
||||
]
|
||||
dynamic_list = (True, False)
|
||||
@@ -130,24 +128,19 @@ class SharkModuleTester:
|
||||
self.config = config
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
|
||||
shark_args.local_tank_cache = self.local_tank_cache
|
||||
if self.config["framework"] == "tf":
|
||||
model, func_name, inputs, golden_out = download_tf_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
)
|
||||
elif self.config["framework"] == "torch":
|
||||
model, func_name, inputs, golden_out = download_torch_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
)
|
||||
elif self.config["framework"] == "tflite":
|
||||
model, func_name, inputs, golden_out = download_tflite_model(
|
||||
model_name=self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
)
|
||||
else:
|
||||
model, func_name, inputs, golden_out = None, None, None, None
|
||||
shark_args.update_tank = self.update_tank
|
||||
if "nhcw-nhwc" in self.config["flags"] and not os.path.isfile(
|
||||
".use-iree"
|
||||
):
|
||||
shark_args.enable_conv_transform = True
|
||||
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
frontend=self.config["framework"],
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
@@ -240,6 +233,16 @@ class SharkModuleTester:
|
||||
return expected, logits
|
||||
|
||||
|
||||
def run_test(module_tester, dynamic, device):
|
||||
tempdir = tempfile.TemporaryDirectory(
|
||||
prefix=module_tester.tmp_prefix, dir="./shark_tmp/"
|
||||
)
|
||||
module_tester.temp_dir = tempdir.name
|
||||
|
||||
with ireec.tools.TempFileSaver(tempdir.name):
|
||||
module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
class SharkModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
@@ -266,21 +269,18 @@ class SharkModuleTest(unittest.TestCase):
|
||||
self.module_tester.local_tank_cache = self.pytestconfig.getoption(
|
||||
"local_tank_cache"
|
||||
)
|
||||
self.module_tester.update_tank = self.pytestconfig.getoption(
|
||||
"update_tank"
|
||||
)
|
||||
self.module_tester.tank_url = self.pytestconfig.getoption("tank_url")
|
||||
if (
|
||||
config["model_name"] == "distilbert-base-uncased"
|
||||
and config["framework"] == "torch"
|
||||
):
|
||||
pytest.xfail(reason="https://github.com/nod-ai/SHARK/issues/354")
|
||||
if (
|
||||
config["model_name"] == "facebook/convnext-tiny-224"
|
||||
and device == "cuda"
|
||||
):
|
||||
pytest.xfail(reason="https://github.com/nod-ai/SHARK/issues/311")
|
||||
if (
|
||||
config["model_name"] == "google/vit-base-patch16-224"
|
||||
and device == "cuda"
|
||||
):
|
||||
if config["model_name"] == "efficientnet-v2-s" and device in [
|
||||
"metal",
|
||||
"vulkan",
|
||||
]:
|
||||
pytest.xfail(reason="https://github.com/nod-ai/SHARK/issues/575")
|
||||
if config[
|
||||
"model_name"
|
||||
] == "google/vit-base-patch16-224" and device in ["cuda"]:
|
||||
pytest.xfail(reason="https://github.com/nod-ai/SHARK/issues/311")
|
||||
if config["model_name"] == "resnet50" and device in [
|
||||
"metal",
|
||||
@@ -291,8 +291,6 @@ class SharkModuleTest(unittest.TestCase):
|
||||
pytest.xfail(
|
||||
reason="M2: Assert Error & M1: CompilerToolError"
|
||||
)
|
||||
if config["model_name"] == "google/rembert":
|
||||
pytest.skip(reason="Model too large to convert.")
|
||||
if config[
|
||||
"model_name"
|
||||
] == "dbmdz/convbert-base-turkish-cased" and device in [
|
||||
@@ -311,7 +309,6 @@ class SharkModuleTest(unittest.TestCase):
|
||||
reason="https://github.com/nod-ai/SHARK/issues/311, https://github.com/nod-ai/SHARK/issues/342"
|
||||
)
|
||||
if config["model_name"] == "funnel-transformer/small" and device in [
|
||||
"cpu",
|
||||
"cuda",
|
||||
"metal",
|
||||
"vulkan",
|
||||
@@ -319,13 +316,6 @@ class SharkModuleTest(unittest.TestCase):
|
||||
pytest.xfail(
|
||||
reason="failing in the iree-compiler passes, see https://github.com/nod-ai/SHARK/issues/201"
|
||||
)
|
||||
if (
|
||||
config["model_name"] == "google/vit-base-patch16-224"
|
||||
and device == "cuda"
|
||||
):
|
||||
pytest.xfail(reason="https://github.com/nod-ai/SHARK/issues/311")
|
||||
if config["model_name"] == "microsoft/mpnet-base":
|
||||
pytest.xfail(reason="https://github.com/nod-ai/SHARK/issues/203")
|
||||
if config["model_name"] == "nvidia/mit-b0":
|
||||
pytest.xfail(reason="https://github.com/nod-ai/SHARK/issues/343")
|
||||
if (
|
||||
@@ -356,13 +346,74 @@ class SharkModuleTest(unittest.TestCase):
|
||||
pytest.xfail(
|
||||
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
|
||||
)
|
||||
if config["model_name"] == "mobilenet_v3_small" and device in [
|
||||
"metal",
|
||||
"vulkan",
|
||||
if config["model_name"] == "mobilenet_v3_small" and device not in [
|
||||
"cpu"
|
||||
]:
|
||||
pytest.xfail(
|
||||
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
|
||||
)
|
||||
if config["model_name"] == "mnasnet1_0" and device not in [
|
||||
"cpu",
|
||||
"cuda",
|
||||
]:
|
||||
pytest.xfail(
|
||||
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
|
||||
)
|
||||
if config["model_name"] == "hf-internal-testing/tiny-random-flaubert":
|
||||
pytest.xfail(reason="Transformers API mismatch")
|
||||
if config["model_name"] == "alexnet" and device in ["metal", "vulkan"]:
|
||||
pytest.xfail(reason="Assertion Error: Zeros Output")
|
||||
if (
|
||||
config["model_name"] == "camembert-base"
|
||||
and dynamic == False
|
||||
and device in ["metal", "vulkan"]
|
||||
):
|
||||
pytest.xfail(
|
||||
reason="chlo.broadcast_compare failed to satify constraint"
|
||||
)
|
||||
if (
|
||||
config["model_name"] == "roberta-base"
|
||||
and dynamic == False
|
||||
and device in ["metal", "vulkan"]
|
||||
):
|
||||
pytest.xfail(
|
||||
reason="chlo.broadcast_compare failed to satify constraint"
|
||||
)
|
||||
if config["model_name"] in [
|
||||
"microsoft/MiniLM-L12-H384-uncased",
|
||||
"wide_resnet50_2",
|
||||
"resnet50",
|
||||
"resnet18",
|
||||
"resnet101",
|
||||
"microsoft/resnet-50",
|
||||
] and device in ["metal", "vulkan"]:
|
||||
pytest.xfail(reason="Vulkan Numerical Error (mostly conv)")
|
||||
if config[
|
||||
"model_name"
|
||||
] == "dbmdz/convbert-base-turkish-cased" and device in ["cuda", "cpu"]:
|
||||
pytest.xfail(reason="https://github.com/nod-ai/SHARK/issues/463")
|
||||
if (
|
||||
config["model_name"]
|
||||
in [
|
||||
"facebook/convnext-tiny-224",
|
||||
"squeezenet1_0",
|
||||
]
|
||||
and device == "rocm"
|
||||
):
|
||||
pytest.xfail(
|
||||
reason="iree-compile buffer limit issue: https://github.com/nod-ai/SHARK/issues/475"
|
||||
)
|
||||
if (
|
||||
config["model_name"]
|
||||
in [
|
||||
"funnel-transformer/small",
|
||||
"mobilenet_v3_small",
|
||||
]
|
||||
and device == "rocm"
|
||||
):
|
||||
pytest.xfail(
|
||||
reason="Numerics issues: https://github.com/nod-ai/SHARK/issues/476"
|
||||
)
|
||||
if config["framework"] == "tf" and dynamic == True:
|
||||
pytest.skip(
|
||||
reason="Dynamic shapes not supported for this framework."
|
||||
@@ -376,10 +427,11 @@ class SharkModuleTest(unittest.TestCase):
|
||||
if not os.path.isdir("./shark_tmp/"):
|
||||
os.mkdir("./shark_tmp/")
|
||||
|
||||
tempdir = tempfile.TemporaryDirectory(
|
||||
prefix=self.module_tester.tmp_prefix, dir="./shark_tmp/"
|
||||
# We must create a new process each time we benchmark a model to allow
|
||||
# for Tensorflow to release GPU resources. Using the same process to
|
||||
# benchmark multiple models leads to OOM.
|
||||
p = multiprocessing.Process(
|
||||
target=run_test, args=(self.module_tester, dynamic, device)
|
||||
)
|
||||
self.module_tester.temp_dir = tempdir.name
|
||||
|
||||
with ireec.tools.TempFileSaver(tempdir.name):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
## Running SharkInference on CPUs, GPUs and MAC.
|
||||
|
||||
|
||||
### Run the binary sequence_classification.
|
||||
#### The models supported are: [hugging face sequence classification](https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForSequenceClassification)
|
||||
```shell
|
||||
./seq_classification.py --hf_model_name="hf_model" --device="cpu" # Use gpu | vulkan
|
||||
```
|
||||
|
||||
Once the model is compiled to run on the device mentioned, we can pass in text and
|
||||
get the logits.
|
||||
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user