mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 14:58:11 -05:00
Compare commits
337 Commits
20230707.8
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f6dd02fa67 | ||
|
|
fe03539901 | ||
|
|
dba2c8a567 | ||
|
|
8c923e5eea | ||
|
|
5fef4f694a | ||
|
|
6b7903793a | ||
|
|
ec75e08a2d | ||
|
|
d5013fd13e | ||
|
|
26f80ccbbb | ||
|
|
d2c3752dc7 | ||
|
|
4505c4549f | ||
|
|
793495c9c6 | ||
|
|
13e1d8d98a | ||
|
|
2074df40ad | ||
|
|
7b30582408 | ||
|
|
151195ab74 | ||
|
|
8146f0bd2f | ||
|
|
68e9281778 | ||
|
|
fd07cae991 | ||
|
|
6cb86a843e | ||
|
|
7db1612a5c | ||
|
|
81d6e059ac | ||
|
|
e003d0abe8 | ||
|
|
cf2513e7b1 | ||
|
|
60d8591e95 | ||
|
|
ff91982168 | ||
|
|
a6a9e524c1 | ||
|
|
732df2e263 | ||
|
|
1ee16bd256 | ||
|
|
752d775fbd | ||
|
|
4d1a6a204d | ||
|
|
0eff62a468 | ||
|
|
5a5de545c9 | ||
|
|
58f194a450 | ||
|
|
c5cf005292 | ||
|
|
12094ec49c | ||
|
|
100e5b8244 | ||
|
|
6bf51f1f1d | ||
|
|
05b498267e | ||
|
|
fa95ed30d1 | ||
|
|
788cc9157c | ||
|
|
ebfcfec338 | ||
|
|
f692a012e1 | ||
|
|
3cc643b2de | ||
|
|
bf70e80d20 | ||
|
|
7159698496 | ||
|
|
7e12d1782a | ||
|
|
bb5f133e1c | ||
|
|
3af0c6c658 | ||
|
|
3322b7264f | ||
|
|
eeb7bdd143 | ||
|
|
2d6f48821d | ||
|
|
c74b55f24e | ||
|
|
1a723645fb | ||
|
|
dfdd3b1f78 | ||
|
|
6384780d16 | ||
|
|
db0c53ae59 | ||
|
|
ce9ce3a7c8 | ||
|
|
d72da3801f | ||
|
|
9c50edc664 | ||
|
|
a1b7110550 | ||
|
|
ff15fd74f6 | ||
|
|
552b2c3ee3 | ||
|
|
795fc33001 | ||
|
|
2910841fe6 | ||
|
|
396a054856 | ||
|
|
5c66948d4f | ||
|
|
ed3dda94c0 | ||
|
|
d31d28b082 | ||
|
|
78c607e1d3 | ||
|
|
666e601dd9 | ||
|
|
ca58908e5b | ||
|
|
1f5b39f56e | ||
|
|
2da31c4109 | ||
|
|
da50a16242 | ||
|
|
ce38d49f05 | ||
|
|
2f780f0d38 | ||
|
|
d051c3a4a7 | ||
|
|
1b11c82c9d | ||
|
|
80a33d427f | ||
|
|
4125a26294 | ||
|
|
905d0103ff | ||
|
|
192b3b2c61 | ||
|
|
8f9adc4a2a | ||
|
|
70817bb50a | ||
|
|
dd37c26d36 | ||
|
|
a708879c6c | ||
|
|
bb1b49eb6f | ||
|
|
f6d41affd9 | ||
|
|
c2163488d8 | ||
|
|
54bff4611d | ||
|
|
11510d5111 | ||
|
|
32cab73a29 | ||
|
|
392bade0bf | ||
|
|
91df5f0613 | ||
|
|
df20cf9c8a | ||
|
|
c4a908c3ea | ||
|
|
6285430d8a | ||
|
|
51afe19e20 | ||
|
|
31005bcf73 | ||
|
|
f41ad87ef6 | ||
|
|
d811524a00 | ||
|
|
51e1bd1c5d | ||
|
|
db89b1bdc1 | ||
|
|
2754e2e257 | ||
|
|
ab0e870c43 | ||
|
|
fb30e8c226 | ||
|
|
a07d542400 | ||
|
|
ad55cb696f | ||
|
|
488a172292 | ||
|
|
500c4f2306 | ||
|
|
92b694db4d | ||
|
|
322874f7f9 | ||
|
|
5001db3415 | ||
|
|
71846344a2 | ||
|
|
72e27c96fc | ||
|
|
7963abb8ec | ||
|
|
98244232dd | ||
|
|
679a452139 | ||
|
|
72c0a8abc8 | ||
|
|
ea920f2955 | ||
|
|
486202377a | ||
|
|
0c38c33d0a | ||
|
|
841773fa32 | ||
|
|
0361db46f9 | ||
|
|
a012433ffd | ||
|
|
5061193da3 | ||
|
|
bff48924be | ||
|
|
825b36cbdd | ||
|
|
134441957d | ||
|
|
7cd14fdc47 | ||
|
|
e6cb5cef57 | ||
|
|
66abee8e5b | ||
|
|
4797bb89f5 | ||
|
|
205e57683a | ||
|
|
2866d665ee | ||
|
|
71d25ec5d8 | ||
|
|
202ffff67b | ||
|
|
0b77059628 | ||
|
|
a208302bb9 | ||
|
|
b83d32fafe | ||
|
|
0a618e1863 | ||
|
|
a731eb6ed4 | ||
|
|
2004d16945 | ||
|
|
6e409bfb77 | ||
|
|
77727d149c | ||
|
|
66f6e79d68 | ||
|
|
3b825579a7 | ||
|
|
9f0a421764 | ||
|
|
c28682110c | ||
|
|
caf6cc5d8f | ||
|
|
8614a18474 | ||
|
|
86c1c0c215 | ||
|
|
8bb364bcb8 | ||
|
|
7abddd01ec | ||
|
|
2a451fa0c7 | ||
|
|
9c4610b9da | ||
|
|
a38cc9d216 | ||
|
|
1c382449ec | ||
|
|
7cc9b3f8e8 | ||
|
|
e54517e967 | ||
|
|
326327a799 | ||
|
|
785b65c7b0 | ||
|
|
0d16c81687 | ||
|
|
8dd7850c69 | ||
|
|
e930ba85b4 | ||
|
|
cd732e7a38 | ||
|
|
8e0f8b3227 | ||
|
|
b8210ef796 | ||
|
|
94594542a9 | ||
|
|
82f833e87d | ||
|
|
c9d6870105 | ||
|
|
4fec03a6cc | ||
|
|
9a27f51378 | ||
|
|
ad1a0f35ff | ||
|
|
6773278ec2 | ||
|
|
9a0efffcca | ||
|
|
61c6f153d9 | ||
|
|
effd42e8f5 | ||
|
|
b5fbb1a8a0 | ||
|
|
ded74d09cd | ||
|
|
79267931c1 | ||
|
|
9eceba69b7 | ||
|
|
ca609afb6a | ||
|
|
11bdce9790 | ||
|
|
684943a4a6 | ||
|
|
b817bb8455 | ||
|
|
780f520f02 | ||
|
|
c61b6f8d65 | ||
|
|
c854208d49 | ||
|
|
c5dcfc1f13 | ||
|
|
bde63ee8ae | ||
|
|
9681d494eb | ||
|
|
ede6bf83e2 | ||
|
|
2c2693fb7d | ||
|
|
1d31b2b2c6 | ||
|
|
d2f64eefa3 | ||
|
|
87ae14b6ff | ||
|
|
1ccafa1fc1 | ||
|
|
4c3d8a0a7f | ||
|
|
3601dc7c3b | ||
|
|
671881cf87 | ||
|
|
4e9be6be59 | ||
|
|
9c8cbaf498 | ||
|
|
9e348a114e | ||
|
|
51f90a4d56 | ||
|
|
310d5d0a49 | ||
|
|
9697981004 | ||
|
|
450c231171 | ||
|
|
07f6f4a2f7 | ||
|
|
610813c72f | ||
|
|
8e3860c9e6 | ||
|
|
e37d6720eb | ||
|
|
16160d9a7d | ||
|
|
79075a1a07 | ||
|
|
db990826d3 | ||
|
|
7ee3e4ba5d | ||
|
|
05889a8fe1 | ||
|
|
b87efe7686 | ||
|
|
82b462de3a | ||
|
|
d8f0f7bade | ||
|
|
79bd0b84a1 | ||
|
|
8738571d1e | ||
|
|
a4c354ce54 | ||
|
|
cc53efa89f | ||
|
|
9ae8bc921e | ||
|
|
32eb78f0f9 | ||
|
|
cb509343d9 | ||
|
|
6da391c9b1 | ||
|
|
9dee7ae652 | ||
|
|
343dfd901c | ||
|
|
57260b9c37 | ||
|
|
18e7d2d061 | ||
|
|
51a1009796 | ||
|
|
045c3c3852 | ||
|
|
0139dd58d9 | ||
|
|
c96571855a | ||
|
|
4f61d69d86 | ||
|
|
531d447768 | ||
|
|
16f46f8de9 | ||
|
|
c4723f469f | ||
|
|
d804f45a61 | ||
|
|
d22177f936 | ||
|
|
75e68f02f4 | ||
|
|
4dc9c59611 | ||
|
|
18801dcabc | ||
|
|
3c577f7168 | ||
|
|
f5e4fa6ffe | ||
|
|
48de445325 | ||
|
|
8e90f1b81a | ||
|
|
e8c1203be2 | ||
|
|
e4d7abb519 | ||
|
|
96185c9dc1 | ||
|
|
bc22a81925 | ||
|
|
5203679f1f | ||
|
|
bf073f8f37 | ||
|
|
cec6eda6b4 | ||
|
|
9e37e03741 | ||
|
|
9b8c4401b5 | ||
|
|
a9f95a218b | ||
|
|
872bd72d0b | ||
|
|
fd1c4db5d0 | ||
|
|
759664bb48 | ||
|
|
14fd0cdd87 | ||
|
|
a57eccc997 | ||
|
|
a686d7d89f | ||
|
|
ed484b8253 | ||
|
|
7fe57ebaaf | ||
|
|
c287fd2be8 | ||
|
|
51ec1a1360 | ||
|
|
bd30044c0b | ||
|
|
c9de2729b2 | ||
|
|
a5b13fcc2f | ||
|
|
6bb329c4af | ||
|
|
98fb6c52df | ||
|
|
206c1b70f4 | ||
|
|
cdb037ee54 | ||
|
|
ce2fd84538 | ||
|
|
4684afad34 | ||
|
|
8d65456b7a | ||
|
|
d6759a852b | ||
|
|
ab57af43c1 | ||
|
|
4d5c55dd9f | ||
|
|
07399ad65c | ||
|
|
776a9c2293 | ||
|
|
9d399eb988 | ||
|
|
927b662aa7 | ||
|
|
47f8a79c75 | ||
|
|
289f983f41 | ||
|
|
453e46562f | ||
|
|
5497af1f56 | ||
|
|
f3cb63fc9c | ||
|
|
d7092aafaa | ||
|
|
a415f3f70e | ||
|
|
c292e5c9d7 | ||
|
|
03c4d9e171 | ||
|
|
3662224c04 | ||
|
|
db3f222933 | ||
|
|
68b3021325 | ||
|
|
336469154d | ||
|
|
41e5088908 | ||
|
|
0a8f7673f4 | ||
|
|
c482ab78da | ||
|
|
4be80f7158 | ||
|
|
536aba1424 | ||
|
|
dd738a0e02 | ||
|
|
8927cb0a2c | ||
|
|
8c317e4809 | ||
|
|
b0136593df | ||
|
|
11f62d7fac | ||
|
|
14559dd620 | ||
|
|
e503a3e8d6 | ||
|
|
22a4254adf | ||
|
|
ab01f0f048 | ||
|
|
c471d17cca | ||
|
|
a2a436eb0c | ||
|
|
1adb51b29d | ||
|
|
aab2233e25 | ||
|
|
e20cd71314 | ||
|
|
5ec91143f5 | ||
|
|
7cf19230e2 | ||
|
|
1bcf6b2c5b | ||
|
|
91027f8719 | ||
|
|
a909fc2e78 | ||
|
|
247f69cf9d | ||
|
|
3b8f7cc231 | ||
|
|
6e8dbf72bd | ||
|
|
38e5b62d80 | ||
|
|
1c7eecc981 | ||
|
|
be417f0bf4 | ||
|
|
a517e217b0 | ||
|
|
9fcae4f808 | ||
|
|
788d469c5b | ||
|
|
8a59f7cc27 | ||
|
|
1c2ec3c7a2 | ||
|
|
af0f715e20 | ||
|
|
47ec7275e6 |
2
.flake8
2
.flake8
@@ -2,4 +2,4 @@
|
||||
count = 1
|
||||
show-source = 1
|
||||
select = E9,F63,F7,F82
|
||||
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py
|
||||
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py
|
||||
|
||||
8
.github/workflows/gh-pages-releases.yml
vendored
8
.github/workflows/gh-pages-releases.yml
vendored
@@ -10,15 +10,15 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
# Don't run this in everyone's forks.
|
||||
if: github.repository == 'nod-ai/SHARK'
|
||||
if: github.repository == 'nod-ai/AMD-SHARK-Studio'
|
||||
|
||||
steps:
|
||||
- name: Checking out repository
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
token: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
- name: Run scrape releases script
|
||||
run: python ./build_tools/scrape_releases.py nod-ai SHARK > /tmp/index.html
|
||||
run: python ./build_tools/scrape_releases.py nod-ai AMD-SHARK-Studio > /tmp/index.html
|
||||
shell: bash
|
||||
- run: git fetch --all
|
||||
- run: git switch github-pages
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
- run: git diff --cached --exit-code || git commit -m "Update releases."
|
||||
|
||||
- name: GitHub Push
|
||||
uses: ad-m/github-push-action@v0.6.0
|
||||
uses: ad-m/github-push-action@d91a481090679876dfc4178fef17f286781251df # v0.8.0
|
||||
with:
|
||||
github_token: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
branch: github-pages
|
||||
|
||||
104
.github/workflows/nightly.yml
vendored
104
.github/workflows/nightly.yml
vendored
@@ -17,9 +17,9 @@ jobs:
|
||||
python-version: ["3.11"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
@@ -35,14 +35,14 @@ jobs:
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
uses: ncipollo/release-action@440c8c1cb0ed28b9f43e4d1d670870f059653174 # v1.16.0
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ env.tag_name }}
|
||||
release_name: nod.ai SHARK ${{ env.tag_name }}
|
||||
tag: ${{ env.tag_name }}
|
||||
name: nod.ai AMDSHARK ${{ env.tag_name }}
|
||||
body: |
|
||||
Automatic snapshot release of nod.ai SHARK.
|
||||
Automatic snapshot release of nod.ai AMDSHARK.
|
||||
draft: true
|
||||
prerelease: true
|
||||
|
||||
@@ -50,16 +50,17 @@ jobs:
|
||||
shell: powershell
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
$env:AMDSHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip install -e .
|
||||
pip freeze -l
|
||||
pyinstaller .\apps\amdshark_studio\amdshark_studio.spec
|
||||
mv ./dist/nodai_amdshark_studio.exe ./dist/nodai_amdshark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\amdshark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_amdshark_studio_${{ env.package_version_ }}.exe
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
uses: dwenegar/upload-release-assets@fe47e06814723c7b1bea3a7e46cf93d5f020d0c3 # v3
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
@@ -69,85 +70,8 @@ jobs:
|
||||
|
||||
- name: Publish Release
|
||||
id: publish_release
|
||||
uses: eregon/publish-release@v1
|
||||
uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
|
||||
linux-build:
|
||||
|
||||
runs-on: a100
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
backend: [IREE, SHARK]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup pip cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
|
||||
- name: Build and validate the IREE package
|
||||
if: ${{ matrix.backend == 'IREE' }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
|
||||
source iree.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
if !(grep -Fxq " failed" pytest_results.txt)
|
||||
then
|
||||
export SHA=$(git log -1 --format='%h')
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
|
||||
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
|
||||
fi
|
||||
rm -rf ./wheelhouse/nodai*
|
||||
|
||||
- name: Build and validate the SHARK Runtime package
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
|
||||
162
.github/workflows/test-models.yml
vendored
162
.github/workflows/test-models.yml
vendored
@@ -1,162 +0,0 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Validate Models on Shark Runtime
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'shark/examples/**'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'shark/examples/**'
|
||||
workflow_dispatch:
|
||||
|
||||
# Ensure that only a single job or workflow using the same
|
||||
# concurrency group will run at a time. This would cancel
|
||||
# any in-progress jobs in the same github workflow and github
|
||||
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-validate:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
|
||||
suite: [cpu,cuda,vulkan]
|
||||
python-version: ["3.11"]
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
suite: lint
|
||||
- os: MacStudio
|
||||
suite: metal
|
||||
exclude:
|
||||
- os: ubuntu-latest
|
||||
suite: vulkan
|
||||
- os: ubuntu-latest
|
||||
suite: cuda
|
||||
- os: ubuntu-latest
|
||||
suite: cpu
|
||||
- os: MacStudio
|
||||
suite: cuda
|
||||
- os: MacStudio
|
||||
suite: cpu
|
||||
- os: MacStudio
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
suite: cuda
|
||||
- os: a100
|
||||
suite: cpu
|
||||
- os: 7950x
|
||||
suite: cpu
|
||||
- os: 7950x
|
||||
suite: cuda
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set Environment Variables
|
||||
if: matrix.os != '7950x'
|
||||
run: |
|
||||
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Python Version File ${{ matrix.python-version }}
|
||||
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
|
||||
run: |
|
||||
# See https://github.com/actions/setup-python/issues/433
|
||||
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '${{ matrix.python-version }}'
|
||||
#cache: 'pip'
|
||||
#cache-dependency-path: |
|
||||
# **/requirements-importer.txt
|
||||
# **/requirements.txt
|
||||
|
||||
- name: Install dependencies
|
||||
if: matrix.suite == 'lint'
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml black
|
||||
|
||||
- name: Lint with flake8
|
||||
if: matrix.suite == 'lint'
|
||||
run: |
|
||||
# black format check
|
||||
black --version
|
||||
black --check .
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
|
||||
--statistics --exclude lit.cfg.py
|
||||
|
||||
- name: Validate Models on CPU
|
||||
if: matrix.suite == 'cpu'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
|
||||
- name: Validate Models on NVIDIA GPU
|
||||
if: matrix.suite == 'cuda'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
# Disabled due to black image bug
|
||||
# python build_tools/stable_diffusion_testing.py --device=cuda
|
||||
|
||||
- name: Validate Vulkan Models (MacOS)
|
||||
if: matrix.suite == 'metal' && matrix.os == 'MacStudio'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
|
||||
- name: Validate Vulkan Models (Windows)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
pytest -k vulkan -s --ci
|
||||
|
||||
- name: Validate Stable Diffusion Models (Windows)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
85
.github/workflows/test-studio.yml
vendored
Normal file
85
.github/workflows/test-studio.yml
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Validate AMDShark Studio
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'amdshark/examples/**'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'amdshark/examples/**'
|
||||
workflow_dispatch:
|
||||
|
||||
# Ensure that only a single job or workflow using the same
|
||||
# concurrency group will run at a time. This would cancel
|
||||
# any in-progress jobs in the same github workflow and github
|
||||
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-validate:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [nodai-ubuntu-builder-large]
|
||||
suite: [cpu] #,cuda,vulkan]
|
||||
python-version: ["3.11"]
|
||||
include:
|
||||
- os: nodai-ubuntu-builder-large
|
||||
suite: lint
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Set Environment Variables
|
||||
run: |
|
||||
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Python Version File ${{ matrix.python-version }}
|
||||
run: |
|
||||
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: '${{ matrix.python-version }}'
|
||||
|
||||
- name: Install dependencies
|
||||
if: matrix.suite == 'lint'
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml black
|
||||
|
||||
- name: Lint with flake8
|
||||
if: matrix.suite == 'lint'
|
||||
run: |
|
||||
# black format check
|
||||
black --version
|
||||
black --check apps/amdshark_studio
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
|
||||
--statistics --exclude lit.cfg.py
|
||||
|
||||
- name: Validate Models on CPU
|
||||
if: matrix.suite == 'cpu'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
python${{ matrix.python-version }} -m venv amdshark.venv
|
||||
source amdshark.venv/bin/activate
|
||||
pip install -r requirements.txt --no-cache-dir
|
||||
pip install -e .
|
||||
# Disabled due to hang when exporting test llama2
|
||||
# python apps/amdshark_studio/tests/api_test.py
|
||||
26
.gitignore
vendored
26
.gitignore
vendored
@@ -164,14 +164,15 @@ cython_debug/
|
||||
# vscode related
|
||||
.vscode
|
||||
|
||||
# Shark related artefacts
|
||||
# AMDShark related artifacts
|
||||
*venv/
|
||||
shark_tmp/
|
||||
amdshark_tmp/
|
||||
*.vmfb
|
||||
.use-iree
|
||||
tank/dict_configs.py
|
||||
*.csv
|
||||
reproducers/
|
||||
apps/amdshark_studio/web/configs
|
||||
|
||||
# ORT related artefacts
|
||||
cache_models/
|
||||
@@ -182,10 +183,29 @@ generated_imgs/
|
||||
|
||||
# Custom model related artefacts
|
||||
variants.json
|
||||
models/
|
||||
/models/
|
||||
*.safetensors
|
||||
|
||||
# models folder
|
||||
apps/stable_diffusion/web/models/
|
||||
|
||||
# model artifacts (AMDSHARK)
|
||||
*.tempfile
|
||||
*.mlir
|
||||
*.vmfb
|
||||
|
||||
# Stencil annotators.
|
||||
stencil_annotator/
|
||||
|
||||
# For DocuChat
|
||||
apps/language_models/langchain/user_path/
|
||||
db_dir_UserData
|
||||
|
||||
# Embeded browser cache and other
|
||||
apps/stable_diffusion/web/EBWebView/
|
||||
|
||||
# Llama2 tokenizer configs
|
||||
llama2_tokenizer_configs/
|
||||
|
||||
# Webview2 runtime artefacts
|
||||
EBWebView/
|
||||
|
||||
6
.gitmodules
vendored
6
.gitmodules
vendored
@@ -1,4 +1,4 @@
|
||||
[submodule "inference/thirdparty/shark-runtime"]
|
||||
path = inference/thirdparty/shark-runtime
|
||||
url =https://github.com/nod-ai/SHARK-Runtime.git
|
||||
[submodule "inference/thirdparty/amdshark-runtime"]
|
||||
path = inference/thirdparty/amdshark-runtime
|
||||
url =https://github.com/nod-ai/SRT.git
|
||||
branch = shark-06032022
|
||||
|
||||
174
README.md
174
README.md
@@ -1,19 +1,21 @@
|
||||
# SHARK
|
||||
# AMDSHARK
|
||||
|
||||
High Performance Machine Learning Distribution
|
||||
|
||||
[](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
|
||||
[](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
|
||||
<h2>NOTE: This project is not currently maintained.</h2>
|
||||
|
||||
*The latest versions of this project are developments towards a refactor on top of IREE-Turbine. Until further notice, make sure you use an .exe release or a checkout of the `AMDSHARK-1.0` branch, for a working AMDSHARK-Studio*
|
||||
|
||||
[](https://github.com/nod-ai/AMD-SHARK-Studio/actions/workflows/nightly.yml)
|
||||
|
||||
<details>
|
||||
<summary>Prerequisites - Drivers </summary>
|
||||
|
||||
|
||||
#### Install your Windows hardware drivers
|
||||
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
|
||||
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
|
||||
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
|
||||
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
|
||||
#### Linux Drivers
|
||||
* MESA / RADV drivers wont work with FP16. Please use the latest AMGPU-PRO drivers (non-pro OSS drivers also wont work) or the latest NVidia Linux Drivers.
|
||||
|
||||
@@ -22,23 +24,23 @@ Other users please ensure you have your latest vendor drivers and Vulkan SDK fro
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
### Quick Start for SHARK Stable Diffusion for Windows 10/11 Users
|
||||
|
||||
Install the Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
|
||||
### Quick Start for AMDSHARK Stable Diffusion for Windows 10/11 Users
|
||||
|
||||
Download the [stable release](https://github.com/nod-ai/shark/releases/latest)
|
||||
Install the Driver from [Prerequisites](https://github.com/nod-ai/AMD-SHARK-Studio#install-your-hardware-drivers) above
|
||||
|
||||
Double click the .exe and you should have the [UI](http://localhost:8080/) in the browser.
|
||||
Download the [stable release](https://github.com/nod-ai/AMD-SHARK-Studio/releases/latest) or the most recent [AMDSHARK 1.0 pre-release](https://github.com/nod-ai/AMD-SHARK-Studio/releases).
|
||||
|
||||
If you have custom models put them in a `models/` directory where the .exe is.
|
||||
Double click the .exe, or [run from the command line](#running) (recommended), and you should have the [UI](http://localhost:8080/) in the browser.
|
||||
|
||||
Enjoy.
|
||||
If you have custom models put them in a `models/` directory where the .exe is.
|
||||
|
||||
Enjoy.
|
||||
|
||||
<details>
|
||||
<summary>More installation notes</summary>
|
||||
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
|
||||
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
|
||||
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
|
||||
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
|
||||
|
||||
## Running
|
||||
|
||||
@@ -46,38 +48,51 @@ Enjoy.
|
||||
* The first run may take few minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
|
||||
* You will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
|
||||
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/.
|
||||
* If you prefer to always run in the browser, use the `--ui=web` command argument when running the EXE.
|
||||
|
||||
## Stopping
|
||||
|
||||
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
|
||||
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Advanced Installation (Only for developers)</summary>
|
||||
|
||||
|
||||
## Advanced Installation (Windows, Linux and macOS) for developers
|
||||
|
||||
### Windows 10/11 Users
|
||||
|
||||
* Install Git for Windows from [here](https://git-scm.com/download/win) if you don't already have it.
|
||||
|
||||
## Check out the code
|
||||
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
cd SHARK
|
||||
git clone https://github.com/nod-ai/AMD-SHARK-Studio.git
|
||||
cd AMD-SHARK-Studio
|
||||
```
|
||||
|
||||
## Switch to the Correct Branch (IMPORTANT!)
|
||||
|
||||
Currently AMDSHARK is being rebuilt for [Turbine](https://github.com/iree-org/iree-turbine) on the `main` branch. For now you are strongly discouraged from using `main` unless you are working on the rebuild effort, and should not expect the code there to produce a working application for Image Generation, So for now you'll need switch over to the `AMDSHARK-1.0` branch and use the stable code.
|
||||
|
||||
```shell
|
||||
git checkout AMDSHARK-1.0
|
||||
```
|
||||
|
||||
The following setup instructions assume you are on this branch.
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
|
||||
### Windows 10/11 Users
|
||||
|
||||
* Install the latest Python 3.11.x version from [here](https://www.python.org/downloads/windows/)
|
||||
|
||||
* Install Git for Windows from [here](https://git-scm.com/download/win)
|
||||
|
||||
#### Allow the install script to run in Powershell
|
||||
```powershell
|
||||
set-executionpolicy remotesigned
|
||||
```
|
||||
|
||||
#### Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...)
|
||||
#### Setup venv and install necessary packages (torch-mlir, nodLabs/AMDShark, ...)
|
||||
```powershell
|
||||
./setup_venv.ps1 #You can re-run this script to get the latest version
|
||||
```
|
||||
@@ -86,21 +101,20 @@ set-executionpolicy remotesigned
|
||||
|
||||
```shell
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
source amdshark1.venv/bin/activate
|
||||
```
|
||||
|
||||
|
||||
### Run Stable Diffusion on your device - WebUI
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
|
||||
(shark.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
|
||||
(amdshark1.venv) PS C:\g\amdshark> cd .\apps\stable_diffusion\web\
|
||||
(amdshark1.venv) PS C:\g\amdshark\apps\stable_diffusion\web> python .\index.py
|
||||
```
|
||||
#### Linux / macOS Users
|
||||
```shell
|
||||
(shark.venv) > cd apps/stable_diffusion/web
|
||||
(shark.venv) > python index.py
|
||||
(amdshark1.venv) > cd apps/stable_diffusion/web
|
||||
(amdshark1.venv) > python index.py
|
||||
```
|
||||
|
||||
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
|
||||
@@ -114,7 +128,7 @@ source shark.venv/bin/activate
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
(amdshark1.venv) PS C:\g\amdshark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
```
|
||||
|
||||
#### Linux / macOS Users
|
||||
@@ -142,7 +156,7 @@ Here are some samples generated:
|
||||

|
||||
|
||||
|
||||
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
|
||||
Find us on [AMDSHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
|
||||
|
||||
|
||||
<details>
|
||||
@@ -154,8 +168,8 @@ This step sets up a new VirtualEnv for Python
|
||||
|
||||
```shell
|
||||
python --version #Check you have 3.11 on Linux, macOS or Windows Powershell
|
||||
python -m venv shark_venv
|
||||
source shark_venv/bin/activate # Use shark_venv/Scripts/activate on Windows
|
||||
python -m venv amdshark_venv
|
||||
source amdshark_venv/bin/activate # Use amdshark_venv/Scripts/activate on Windows
|
||||
|
||||
# If you are using conda create and activate a new conda env
|
||||
|
||||
@@ -165,15 +179,15 @@ python -m pip install --upgrade pip
|
||||
|
||||
*macOS Metal* users please install https://sdk.lunarg.com/sdk/download/latest/mac/vulkan-sdk.dmg and enable "System wide install"
|
||||
|
||||
### Install SHARK
|
||||
### Install AMD-SHARK
|
||||
|
||||
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
|
||||
This step pip installs AMD-SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
|
||||
|
||||
```shell
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip install nodai-amdshark -f https://nod-ai.github.io/AMD-SHARK-Studio/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
|
||||
### Run shark tank model tests.
|
||||
### Run amdshark tank model tests.
|
||||
```shell
|
||||
pytest tank/test_models.py
|
||||
```
|
||||
@@ -182,7 +196,7 @@ See tank/README.md for a more detailed walkthrough of our pytest suite and CLI.
|
||||
### Download and run Resnet50 sample
|
||||
|
||||
```shell
|
||||
curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/resnet50_script.py
|
||||
curl -O https://raw.githubusercontent.com/nod-ai/AMD-SHARK-Studio/main/amdshark/examples/amdshark_inference/resnet50_script.py
|
||||
#Install deps for test script
|
||||
pip install --pre torch torchvision torchaudio tqdm pillow gsutil --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python ./resnet50_script.py --device="cpu" #use cuda or vulkan or metal
|
||||
@@ -190,7 +204,7 @@ python ./resnet50_script.py --device="cpu" #use cuda or vulkan or metal
|
||||
|
||||
### Download and run BERT (MiniLM) sample
|
||||
```shell
|
||||
curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/minilm_jit.py
|
||||
curl -O https://raw.githubusercontent.com/nod-ai/AMD-SHARK-Studio/main/amdshark/examples/amdshark_inference/minilm_jit.py
|
||||
#Install deps for test script
|
||||
pip install transformers torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
|
||||
@@ -205,56 +219,55 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
|
||||
If you want to use Python3.11 and with TF Import tools you can use the environment variables like:
|
||||
Set `USE_IREE=1` to use upstream IREE
|
||||
```
|
||||
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
|
||||
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
|
||||
```
|
||||
|
||||
### Run any of the hundreds of SHARK tank models via the test framework
|
||||
### Run any of the hundreds of AMDSHARK tank models via the test framework
|
||||
```shell
|
||||
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
|
||||
python -m amdshark.examples.amdshark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
|
||||
# Or a pytest
|
||||
pytest tank/test_models.py -k "MiniLM"
|
||||
```
|
||||
|
||||
### How to use your locally built IREE / Torch-MLIR with SHARK
|
||||
|
||||
### How to use your locally built IREE / Torch-MLIR with AMDSHARK
|
||||
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
|
||||
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
|
||||
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
|
||||
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
|
||||
for Torch-MLIR.
|
||||
|
||||
How to use your locally built Torch-MLIR with SHARK:
|
||||
How to use your locally built Torch-MLIR with AMDSHARK:
|
||||
```shell
|
||||
1.) Run `./setup_venv.sh in SHARK` and activate `shark.venv` virtual env.
|
||||
1.) Run `./setup_venv.sh in AMDSHARK` and activate `amdshark.venv` virtual env.
|
||||
2.) Run `pip uninstall torch-mlir`.
|
||||
3.) Go to your local Torch-MLIR directory.
|
||||
4.) Activate mlir_venv virtual envirnoment.
|
||||
5.) Run `pip uninstall -r requirements.txt`.
|
||||
6.) Run `pip install -r requirements.txt`.
|
||||
7.) Build Torch-MLIR.
|
||||
8.) Activate shark.venv virtual environment from the Torch-MLIR directory.
|
||||
8.) Activate amdshark.venv virtual environment from the Torch-MLIR directory.
|
||||
8.) Run `export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples` in the Torch-MLIR directory.
|
||||
9.) Go to the SHARK directory.
|
||||
9.) Go to the AMDSHARK directory.
|
||||
```
|
||||
Now the SHARK will use your locally build Torch-MLIR repo.
|
||||
Now the AMDSHARK will use your locally build Torch-MLIR repo.
|
||||
|
||||
|
||||
## Benchmarking Dispatches
|
||||
|
||||
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
|
||||
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
|
||||
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
|
||||
|
||||
For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
|
||||
```
|
||||
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
|
||||
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
|
||||
```
|
||||
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.
|
||||
|
||||
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
|
||||
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `AMDSharkInference`, and the benchmarks will be generated when compiled. E.G:
|
||||
|
||||
```
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_model,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
dispatch_benchmarks="all",
|
||||
@@ -265,41 +278,41 @@ shark_module = SharkInference(
|
||||
Output will include:
|
||||
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
|
||||
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
|
||||
- An .mlir file containing the dispatch benchmark
|
||||
- An .mlir file containing the dispatch benchmark
|
||||
- A compiled .vmfb file containing the dispatch benchmark
|
||||
- An .mlir file containing just the hal executable
|
||||
- A compiled .vmfb file of the hal executable
|
||||
- A .txt file containing benchmark output
|
||||
|
||||
|
||||
See tank/README.md for further instructions on how to run model tests and benchmarks from the SHARK tank.
|
||||
See tank/README.md for further instructions on how to run model tests and benchmarks from the AMDSHARK tank.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>API Reference</summary>
|
||||
|
||||
### Shark Inference API
|
||||
### AMDShark Inference API
|
||||
|
||||
```
|
||||
|
||||
from shark.shark_importer import SharkImporter
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
|
||||
# SharkImporter imports mlir file from the torch, tensorflow or tf-lite module.
|
||||
# AMDSharkImporter imports mlir file from the torch, tensorflow or tf-lite module.
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
mlir_importer = AMDSharkImporter(
|
||||
torch_module,
|
||||
(input),
|
||||
frontend="torch", #tf, #tf-lite
|
||||
)
|
||||
torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
|
||||
|
||||
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
|
||||
# AMDSharkInference accepts mlir in linalg, mhlo, and tosa dialect.
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input))
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
amdshark_module = AMDSharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward((input))
|
||||
|
||||
```
|
||||
|
||||
@@ -307,7 +320,7 @@ result = shark_module.forward((input))
|
||||
### Example demonstrating running MHLO IR.
|
||||
|
||||
```
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
import numpy as np
|
||||
|
||||
mhlo_ir = r"""builtin.module {
|
||||
@@ -320,17 +333,22 @@ mhlo_ir = r"""builtin.module {
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((arg0, arg1))
|
||||
amdshark_module = AMDSharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward((arg0, arg1))
|
||||
```
|
||||
</details>
|
||||
|
||||
## Examples Using the REST API
|
||||
|
||||
* [Setting up AMDSHARK for use with Blender](./docs/amdshark_sd_blender.md)
|
||||
* [Setting up AMDSHARK for use with Koboldcpp](./docs/amdshark_sd_koboldcpp.md)
|
||||
|
||||
## Supported and Validated Models
|
||||
|
||||
SHARK is maintained to support the latest innovations in ML Models:
|
||||
AMDSHARK is maintained to support the latest innovations in ML Models:
|
||||
|
||||
| TF HuggingFace Models | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
| TF HuggingFace Models | AMDSHARK-CPU | AMDSHARK-CUDA | AMDSHARK-METAL |
|
||||
|---------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DistilBERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
@@ -340,12 +358,12 @@ SHARK is maintained to support the latest innovations in ML Models:
|
||||
| Vision Transformer | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ResNet50 | :green_heart: | :green_heart: | :green_heart: |
|
||||
|
||||
For a complete list of the models supported in SHARK, please refer to [tank/README.md](https://github.com/nod-ai/SHARK/blob/main/tank/README.md).
|
||||
For a complete list of the models supported in AMDSHARK, please refer to [tank/README.md](https://github.com/nod-ai/AMD-SHARK-Studio/blob/main/tank/README.md).
|
||||
|
||||
## Communication Channels
|
||||
|
||||
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
|
||||
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
|
||||
* [AMDSHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the AMDSHARK team and other users
|
||||
* [GitHub issues](https://github.com/nod-ai/AMD-SHARK-Studio/issues): Feature requests, bugs etc
|
||||
|
||||
## Related Projects
|
||||
|
||||
@@ -354,7 +372,7 @@ For a complete list of the models supported in SHARK, please refer to [tank/READ
|
||||
|
||||
* [Upstream IREE issues](https://github.com/google/iree/issues): Feature requests,
|
||||
bugs, and other work tracking
|
||||
* [Upstream IREE Discord server](https://discord.gg/26P4xW4): Daily development
|
||||
* [Upstream IREE Discord server](https://discord.gg/wEWh6Z9nMU): Daily development
|
||||
discussions with the core team and collaborators
|
||||
* [iree-discuss email list](https://groups.google.com/forum/#!forum/iree-discuss):
|
||||
Announcements, general and low-priority discussion
|
||||
@@ -367,10 +385,10 @@ For a complete list of the models supported in SHARK, please refer to [tank/READ
|
||||
* Torch-MLIR Github issues [here](https://github.com/llvm/torch-mlir/issues)
|
||||
* [`torch-mlir` section](https://llvm.discourse.group/c/projects-that-want-to-become-official-llvm-projects/torch-mlir/41) of LLVM Discourse
|
||||
* Weekly meetings on Mondays 9AM PST. See [here](https://discourse.llvm.org/t/community-meeting-developer-hour-refactoring-recurring-meetings/62575) for more information.
|
||||
* [MLIR topic within LLVM Discourse](https://llvm.discourse.group/c/llvm-project/mlir/31) SHARK and IREE is enabled by and heavily relies on [MLIR](https://mlir.llvm.org).
|
||||
* [MLIR topic within LLVM Discourse](https://llvm.discourse.group/c/llvm-project/mlir/31) AMDSHARK and IREE is enabled by and heavily relies on [MLIR](https://mlir.llvm.org).
|
||||
</details>
|
||||
|
||||
|
||||
## License
|
||||
|
||||
nod.ai SHARK is licensed under the terms of the Apache 2.0 License with LLVM Exceptions.
|
||||
nod.ai AMDSHARK is licensed under the terms of the Apache 2.0 License with LLVM Exceptions.
|
||||
See [LICENSE](LICENSE) for more information.
|
||||
|
||||
28
amdshark/__init__.py
Normal file
28
amdshark/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
from torch._dynamo import register_backend
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_backend
|
||||
def amdshark(model, inputs, *, options):
|
||||
try:
|
||||
from amdshark.dynamo_backend.utils import AMDSharkBackend
|
||||
except ImportError:
|
||||
log.exception(
|
||||
"Unable to import AMDSHARK - High Performance Machine Learning Distribution"
|
||||
"Please install the right version of AMDSHARK that matches the PyTorch version being used. "
|
||||
"Refer to https://github.com/nod-ai/AMD-SHARK-Studio/ for details."
|
||||
)
|
||||
raise
|
||||
return AMDSharkBackend(model, inputs, options)
|
||||
|
||||
|
||||
def has_amdshark():
|
||||
try:
|
||||
importlib.import_module("amdshark")
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
@@ -12,13 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.iree_utils.compile_utils import export_iree_module_to_vmfb
|
||||
from shark.iree_utils.benchmark_utils import (
|
||||
from amdshark.amdshark_runner import AMDSharkRunner
|
||||
from amdshark.iree_utils.compile_utils import (
|
||||
export_iree_module_to_vmfb,
|
||||
load_flatbuffer,
|
||||
get_iree_runtime_config,
|
||||
)
|
||||
from amdshark.iree_utils.benchmark_utils import (
|
||||
build_benchmark_args,
|
||||
run_benchmark_module,
|
||||
)
|
||||
from shark.parser import shark_args
|
||||
from amdshark.parser import amdshark_args
|
||||
from datetime import datetime
|
||||
import time
|
||||
from typing import Optional
|
||||
@@ -63,8 +67,8 @@ def check_requirements(frontend):
|
||||
return has_pkgs
|
||||
|
||||
|
||||
class SharkBenchmarkRunner(SharkRunner):
|
||||
# SharkRunner derived class with Benchmarking capabilities.
|
||||
class AMDSharkBenchmarkRunner(AMDSharkRunner):
|
||||
# AMDSharkRunner derived class with Benchmarking capabilities.
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: bytes,
|
||||
@@ -72,29 +76,46 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
mlir_dialect: str = "linalg",
|
||||
extra_args: list = [],
|
||||
):
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.enable_tf32 = shark_args.enable_tf32
|
||||
self.device = amdshark_args.device if device == "none" else device
|
||||
self.enable_tf32 = amdshark_args.enable_tf32
|
||||
self.frontend_model = None
|
||||
self.vmfb_file = None
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
self.import_args = {}
|
||||
SharkRunner.__init__(
|
||||
self.temp_file_to_unlink = None
|
||||
if not os.path.isfile(mlir_module):
|
||||
print(
|
||||
"Warning: Initializing AMDSharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize AMDSharkInference with a path to a MLIR module on your hard disk instead."
|
||||
)
|
||||
self.compile_str = True
|
||||
else:
|
||||
self.compile_str = False
|
||||
AMDSharkRunner.__init__(
|
||||
self,
|
||||
mlir_module,
|
||||
device,
|
||||
self.mlir_dialect,
|
||||
self.extra_args,
|
||||
compile_vmfb=True,
|
||||
compile_vmfb=False,
|
||||
)
|
||||
if self.vmfb_file == None:
|
||||
self.vmfb_file = export_iree_module_to_vmfb(
|
||||
mlir_module,
|
||||
device,
|
||||
".",
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
self.vmfb_file = export_iree_module_to_vmfb(
|
||||
mlir_module,
|
||||
device,
|
||||
".",
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
params = load_flatbuffer(
|
||||
self.vmfb_file,
|
||||
device,
|
||||
mmap=True,
|
||||
)
|
||||
self.iree_compilation_module = params["vmfb"]
|
||||
self.iree_config = params["config"]
|
||||
self.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
|
||||
def setup_cl(self, input_tensors):
|
||||
self.benchmark_cl = build_benchmark_args(
|
||||
@@ -111,42 +132,41 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
elif self.mlir_dialect in ["mhlo", "tf"]:
|
||||
return self.benchmark_tf(modelname)
|
||||
|
||||
def benchmark_torch(self, modelname):
|
||||
def benchmark_torch(self, modelname, device="cpu"):
|
||||
import torch
|
||||
from tank.model_utils import get_torch_model
|
||||
|
||||
if self.device == "cuda":
|
||||
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||
if self.enable_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
# TODO: Pass this as an arg. currently the best way is to setup with BENCHMARK=1 if we want to use torch+cuda, else use cpu.
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if device == "cuda":
|
||||
torch.set_default_device("cuda:0")
|
||||
# if self.enable_tf32:
|
||||
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
torch_device = torch.device(
|
||||
"cuda:0" if self.device == "cuda" else "cpu"
|
||||
)
|
||||
torch.set_default_dtype(torch.float32)
|
||||
torch.set_default_device("cpu")
|
||||
torch_device = torch.device("cuda:0" if device == "cuda" else "cpu")
|
||||
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
|
||||
frontend_model = HFmodel.model
|
||||
frontend_model.to(torch_device)
|
||||
input.to(torch_device)
|
||||
if device == "cuda":
|
||||
frontend_model.cuda()
|
||||
input.to(torch.device("cuda:0"))
|
||||
print(input)
|
||||
else:
|
||||
frontend_model.cpu()
|
||||
input.cpu()
|
||||
|
||||
# TODO: re-enable as soon as pytorch CUDA context issues are resolved
|
||||
try:
|
||||
frontend_model = torch.compile(
|
||||
frontend_model, mode="max-autotune", backend="inductor"
|
||||
)
|
||||
except RuntimeError:
|
||||
frontend_model = HFmodel.model
|
||||
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
for i in range(amdshark_args.num_warmup_iterations):
|
||||
frontend_model.forward(input)
|
||||
|
||||
if self.device == "cuda":
|
||||
if device == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
for i in range(amdshark_args.num_iterations):
|
||||
out = frontend_model.forward(input)
|
||||
end = time.time()
|
||||
if self.device == "cuda":
|
||||
if device == "cuda":
|
||||
stats = torch.cuda.memory_stats()
|
||||
device_peak_b = stats["allocated_bytes.all.peak"]
|
||||
frontend_model.to(torch.device("cpu"))
|
||||
@@ -156,14 +176,14 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
device_peak_b = None
|
||||
|
||||
print(
|
||||
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
f"Torch benchmark:{amdshark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{amdshark_args.num_iterations}"
|
||||
)
|
||||
if self.device == "cuda":
|
||||
if device == "cuda":
|
||||
# Set device to CPU so we don't run into segfaults exiting pytest subprocesses.
|
||||
torch_device = torch.device("cpu")
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
f"{amdshark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/amdshark_args.num_iterations)*1000}",
|
||||
"", # host_peak_b (CPU usage) is not reported by PyTorch.
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
@@ -197,13 +217,13 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
)[:2]
|
||||
frontend_model = model
|
||||
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
for i in range(amdshark_args.num_warmup_iterations):
|
||||
frontend_model.forward(*input)
|
||||
|
||||
if tf_device == TF_GPU_DEVICE:
|
||||
tf.config.experimental.reset_memory_stats(tf_device)
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
for i in range(amdshark_args.num_iterations):
|
||||
out = frontend_model.forward(*input)
|
||||
end = time.time()
|
||||
if tf_device == TF_GPU_DEVICE:
|
||||
@@ -215,11 +235,11 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
device_peak_b = None
|
||||
|
||||
print(
|
||||
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
f"TF benchmark:{amdshark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{amdshark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
f"{amdshark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/amdshark_args.num_iterations)*1000}",
|
||||
"", # host_peak_b (CPU usage) is not reported by TensorFlow.
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
@@ -228,7 +248,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
iter_per_second, host_peak_b, device_peak_b = run_benchmark_module(
|
||||
self.benchmark_cl
|
||||
)
|
||||
print(f"Shark-IREE-C benchmark:{iter_per_second} iter/second")
|
||||
print(f"AMDShark-IREE-C benchmark:{iter_per_second} iter/second")
|
||||
return [
|
||||
f"{iter_per_second}",
|
||||
f"{1000/iter_per_second}",
|
||||
@@ -238,25 +258,25 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
|
||||
def benchmark_python(self, inputs):
|
||||
input_list = [x for x in inputs]
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
for i in range(amdshark_args.num_warmup_iterations):
|
||||
self.run("forward", input_list)
|
||||
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
for i in range(amdshark_args.num_iterations):
|
||||
out = self.run("forward", input_list)
|
||||
end = time.time()
|
||||
print(
|
||||
f"Shark-IREE Python benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
f"AMDShark-IREE Python benchmark:{amdshark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{amdshark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
f"{amdshark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/amdshark_args.num_iterations)*1000}",
|
||||
]
|
||||
|
||||
def benchmark_onnx(self, modelname, inputs):
|
||||
if self.device == "cuda":
|
||||
print(
|
||||
"Currently GPU benchmarking on ONNX is not supported in SHARK."
|
||||
"Currently GPU benchmarking on ONNX is not supported in AMDSHARK."
|
||||
)
|
||||
return ["N/A", "N/A"]
|
||||
else:
|
||||
@@ -305,7 +325,7 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
num_threads,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
shark_args.num_iterations,
|
||||
amdshark_args.num_iterations,
|
||||
input_counts,
|
||||
optimize_onnx,
|
||||
validate_onnx,
|
||||
@@ -320,7 +340,7 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
onnx_args,
|
||||
)
|
||||
print(
|
||||
f"ONNX ORT-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
f"ONNX ORT-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{amdshark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
result[0]["QPS"],
|
||||
@@ -388,13 +408,13 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
]
|
||||
# "frontend" must be the first element.
|
||||
if self.mode == "native":
|
||||
engines = ["shark_python", "shark_iree_c"]
|
||||
engines = ["amdshark_python", "amdshark_iree_c"]
|
||||
if self.mode == "baseline":
|
||||
engines = ["frontend"]
|
||||
if self.mode == "all":
|
||||
engines = ["frontend", "shark_python", "shark_iree_c"]
|
||||
engines = ["frontend", "amdshark_python", "amdshark_iree_c"]
|
||||
|
||||
if shark_args.onnx_bench == True:
|
||||
if amdshark_args.onnx_bench == True:
|
||||
engines.append("onnxruntime")
|
||||
|
||||
if not os.path.exists("bench_results.csv"):
|
||||
@@ -408,7 +428,7 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
bench_info["model"] = modelname
|
||||
bench_info["batch_size"] = str(import_args["batch_size"])
|
||||
bench_info["dialect"] = self.mlir_dialect
|
||||
bench_info["iterations"] = shark_args.num_iterations
|
||||
bench_info["iterations"] = amdshark_args.num_iterations
|
||||
if dynamic == True:
|
||||
bench_info["shape_type"] = "dynamic"
|
||||
else:
|
||||
@@ -442,8 +462,8 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
self.frontend_result = None
|
||||
continue
|
||||
|
||||
elif e == "shark_python":
|
||||
engine_result["engine"] = "shark_python"
|
||||
elif e == "amdshark_python":
|
||||
engine_result["engine"] = "amdshark_python"
|
||||
(
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
@@ -455,8 +475,8 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
self.frontend_result, engine_result["ms/iter"]
|
||||
)
|
||||
|
||||
elif e == "shark_iree_c":
|
||||
engine_result["engine"] = "shark_iree_c"
|
||||
elif e == "amdshark_iree_c":
|
||||
engine_result["engine"] = "amdshark_iree_c"
|
||||
(
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
241
amdshark/amdshark_compile.py
Normal file
241
amdshark/amdshark_compile.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import os
|
||||
import tempfile
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import import_with_fx, save_mlir
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from typing import List, Tuple
|
||||
from io import BytesIO
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
return [lhs[0], rhs[0]]
|
||||
else:
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
|
||||
amdshark_module = None
|
||||
if os.path.isfile(vmfb_path):
|
||||
amdshark_module = AMDSharkInference(
|
||||
None,
|
||||
device=device,
|
||||
mlir_dialect=mlir_dialect,
|
||||
)
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
amdshark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
return amdshark_module
|
||||
|
||||
|
||||
def compile_module(
|
||||
amdshark_module, extended_model_name, generate_vmfb, extra_args=[]
|
||||
):
|
||||
if generate_vmfb:
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
|
||||
if os.path.isfile(vmfb_path):
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
amdshark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
|
||||
)
|
||||
path = amdshark_module.save_module(
|
||||
os.getcwd(), extended_model_name, extra_args
|
||||
)
|
||||
amdshark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
amdshark_module.compile(extra_args)
|
||||
return amdshark_module
|
||||
|
||||
|
||||
def compile_int_precision(
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name
|
||||
):
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
weight_group_size = 128
|
||||
quantize_model(
|
||||
get_model_impl(model),
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
torchscript_module = import_with_fx(
|
||||
model,
|
||||
inputs,
|
||||
precision=precision,
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
mlir_module = torch_mlir.compile(
|
||||
torchscript_module,
|
||||
inputs,
|
||||
output_type="torch",
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
mlir_file_path = os.path.join(
|
||||
os.getcwd(), f"{extended_model_name}_linalg.mlir"
|
||||
)
|
||||
with open(mlir_file_path, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
print(mlir_module.operation.get_asm())
|
||||
mlir_module = str(mlir_module)
|
||||
mlir_module = mlir_module.encode("UTF-8")
|
||||
mlir_module = BytesIO(mlir_module)
|
||||
bytecode = mlir_module.read()
|
||||
bytecode_path = os.path.join(
|
||||
os.getcwd(), f"{extended_model_name}_linalg.mlirbc"
|
||||
)
|
||||
with open(bytecode_path, "wb") as f:
|
||||
f.write(bytecode)
|
||||
del bytecode
|
||||
del mlir_module
|
||||
print(f"Elided IR written for {extended_model_name}")
|
||||
return bytecode_path
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module=bytecode_path, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
return (
|
||||
compile_module(
|
||||
amdshark_module,
|
||||
extended_model_name=extended_model_name,
|
||||
generate_vmfb=generate_vmfb,
|
||||
extra_args=extra_args,
|
||||
),
|
||||
bytecode_path,
|
||||
)
|
||||
|
||||
|
||||
def amdshark_compile_through_fx(
|
||||
model,
|
||||
inputs,
|
||||
extended_model_name,
|
||||
precision,
|
||||
f16_input_mask=None,
|
||||
save_dir=tempfile.gettempdir(),
|
||||
debug=False,
|
||||
generate_or_load_vmfb=True,
|
||||
extra_args=[],
|
||||
device=None,
|
||||
mlir_dialect="tm_tensor",
|
||||
):
|
||||
is_f16 = precision == "fp16"
|
||||
if generate_or_load_vmfb:
|
||||
amdshark_module = load_vmfb(
|
||||
extended_model_name=extended_model_name,
|
||||
device=device,
|
||||
mlir_dialect=mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
if amdshark_module:
|
||||
return (
|
||||
amdshark_module,
|
||||
None,
|
||||
)
|
||||
|
||||
from amdshark.parser import amdshark_args
|
||||
|
||||
if "cuda" in device:
|
||||
amdshark_args.enable_tf32 = True
|
||||
|
||||
if precision in ["int4", "int8"]:
|
||||
mlir_module = compile_int_precision(
|
||||
model,
|
||||
inputs,
|
||||
precision,
|
||||
device,
|
||||
generate_or_load_vmfb,
|
||||
extended_model_name,
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
else:
|
||||
(
|
||||
bytecode,
|
||||
_,
|
||||
) = import_with_fx(
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=f16_input_mask,
|
||||
debug=debug,
|
||||
model_name=extended_model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
mlir_module = save_mlir(
|
||||
mlir_module=bytecode,
|
||||
model_name=extended_model_name,
|
||||
mlir_dialect=mlir_dialect,
|
||||
)
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module,
|
||||
device=device,
|
||||
mlir_dialect=mlir_dialect,
|
||||
)
|
||||
return (
|
||||
compile_module(
|
||||
amdshark_module,
|
||||
extended_model_name,
|
||||
generate_vmfb=generate_or_load_vmfb,
|
||||
extra_args=extra_args,
|
||||
),
|
||||
mlir_module,
|
||||
)
|
||||
@@ -1,8 +1,8 @@
|
||||
# Lint as: python3
|
||||
"""SHARK Downloader"""
|
||||
# Requirements : Put shark_tank in SHARK directory
|
||||
# /SHARK
|
||||
# /gen_shark_tank
|
||||
"""AMDSHARK Downloader"""
|
||||
# Requirements : Put amdshark_tank in AMDSHARK directory
|
||||
# /AMDSHARK
|
||||
# /gen_amdshark_tank
|
||||
# /tflite
|
||||
# /albert_lite_base
|
||||
# /...model_name...
|
||||
@@ -17,7 +17,7 @@ import os
|
||||
from tqdm.std import tqdm
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from shark.parser import shark_args
|
||||
from amdshark.parser import amdshark_args
|
||||
from google.cloud import storage
|
||||
|
||||
|
||||
@@ -83,8 +83,8 @@ input_type_to_np_dtype = {
|
||||
|
||||
# Save the model in the home local so it needn't be fetched everytime in the CI.
|
||||
home = str(Path.home())
|
||||
alt_path = os.path.join(os.path.dirname(__file__), "../gen_shark_tank/")
|
||||
custom_path = shark_args.local_tank_cache
|
||||
alt_path = os.path.join(os.path.dirname(__file__), "../gen_amdshark_tank/")
|
||||
custom_path = amdshark_args.local_tank_cache
|
||||
|
||||
if custom_path is not None:
|
||||
if not os.path.exists(custom_path):
|
||||
@@ -92,17 +92,17 @@ if custom_path is not None:
|
||||
|
||||
WORKDIR = custom_path
|
||||
|
||||
print(f"Using {WORKDIR} as local shark_tank cache directory.")
|
||||
print(f"Using {WORKDIR} as local amdshark_tank cache directory.")
|
||||
|
||||
elif os.path.exists(alt_path):
|
||||
WORKDIR = alt_path
|
||||
print(
|
||||
f"Using {WORKDIR} as shark_tank directory. Delete this directory if you aren't working from locally generated shark_tank."
|
||||
f"Using {WORKDIR} as amdshark_tank directory. Delete this directory if you aren't working from locally generated amdshark_tank."
|
||||
)
|
||||
else:
|
||||
WORKDIR = os.path.join(home, ".local/shark_tank/")
|
||||
WORKDIR = os.path.join(home, ".local/amdshark_tank/")
|
||||
print(
|
||||
f"shark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache= flag"
|
||||
f"amdshark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache= flag"
|
||||
)
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
|
||||
@@ -111,22 +111,20 @@ os.makedirs(WORKDIR, exist_ok=True)
|
||||
def check_dir_exists(model_name, frontend="torch", dynamic=""):
|
||||
model_dir = os.path.join(WORKDIR, model_name)
|
||||
|
||||
# Remove the _tf keyword from end.
|
||||
if frontend in ["tf", "tensorflow"]:
|
||||
model_name = model_name[:-3]
|
||||
elif frontend in ["tflite"]:
|
||||
model_name = model_name[:-7]
|
||||
elif frontend in ["torch", "pytorch"]:
|
||||
model_name = model_name[:-6]
|
||||
# Remove the _tf keyword from end only for non-SD models.
|
||||
if not any(model in model_name for model in ["clip", "unet", "vae"]):
|
||||
if frontend in ["tf", "tensorflow"]:
|
||||
model_name = model_name[:-3]
|
||||
elif frontend in ["tflite"]:
|
||||
model_name = model_name[:-7]
|
||||
elif frontend in ["torch", "pytorch"]:
|
||||
model_name = model_name[:-6]
|
||||
|
||||
model_mlir_file_name = f"{model_name}{dynamic}_{frontend}.mlir"
|
||||
|
||||
if os.path.isdir(model_dir):
|
||||
if (
|
||||
os.path.isfile(
|
||||
os.path.join(
|
||||
model_dir,
|
||||
model_name + dynamic + "_" + str(frontend) + ".mlir",
|
||||
)
|
||||
)
|
||||
os.path.isfile(os.path.join(model_dir, model_mlir_file_name))
|
||||
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
|
||||
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
|
||||
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
|
||||
@@ -152,8 +150,8 @@ def _internet_connected():
|
||||
def get_git_revision_short_hash() -> str:
|
||||
import subprocess
|
||||
|
||||
if shark_args.shark_prefix is not None:
|
||||
prefix_kw = shark_args.shark_prefix
|
||||
if amdshark_args.amdshark_prefix is not None:
|
||||
prefix_kw = amdshark_args.amdshark_prefix
|
||||
else:
|
||||
import json
|
||||
|
||||
@@ -162,11 +160,11 @@ def get_git_revision_short_hash() -> str:
|
||||
with open(src, "r") as f:
|
||||
data = json.loads(f.read())
|
||||
prefix_kw = data["version"]
|
||||
print(f"Checking for updates from gs://shark_tank/{prefix_kw}")
|
||||
print(f"Checking for updates from gs://amdshark_tank/{prefix_kw}")
|
||||
return prefix_kw
|
||||
|
||||
|
||||
def get_sharktank_prefix():
|
||||
def get_amdsharktank_prefix():
|
||||
tank_prefix = ""
|
||||
if not _internet_connected():
|
||||
print(
|
||||
@@ -176,7 +174,7 @@ def get_sharktank_prefix():
|
||||
else:
|
||||
desired_prefix = get_git_revision_short_hash()
|
||||
storage_client_a = storage.Client.create_anonymous_client()
|
||||
base_bucket_name = "shark_tank"
|
||||
base_bucket_name = "amdshark_tank"
|
||||
base_bucket = storage_client_a.bucket(base_bucket_name)
|
||||
dir_blobs = base_bucket.list_blobs(prefix=f"{desired_prefix}")
|
||||
for blob in dir_blobs:
|
||||
@@ -188,13 +186,13 @@ def get_sharktank_prefix():
|
||||
continue
|
||||
if tank_prefix == "":
|
||||
print(
|
||||
f"shark_tank bucket not found matching ({desired_prefix}). Defaulting to nightly."
|
||||
f"amdshark_tank bucket not found matching ({desired_prefix}). Defaulting to nightly."
|
||||
)
|
||||
tank_prefix = "nightly"
|
||||
return tank_prefix
|
||||
|
||||
|
||||
# Downloads the torch model from gs://shark_tank dir.
|
||||
# Downloads the torch model from gs://amdshark_tank dir.
|
||||
def download_model(
|
||||
model_name,
|
||||
dynamic=False,
|
||||
@@ -206,7 +204,7 @@ def download_model(
|
||||
model_name = model_name.replace("/", "_")
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
shark_args.shark_prefix = get_sharktank_prefix()
|
||||
amdshark_args.amdshark_prefix = get_amdsharktank_prefix()
|
||||
if import_args["batch_size"] and import_args["batch_size"] != 1:
|
||||
model_dir_name = (
|
||||
model_name
|
||||
@@ -223,7 +221,7 @@ def download_model(
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
|
||||
if not tank_url:
|
||||
tank_url = "gs://shark_tank/" + shark_args.shark_prefix
|
||||
tank_url = "gs://amdshark_tank/" + amdshark_args.amdshark_prefix
|
||||
|
||||
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
|
||||
if not check_dir_exists(
|
||||
@@ -234,7 +232,7 @@ def download_model(
|
||||
)
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
|
||||
elif shark_args.force_update_tank == True:
|
||||
elif amdshark_args.force_update_tank == True:
|
||||
print(
|
||||
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
|
||||
)
|
||||
@@ -261,13 +259,13 @@ def download_model(
|
||||
except FileNotFoundError:
|
||||
print(f"Model artifact hash not found at {model_dir}.")
|
||||
upstream_hash = None
|
||||
if local_hash != upstream_hash and shark_args.update_tank == True:
|
||||
if local_hash != upstream_hash and amdshark_args.update_tank == True:
|
||||
print(f"Updating artifacts for model {model_name}...")
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
|
||||
elif local_hash != upstream_hash:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
|
||||
"Hash does not match upstream in gs://amdshark_tank/. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
@@ -277,25 +275,23 @@ def download_model(
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
tuned_str = "" if tuned is None else "_" + tuned
|
||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
mlir_filename = os.path.join(model_dir, model_name + suffix)
|
||||
print(
|
||||
f"Verifying that model artifacts were downloaded successfully to {filename}..."
|
||||
f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..."
|
||||
)
|
||||
if not os.path.exists(filename):
|
||||
from tank.generate_sharktank import gen_shark_files
|
||||
if not os.path.exists(mlir_filename):
|
||||
from tank.generate_amdsharktank import gen_amdshark_files
|
||||
|
||||
print(
|
||||
"The model data was not found. Trying to generate artifacts locally."
|
||||
)
|
||||
gen_shark_files(model_name, frontend, WORKDIR, import_args)
|
||||
gen_amdshark_files(model_name, frontend, WORKDIR, import_args)
|
||||
|
||||
assert os.path.exists(filename), f"MLIR not found at {filename}"
|
||||
with open(filename, mode="rb") as f:
|
||||
mlir_file = f.read()
|
||||
assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}"
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
return mlir_filename, function_name, inputs_tuple, golden_out_tuple
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
from shark.shark_importer import import_with_fx
|
||||
from amdshark.amdshark_importer import import_with_fx, save_mlir
|
||||
import torchvision.models as models
|
||||
import copy
|
||||
import io
|
||||
@@ -13,22 +13,28 @@ from typing import Dict
|
||||
import torch_mlir
|
||||
|
||||
|
||||
def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
|
||||
def amdshark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_g, inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
from shark.shark_inference import SharkInference
|
||||
bytecode_path = save_mlir(
|
||||
bytecode,
|
||||
model_name="amdshark_eager_module",
|
||||
frontend="torch",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module=bytecode_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile(extra_args=[])
|
||||
return shark_module
|
||||
amdshark_module.compile(extra_args=[])
|
||||
return amdshark_module
|
||||
|
||||
|
||||
def _make_single_op_gm(node, captured_val, compiled_graph):
|
||||
@@ -49,7 +55,7 @@ def _make_single_op_gm(node, captured_val, compiled_graph):
|
||||
g.output(call)
|
||||
g.lint()
|
||||
single_node = torch.fx.GraphModule(torch.nn.Module(), g)
|
||||
compiled_module = shark_backend(single_node, inputs)
|
||||
compiled_module = amdshark_backend(single_node, inputs)
|
||||
compiled_graph[node.name] = {
|
||||
"module": compiled_module,
|
||||
"inputs": [i for i in env],
|
||||
@@ -166,41 +172,41 @@ shape_prop = ShapeProp(fx_graph)
|
||||
|
||||
x = shape_prop.propagate(input[0])
|
||||
|
||||
shark_graph = compiled_graph(fx_graph, x)
|
||||
amdshark_graph = compiled_graph(fx_graph, x)
|
||||
|
||||
|
||||
for key in shark_graph:
|
||||
for key in amdshark_graph:
|
||||
if key.startswith("getitem"):
|
||||
input_val = shark_graph[key]["input"]
|
||||
pos = shark_graph[key]["pos"]
|
||||
if input_val not in shark_graph:
|
||||
shark_graph[key]["result"] = x[input_val][pos].detach()
|
||||
input_val = amdshark_graph[key]["input"]
|
||||
pos = amdshark_graph[key]["pos"]
|
||||
if input_val not in amdshark_graph:
|
||||
amdshark_graph[key]["result"] = x[input_val][pos].detach()
|
||||
else:
|
||||
shark_graph[key]["result"] = shark_graph[input_val]["result"][
|
||||
amdshark_graph[key]["result"] = amdshark_graph[input_val]["result"][
|
||||
pos
|
||||
].detach()
|
||||
elif key.startswith("empty"):
|
||||
operator = shark_graph[key]["target"]
|
||||
args = shark_graph[key]["args"]
|
||||
kwargs = shark_graph[key]["kwargs"]
|
||||
shark_graph[key]["result"] = operator(*args, **kwargs).detach()
|
||||
operator = amdshark_graph[key]["target"]
|
||||
args = amdshark_graph[key]["args"]
|
||||
kwargs = amdshark_graph[key]["kwargs"]
|
||||
amdshark_graph[key]["result"] = operator(*args, **kwargs).detach()
|
||||
else:
|
||||
input_val = shark_graph[key]["inputs"]
|
||||
input_val = amdshark_graph[key]["inputs"]
|
||||
input_tensors = []
|
||||
for input in input_val:
|
||||
if input not in shark_graph:
|
||||
if input not in amdshark_graph:
|
||||
input_tensors.append(x[input].detach())
|
||||
else:
|
||||
input_tensors.append(shark_graph[input]["result"])
|
||||
input_tensors.append(amdshark_graph[input]["result"])
|
||||
|
||||
val = shark_graph[key]["module"]("forward", input_tensors)
|
||||
val = amdshark_graph[key]["module"]("forward", input_tensors)
|
||||
if isinstance(val, (tuple, list)):
|
||||
list_val = []
|
||||
for v in val:
|
||||
list_val.append(torch.from_numpy(v))
|
||||
shark_graph[key]["result"] = list_val
|
||||
amdshark_graph[key]["result"] = list_val
|
||||
else:
|
||||
shark_graph[key]["result"] = torch.from_numpy(val)
|
||||
amdshark_graph[key]["result"] = torch.from_numpy(val)
|
||||
|
||||
|
||||
print(shark_graph)
|
||||
print(amdshark_graph)
|
||||
@@ -1,8 +1,10 @@
|
||||
import re
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
import torch_mlir
|
||||
from iree.compiler import compile_str
|
||||
from shark.shark_importer import import_with_fx, get_f16_inputs
|
||||
from iree.compiler import compile_file
|
||||
from amdshark.amdshark_importer import import_with_fx, get_f16_inputs, save_mlir
|
||||
|
||||
|
||||
class GenerateConfigFile:
|
||||
@@ -11,6 +13,7 @@ class GenerateConfigFile:
|
||||
model,
|
||||
num_sharding_stages: int,
|
||||
sharding_stages_id: list[str],
|
||||
units_in_each_stage: list[int],
|
||||
model_input=None,
|
||||
config_file_path="model_config.json",
|
||||
):
|
||||
@@ -22,13 +25,16 @@ class GenerateConfigFile:
|
||||
), "Number of sharding stages should be equal to the list of their ID"
|
||||
self.model_input = model_input
|
||||
self.config_file_path = config_file_path
|
||||
# (Nithin) this is a quick fix - revisit and rewrite
|
||||
self.units_in_each_stage = np.array(units_in_each_stage)
|
||||
self.track_loop = np.zeros(len(self.sharding_stages_id)).astype(int)
|
||||
|
||||
def split_into_dispatches(
|
||||
self,
|
||||
backend,
|
||||
fx_tracing_required=True,
|
||||
fx_tracing_required=False,
|
||||
f16_model=False,
|
||||
torch_mlir_tracing=False,
|
||||
torch_mlir_tracing=True,
|
||||
):
|
||||
graph_for_compilation = self.model
|
||||
if fx_tracing_required:
|
||||
@@ -48,9 +54,15 @@ class GenerateConfigFile:
|
||||
verbose=False,
|
||||
)
|
||||
module = module.operation.get_asm(large_elements_limit=4)
|
||||
module_file = save_mlir(
|
||||
module,
|
||||
model_name="module_pre_split",
|
||||
frontend="torch",
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
compiled_module_str = str(
|
||||
compile_str(
|
||||
str(module),
|
||||
compile_file(
|
||||
module_file,
|
||||
target_backends=[backend],
|
||||
extra_args=[
|
||||
"--compile-to=flow",
|
||||
@@ -95,7 +107,17 @@ class GenerateConfigFile:
|
||||
if substring_before_final_period in model_dictionary:
|
||||
del model_dictionary[substring_before_final_period]
|
||||
|
||||
layer_dict = {n: "None" for n in self.sharding_stages_id}
|
||||
# layer_dict = {n: "None" for n in self.sharding_stages_id}
|
||||
|
||||
# By default embed increasing device id's for each layer
|
||||
increasing_wraparound_idx_list = (
|
||||
self.track_loop % self.units_in_each_stage
|
||||
)
|
||||
layer_dict = {
|
||||
n: int(increasing_wraparound_idx_list[idx][0][0])
|
||||
for idx, n in enumerate(self.sharding_stages_id)
|
||||
}
|
||||
self.track_loop += 1
|
||||
model_dictionary[name] = layer_dict
|
||||
|
||||
self.generate_json(model_dictionary)
|
||||
@@ -103,3 +125,29 @@ class GenerateConfigFile:
|
||||
def generate_json(self, artifacts):
|
||||
with open(self.config_file_path, "w") as outfile:
|
||||
json.dump(artifacts, outfile)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
|
||||
[1, 19]
|
||||
)
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna7B,
|
||||
CombinedModel,
|
||||
)
|
||||
|
||||
model = CombinedModel()
|
||||
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
|
||||
c.split_into_layers()
|
||||
@@ -1,11 +1,12 @@
|
||||
# Lint as: python3
|
||||
"""SHARK Importer"""
|
||||
"""AMDSHARK Importer"""
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
def create_hash(file_name):
|
||||
with open(file_name, "rb") as f:
|
||||
@@ -27,9 +28,9 @@ supported_frontends = {
|
||||
}
|
||||
|
||||
|
||||
class SharkImporter:
|
||||
class AMDSharkImporter:
|
||||
"""
|
||||
SharkImporter converts frontend modules into a
|
||||
AMDSharkImporter converts frontend modules into a
|
||||
mlir_module. The supported frameworks are tensorflow,
|
||||
pytorch, and tf-lite.
|
||||
|
||||
@@ -82,7 +83,7 @@ class SharkImporter:
|
||||
# NOTE: The default function for torch is "forward" and tf-lite is "main".
|
||||
|
||||
def _torch_mlir(self, is_dynamic, tracing_required, mlir_type):
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module
|
||||
from amdshark.torch_mlir_utils import get_torch_mlir_module
|
||||
|
||||
return get_torch_mlir_module(
|
||||
self.module,
|
||||
@@ -120,7 +121,7 @@ class SharkImporter:
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
func_name="forward",
|
||||
save_dir="./shark_tmp/",
|
||||
save_dir=cmd_opts.tmp_dir, #"./amdshark_tmp/",
|
||||
mlir_type="linalg",
|
||||
):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
@@ -451,6 +452,108 @@ def transform_fx(fx_g, quantized=False):
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
def gptq_transforms(fx_g):
|
||||
import torch
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.ones,
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("device") == torch.device(device="cuda:0"):
|
||||
updated_kwargs = node.kwargs.copy()
|
||||
updated_kwargs["device"] = torch.device(device="cpu")
|
||||
node.kwargs = updated_kwargs
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("dtype") == torch.bfloat16:
|
||||
updated_kwargs = node.kwargs.copy()
|
||||
updated_kwargs["dtype"] = torch.float16
|
||||
node.kwargs = updated_kwargs
|
||||
|
||||
# Inputs of aten.native_layer_norm should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.native_layer_norm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (
|
||||
new_node_arg0,
|
||||
node.args[1],
|
||||
node.args[2],
|
||||
node.args[3],
|
||||
node.args[4],
|
||||
)
|
||||
|
||||
# Inputs of aten.mm should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.mm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
new_node_arg1 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[1], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node_arg0, new_node_arg1)
|
||||
|
||||
# Outputs of aten.mm should be downcasted to fp16.
|
||||
if type(node.args[0]) == torch.fx.node.Node and node.args[
|
||||
0
|
||||
].target in [torch.ops.aten.mm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
tmp = node.args[0]
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node.args[0],),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.args[0].append(new_node)
|
||||
node.args[0].replace_all_uses_with(new_node)
|
||||
new_node.args = (tmp,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
# Inputs of aten._softmax should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten._softmax]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node_arg0, node.args[1], node.args[2])
|
||||
|
||||
# Outputs of aten._softmax should be downcasted to fp16.
|
||||
if (
|
||||
type(node.args[0]) == torch.fx.node.Node
|
||||
and node.args[0].target in [torch.ops.aten._softmax]
|
||||
and node.target in [torch.ops.aten.expand]
|
||||
):
|
||||
with fx_g.graph.inserting_before(node):
|
||||
tmp = node.args[0]
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node.args[0],),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.args[0].append(new_node)
|
||||
node.args[0].replace_all_uses_with(new_node)
|
||||
new_node.args = (tmp,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
# Doesn't replace the None type.
|
||||
def change_fx_graph_return_to_tuple(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
@@ -488,7 +591,7 @@ def flatten_training_input(inputs):
|
||||
return tuple(flattened_input)
|
||||
|
||||
|
||||
# TODO: get rid of is_f16 by using precision
|
||||
# TODO: Remove is_f16 and fix all calls with using precision instead
|
||||
# Applies fx conversion to the model and imports the mlir.
|
||||
def import_with_fx(
|
||||
model,
|
||||
@@ -504,27 +607,12 @@ def import_with_fx(
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
precision="fp32",
|
||||
is_gptq=False,
|
||||
):
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
block_quant_layer_level_manager,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
brevitas_layer_export_mode,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
LinearWeightBlockQuantHandlerFwd,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
matmul_rhs_group_quant_placeholder,
|
||||
)
|
||||
from brevitas.backport.fx.experimental.proxy_tensor import (
|
||||
make_fx as brevitas_make_fx,
|
||||
)
|
||||
|
||||
golden_values = None
|
||||
if debug:
|
||||
@@ -596,8 +684,30 @@ def import_with_fx(
|
||||
torch.ops.aten.native_layer_norm,
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
||||
torch.ops.aten.index_add,
|
||||
torch.ops.aten.index_add_,
|
||||
]
|
||||
if precision in ["int4", "int8"]:
|
||||
if precision in ["int4", "int8"] and not is_gptq:
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
block_quant_layer_level_manager,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
brevitas_layer_export_mode,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
LinearWeightBlockQuantHandlerFwd,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
replace_call_fn_target,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
matmul_rhs_group_quant_placeholder,
|
||||
)
|
||||
from brevitas.backport.fx.experimental.proxy_tensor import (
|
||||
make_fx as brevitas_make_fx,
|
||||
)
|
||||
|
||||
export_context_manager = brevitas_layer_export_mode
|
||||
export_class = block_quant_layer_level_manager(
|
||||
export_handlers=[LinearWeightBlockQuantHandlerFwd]
|
||||
@@ -612,7 +722,7 @@ def import_with_fx(
|
||||
replace_call_fn_target(
|
||||
fx_g,
|
||||
src=matmul_rhs_group_quant_placeholder,
|
||||
target=torch.ops.brevitas.matmul_rhs_group_quant,
|
||||
target=torch.ops.quant.matmul_rhs_group_quant,
|
||||
)
|
||||
|
||||
fx_g.recompile()
|
||||
@@ -647,6 +757,10 @@ def import_with_fx(
|
||||
add_upcast(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if is_gptq:
|
||||
gptq_transforms(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if mlir_type == "fx":
|
||||
return fx_g
|
||||
|
||||
@@ -659,7 +773,7 @@ def import_with_fx(
|
||||
return ts_graph
|
||||
|
||||
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)
|
||||
mlir_importer = SharkImporter(
|
||||
mlir_importer = AMDSharkImporter(
|
||||
ts_graph,
|
||||
inputs,
|
||||
frontend="torch",
|
||||
@@ -677,5 +791,29 @@ def import_with_fx(
|
||||
)
|
||||
return mlir_module, func_name
|
||||
|
||||
mlir_module, func_name = mlir_importer.import_mlir()
|
||||
mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type)
|
||||
return mlir_module, func_name
|
||||
|
||||
|
||||
# Saves a .mlir module python object to the directory 'dir' with 'model_name' and returns a path to the saved file.
|
||||
def save_mlir(
|
||||
mlir_module,
|
||||
model_name,
|
||||
mlir_dialect="linalg",
|
||||
frontend="torch",
|
||||
dir="",
|
||||
):
|
||||
model_name_mlir = (
|
||||
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
|
||||
)
|
||||
if dir == "":
|
||||
dir = cmd_opts.tmp_dir, #os.path.join(".", "amdshark_tmp")
|
||||
mlir_path = os.path.join(dir, model_name_mlir)
|
||||
print(f"saving {model_name_mlir} to {dir}")
|
||||
if not os.path.exists(dir):
|
||||
os.makedirs(dir)
|
||||
if frontend == "torch":
|
||||
with open(mlir_path, "wb") as mlir_file:
|
||||
mlir_file.write(mlir_module)
|
||||
|
||||
return mlir_path
|
||||
@@ -9,15 +9,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from shark.iree_utils.compile_utils import (
|
||||
from amdshark.iree_utils.compile_utils import (
|
||||
export_iree_module_to_vmfb,
|
||||
load_flatbuffer,
|
||||
create_dispatch_dirs,
|
||||
compile_benchmark_dirs,
|
||||
)
|
||||
import os
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.parser import shark_args
|
||||
from amdshark.amdshark_runner import AMDSharkRunner
|
||||
from amdshark.parser import amdshark_args
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ dtype_to_np_dtype = {
|
||||
}
|
||||
|
||||
|
||||
class SharkInference:
|
||||
class AMDSharkInference:
|
||||
"""
|
||||
Runs prediction or inference on mlir_module.
|
||||
|
||||
@@ -39,7 +39,7 @@ class SharkInference:
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
|
||||
mlir_module or path represented in string; modules from torch-mlir are serialized in bytecode format.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
@@ -47,7 +47,7 @@ class SharkInference:
|
||||
The dialect in which the given mlir_module is in.
|
||||
Refer to {https://mlir.llvm.org/docs/Dialects/}
|
||||
is_benchmark: bool
|
||||
Whether this SharkInference module should be benchmark-enabled.
|
||||
Whether this AMDSharkInference module should be benchmark-enabled.
|
||||
mmap: bool
|
||||
Whether to load/run vmfb using mmap. It's `True` by default.
|
||||
|
||||
@@ -65,7 +65,7 @@ class SharkInference:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: bytes,
|
||||
mlir_module,
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
is_benchmark: bool = False,
|
||||
@@ -73,25 +73,35 @@ class SharkInference:
|
||||
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
|
||||
device_idx: int = None,
|
||||
mmap: bool = True,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
if mlir_module is not None:
|
||||
if mlir_module and not os.path.isfile(mlir_module):
|
||||
print(
|
||||
"Warning: Initializing AMDSharkInference with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize AMDSharkInference with a path to a MLIR module on your hard disk instead."
|
||||
)
|
||||
self.compile_str = True
|
||||
else:
|
||||
self.compile_str = False
|
||||
self.device = amdshark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.is_benchmark = is_benchmark
|
||||
self.device_idx = device_idx
|
||||
self.dispatch_benchmarks = (
|
||||
shark_args.dispatch_benchmarks
|
||||
amdshark_args.dispatch_benchmarks
|
||||
if dispatch_benchmark is None
|
||||
else dispatch_benchmark
|
||||
)
|
||||
self.dispatch_benchmarks_dir = (
|
||||
shark_args.dispatch_benchmarks_dir
|
||||
amdshark_args.dispatch_benchmarks_dir
|
||||
if dispatch_benchmark_dir == "temp_dispatch_benchmarks"
|
||||
else dispatch_benchmark_dir
|
||||
)
|
||||
|
||||
self.shark_runner = None
|
||||
self.amdshark_runner = None
|
||||
self.mmap = mmap
|
||||
self.rt_flags = rt_flags
|
||||
|
||||
def compile(self, extra_args=[]):
|
||||
if self.dispatch_benchmarks is not None:
|
||||
@@ -110,9 +120,9 @@ class SharkInference:
|
||||
)
|
||||
|
||||
if self.is_benchmark == True:
|
||||
from shark.shark_benchmark_runner import SharkBenchmarkRunner
|
||||
from amdshark.amdshark_benchmark_runner import AMDSharkBenchmarkRunner
|
||||
|
||||
self.shark_runner = SharkBenchmarkRunner(
|
||||
self.amdshark_runner = AMDSharkBenchmarkRunner(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
@@ -120,12 +130,13 @@ class SharkInference:
|
||||
)
|
||||
|
||||
else:
|
||||
self.shark_runner = SharkRunner(
|
||||
self.amdshark_runner = AMDSharkRunner(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
device_idx=self.device_idx,
|
||||
rt_flags=self.rt_flags,
|
||||
)
|
||||
|
||||
if self.dispatch_benchmarks is not None:
|
||||
@@ -139,11 +150,19 @@ class SharkInference:
|
||||
|
||||
# inputs are considered to be tuple of np.array.
|
||||
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run(function_name, inputs, send_to_host)
|
||||
return self.amdshark_runner.run(
|
||||
function_name, inputs, send_to_host, device=self.device
|
||||
)
|
||||
|
||||
# forward function.
|
||||
def forward(self, inputs: tuple, send_to_host=True):
|
||||
return self.amdshark_runner.run(
|
||||
"forward", inputs, send_to_host, device=self.device
|
||||
)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
def get_functions_in_module(self):
|
||||
return self.shark_runner.get_functions_in_module()
|
||||
return self.amdshark_runner.get_functions_in_module()
|
||||
|
||||
# Captures the static input information from the mlir_module.
|
||||
# TODO(pashu123): Generate the input information for dynamic shapes.
|
||||
@@ -188,7 +207,9 @@ class SharkInference:
|
||||
|
||||
# TODO: Instead of passing directory and having names decided by the module
|
||||
# , user may want to save the module with manual names.
|
||||
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
|
||||
def save_module(
|
||||
self, dir=os.getcwd(), module_name=None, extra_args=[], debug=False
|
||||
):
|
||||
return export_iree_module_to_vmfb(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
@@ -196,23 +217,27 @@ class SharkInference:
|
||||
self.mlir_dialect,
|
||||
module_name=module_name,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
|
||||
# load and return the module.
|
||||
def load_module(self, path, extra_args=[]):
|
||||
self.shark_runner = SharkRunner(
|
||||
self.amdshark_runner = AMDSharkRunner(
|
||||
device=self.device,
|
||||
compile_vmfb=False,
|
||||
extra_args=extra_args,
|
||||
rt_flags=self.rt_flags,
|
||||
)
|
||||
params = load_flatbuffer(
|
||||
path,
|
||||
self.device,
|
||||
self.device_idx,
|
||||
mmap=self.mmap,
|
||||
rt_flags=self.rt_flags,
|
||||
)
|
||||
self.shark_runner.iree_compilation_module = params["vmfb"]
|
||||
self.shark_runner.iree_config = params["config"]
|
||||
self.shark_runner.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
self.amdshark_runner.iree_compilation_module = params["vmfb"]
|
||||
self.amdshark_runner.iree_config = params["config"]
|
||||
self.amdshark_runner.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
return
|
||||
@@ -12,19 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from shark.iree_utils.compile_utils import (
|
||||
from amdshark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
get_results,
|
||||
export_iree_module_to_vmfb,
|
||||
load_flatbuffer,
|
||||
)
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.parser import shark_args
|
||||
from amdshark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from amdshark.parser import amdshark_args
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
# supported dialects by the shark-runtime.
|
||||
# supported dialects by the amdshark-runtime.
|
||||
supported_dialects = {
|
||||
"linalg",
|
||||
"auto",
|
||||
@@ -35,9 +35,9 @@ supported_dialects = {
|
||||
}
|
||||
|
||||
|
||||
class SharkRunner:
|
||||
class AMDSharkRunner:
|
||||
"""
|
||||
Base class for SharkInference and SharkTrainer
|
||||
Base class for AMDSharkInference and AMDSharkTrainer
|
||||
used to execute an mlir_module.
|
||||
|
||||
...
|
||||
@@ -45,7 +45,7 @@ class SharkRunner:
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string.
|
||||
mlir_module path, string, or bytecode.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
@@ -72,12 +72,22 @@ class SharkRunner:
|
||||
extra_args: list = [],
|
||||
compile_vmfb: bool = True,
|
||||
device_idx: int = None,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
if self.mlir_module is not None:
|
||||
if not os.path.isfile(mlir_module):
|
||||
print(
|
||||
"Warning: Initializing AMDSharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize AMDSharkInference with a path to a MLIR module on your hard disk instead."
|
||||
)
|
||||
self.compile_str = True
|
||||
else:
|
||||
self.compile_str = False
|
||||
self.device = amdshark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
self.device_idx = device_idx
|
||||
self.rt_flags = rt_flags
|
||||
|
||||
if check_device_drivers(self.device):
|
||||
print(device_driver_info(self.device))
|
||||
@@ -91,13 +101,17 @@ class SharkRunner:
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
device_idx=self.device_idx,
|
||||
rt_flags=self.rt_flags,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
self.iree_compilation_module = params["vmfb"]
|
||||
self.iree_config = params["config"]
|
||||
self.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
|
||||
def run(self, function_name, inputs: tuple, send_to_host=False):
|
||||
def run(
|
||||
self, function_name, inputs: tuple, send_to_host=False, device=None
|
||||
):
|
||||
return get_results(
|
||||
self.iree_compilation_module,
|
||||
function_name,
|
||||
@@ -105,6 +119,7 @@ class SharkRunner:
|
||||
self.iree_config,
|
||||
self.mlir_dialect,
|
||||
send_to_host,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
@@ -12,10 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from shark.parser import shark_args
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.backward_makefx import MakeFxModule
|
||||
from shark.shark_importer import import_with_fx
|
||||
from amdshark.parser import amdshark_args
|
||||
from amdshark.amdshark_runner import AMDSharkRunner
|
||||
from amdshark.backward_makefx import MakeFxModule
|
||||
from amdshark.amdshark_importer import import_with_fx, save_mlir
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
@@ -26,8 +26,8 @@ def print_err(*a):
|
||||
print(*a, file=sys.stderr)
|
||||
|
||||
|
||||
class SharkTrainer:
|
||||
"""Training pytorch, tensorflow module on shark runtime."""
|
||||
class AMDSharkTrainer:
|
||||
"""Training pytorch, tensorflow module on amdshark runtime."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -48,9 +48,9 @@ class SharkTrainer:
|
||||
|
||||
# By default it's the torch frontend.
|
||||
self.frontend = "pytorch"
|
||||
self.device = device if device is not None else shark_args.device
|
||||
self.device = device if device is not None else amdshark_args.device
|
||||
|
||||
self.shark_runner = None
|
||||
self.amdshark_runner = None
|
||||
|
||||
# Sets the frontend i.e `pytorch` or `tensorflow`.
|
||||
def set_frontend(self, frontend: str):
|
||||
@@ -69,7 +69,7 @@ class SharkTrainer:
|
||||
self.frontend = frontend
|
||||
|
||||
# Training function is needed in the case of torch_fn.
|
||||
def compile(self, training_fn=None, extra_args=[]):
|
||||
def compile(self, training_fn=None, mlir_type="linalg", extra_args=[]):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
packed_inputs = (
|
||||
dict(self.model.named_parameters()),
|
||||
@@ -77,16 +77,27 @@ class SharkTrainer:
|
||||
tuple(self.input),
|
||||
)
|
||||
mlir_module, func_name = import_with_fx(
|
||||
training_fn, packed_inputs, False, [], training=True
|
||||
training_fn,
|
||||
packed_inputs,
|
||||
False,
|
||||
[],
|
||||
training=True,
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
self.shark_runner = SharkRunner(
|
||||
mlir_module = save_mlir(
|
||||
mlir_module,
|
||||
model_name="amdshark_model",
|
||||
frontend="torch",
|
||||
mlir_dialect=mlir_type,
|
||||
)
|
||||
self.amdshark_runner = AMDSharkRunner(
|
||||
mlir_module,
|
||||
self.device,
|
||||
"tm_tensor",
|
||||
extra_args=extra_args,
|
||||
)
|
||||
elif self.frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
self.shark_runner = SharkRunner(
|
||||
self.amdshark_runner = AMDSharkRunner(
|
||||
self.model,
|
||||
self.input,
|
||||
self.dynamic,
|
||||
@@ -112,7 +123,7 @@ class SharkTrainer:
|
||||
params = [x.numpy() for x in params]
|
||||
print(f"Training started for {num_iters} iterations:")
|
||||
for i in tqdm(range(num_iters)):
|
||||
params = self.shark_runner.run(
|
||||
params = self.amdshark_runner.run(
|
||||
"forward", params + self.input, self.frontend
|
||||
)
|
||||
|
||||
@@ -120,7 +131,7 @@ class SharkTrainer:
|
||||
|
||||
# Function to train tensorflow module.
|
||||
# Output final loss.
|
||||
# TODO(raikonenfnu): Save updated weight/states in SHARK.
|
||||
# TODO(raikonenfnu): Save updated weight/states in AMDSHARK.
|
||||
def _train_tf(self, num_iters):
|
||||
input_list = []
|
||||
for x in self.input:
|
||||
@@ -139,7 +150,7 @@ class SharkTrainer:
|
||||
|
||||
print(f"Training started for {num_iters} iterations:")
|
||||
for i in tqdm(range(num_iters)):
|
||||
outputs = self.shark_runner.forward(input_list, self.frontend)
|
||||
outputs = self.amdshark_runner.forward(input_list, self.frontend)
|
||||
return outputs
|
||||
|
||||
def train(self, num_iters=1):
|
||||
@@ -15,7 +15,7 @@
|
||||
import torch
|
||||
from torch._decomp import get_decompositions
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.nn.utils import _stateless
|
||||
from torch.nn.utils import stateless
|
||||
|
||||
from torch import fx
|
||||
import tempfile
|
||||
@@ -71,7 +71,7 @@ class MakeFxModule:
|
||||
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
temp = tempfile.NamedTemporaryFile(
|
||||
suffix="_shark_ts", prefix="temp_ts_"
|
||||
suffix="_amdshark_ts", prefix="temp_ts_"
|
||||
)
|
||||
ts_g.save(temp.name)
|
||||
new_ts = torch.jit.load(temp.name)
|
||||
@@ -3,7 +3,7 @@ from typing import List, Optional
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from torch._decomp import get_decompositions
|
||||
from torch.func import functionalize
|
||||
import io
|
||||
@@ -93,13 +93,13 @@ def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
class SharkBackend:
|
||||
class AMDSharkBackend:
|
||||
def __init__(
|
||||
self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict
|
||||
):
|
||||
self.fx_g = fx_g
|
||||
self.inputs = inputs
|
||||
self.shark_module = None
|
||||
self.amdshark_module = None
|
||||
self.device: str = options.get("device", "cpu")
|
||||
self.was_unwrapped: bool = False
|
||||
self.none_indices: list = []
|
||||
@@ -125,19 +125,19 @@ class SharkBackend:
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=self.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile(extra_args=[])
|
||||
self.shark_module = shark_module
|
||||
amdshark_module.compile(extra_args=[])
|
||||
self.amdshark_module = amdshark_module
|
||||
|
||||
def __call__(self, *inputs):
|
||||
np_inputs = [x.contiguous().detach().cpu().numpy() for x in inputs]
|
||||
np_outs = self.shark_module("forward", np_inputs)
|
||||
np_outs = self.amdshark_module("forward", np_inputs)
|
||||
if self.was_unwrapped:
|
||||
np_outs = [
|
||||
np_outs,
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
import shark
|
||||
import amdshark
|
||||
|
||||
|
||||
def foo(x, a):
|
||||
@@ -9,8 +9,8 @@ def foo(x, a):
|
||||
return x + 3
|
||||
|
||||
|
||||
shark_options = {"device": "cpu"}
|
||||
compiled = torch.compile(foo, backend="shark", options=shark_options)
|
||||
amdshark_options = {"device": "cpu"}
|
||||
compiled = torch.compile(foo, backend="amdshark", options=amdshark_options)
|
||||
|
||||
input = torch.ones(4)
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
"source": [
|
||||
"# standard imports\n",
|
||||
"import torch\n",
|
||||
"from shark.iree_utils import get_iree_compiled_module"
|
||||
"from amdshark.iree_utils import get_iree_compiled_module"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch_mlir import compile, OutputType
|
||||
|
||||
from shark.iree_utils import get_iree_compiled_module
|
||||
from amdshark.iree_utils import get_iree_compiled_module
|
||||
|
||||
try:
|
||||
import torchdynamo
|
||||
@@ -32,7 +32,7 @@
|
||||
"source": [
|
||||
"# eager mode imports\n",
|
||||
"from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor\n",
|
||||
"from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend"
|
||||
"from amdshark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
@@ -440,7 +440,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"There is a convenience class `SharkEagerMode` that will handle both the installation of the backend and the wrapping of `torch.Tensor`s:"
|
||||
"There is a convenience class `AMDSharkEagerMode` that will handle both the installation of the backend and the wrapping of `torch.Tensor`s:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
@@ -684,9 +684,9 @@
|
||||
],
|
||||
"source": [
|
||||
"# eager mode RAII\n",
|
||||
"from shark.shark_runner import SharkEagerMode\n",
|
||||
"from amdshark.amdshark_runner import AMDSharkEagerMode\n",
|
||||
"\n",
|
||||
"shark_eager_mode = SharkEagerMode(\"cpu\")\n",
|
||||
"amdshark_eager_mode = AMDSharkEagerMode(\"cpu\")\n",
|
||||
"\n",
|
||||
"t = torch.ones((10, 10))\n",
|
||||
"u = torch.ones((10, 10))\n",
|
||||
@@ -712,7 +712,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"The `SharkEagerMode` class is a hacky take on [RAII](https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization) that defines a \"deleter\" that runs when an instantiation (of `SharkEagerMode`) is garbage collected. Takeaway is that if you want to turn off `SharkEagerMode`, or switch backends, you need to `del` the instance:"
|
||||
"The `AMDSharkEagerMode` class is a hacky take on [RAII](https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization) that defines a \"deleter\" that runs when an instantiation (of `AMDSharkEagerMode`) is garbage collected. Takeaway is that if you want to turn off `AMDSharkEagerMode`, or switch backends, you need to `del` the instance:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
@@ -757,8 +757,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"del shark_eager_mode\n",
|
||||
"shark_eager_mode = SharkEagerMode(\"cuda\")\n",
|
||||
"del amdshark_eager_mode\n",
|
||||
"amdshark_eager_mode = AMDSharkEagerMode(\"cuda\")\n",
|
||||
"\n",
|
||||
"t = torch.ones((10, 10))\n",
|
||||
"u = torch.ones((10, 10))\n",
|
||||
@@ -17,8 +17,8 @@ from torch.utils.cpp_extension import load_inline, include_paths
|
||||
from torch_mlir.eager_mode import torch_mlir_tensor
|
||||
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
||||
|
||||
from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
|
||||
from shark.shark_runner import SharkEagerMode
|
||||
from amdshark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
|
||||
from amdshark.amdshark_runner import AMDSharkEagerMode
|
||||
|
||||
|
||||
def test_cpu():
|
||||
@@ -85,7 +85,7 @@ def test_gpu():
|
||||
|
||||
def test_python_mode_ref_backend():
|
||||
# hide this wherever you want?
|
||||
_ = SharkEagerMode("refbackend")
|
||||
_ = AMDSharkEagerMode("refbackend")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = torch.ones((10, 10), device="cpu")
|
||||
@@ -103,7 +103,7 @@ def test_python_mode_ref_backend():
|
||||
|
||||
def test_python_mode_iree_cpu():
|
||||
# hide this wherever you want?
|
||||
_ = SharkEagerMode("cpu")
|
||||
_ = AMDSharkEagerMode("cpu")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = torch.ones((10, 10), device="cpu")
|
||||
@@ -121,7 +121,7 @@ def test_python_mode_iree_cpu():
|
||||
|
||||
|
||||
def test_python_mode_iree_gpu():
|
||||
_ = SharkEagerMode("gpu")
|
||||
_ = AMDSharkEagerMode("gpu")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = torch.ones((10, 10), device="cpu")
|
||||
@@ -47,7 +47,7 @@ golden_probabilities = torch.nn.functional.softmax(
|
||||
|
||||
golden_confidences = golden_confidences.numpy()
|
||||
|
||||
from shark.torch_mlir_lockstep_tensor import TorchMLIRLockstepTensor
|
||||
from amdshark.torch_mlir_lockstep_tensor import TorchMLIRLockstepTensor
|
||||
|
||||
input_detached_clone = input_batch.clone()
|
||||
eager_input_batch = TorchMLIRLockstepTensor(input_detached_clone)
|
||||
@@ -62,7 +62,7 @@ probabilities = torch.nn.functional.softmax(
|
||||
torch.from_numpy(confidences), dim=0
|
||||
).numpy()
|
||||
|
||||
print("The obtained result via shark is: ", confidences)
|
||||
print("The obtained result via amdshark is: ", confidences)
|
||||
print("The golden result is:", golden_confidences)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
@@ -3,7 +3,7 @@ import requests
|
||||
|
||||
from transformers import CLIPProcessor, TFCLIPModel
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
# Create a set of inputs
|
||||
clip_vit_inputs = [
|
||||
@@ -43,7 +43,7 @@ if __name__ == "__main__":
|
||||
padding=True,
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
CLIPModule(),
|
||||
(
|
||||
inputs["input_ids"],
|
||||
@@ -51,11 +51,11 @@ if __name__ == "__main__":
|
||||
inputs["pixel_values"],
|
||||
),
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
amdshark_module.set_frontend("tensorflow")
|
||||
amdshark_module.compile()
|
||||
|
||||
print(
|
||||
shark_module.forward(
|
||||
amdshark_module.forward(
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import functools
|
||||
@@ -176,12 +176,12 @@ def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
|
||||
mlir_model = str(module)
|
||||
func_name = "forward"
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
amdshark_module.compile()
|
||||
|
||||
return shark_module
|
||||
return amdshark_module
|
||||
|
||||
|
||||
model_path = "models/RRDB_ESRGAN_x4.pth" # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
|
||||
@@ -213,22 +213,22 @@ if __name__ == "__main__":
|
||||
img_LR = img_LR.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
shark_module = compile_through_fx(inference, img_LR)
|
||||
shark_output = shark_module.forward((img_LR,))
|
||||
shark_output = torch.from_numpy(shark_output)
|
||||
shark_output = (
|
||||
shark_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
amdshark_module = compile_through_fx(inference, img_LR)
|
||||
amdshark_output = amdshark_module.forward((img_LR,))
|
||||
amdshark_output = torch.from_numpy(amdshark_output)
|
||||
amdshark_output = (
|
||||
amdshark_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
)
|
||||
esrgan_output = (
|
||||
model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
)
|
||||
# SHARK OUTPUT
|
||||
shark_output = np.transpose(shark_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
shark_output = (shark_output * 255.0).round()
|
||||
# AMDSHARK OUTPUT
|
||||
amdshark_output = np.transpose(amdshark_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
amdshark_output = (amdshark_output * 255.0).round()
|
||||
cv2.imwrite(
|
||||
"OutputImages/{:s}_rlt_shark_output.png".format(base), shark_output
|
||||
"OutputImages/{:s}_rlt_amdshark_output.png".format(base), amdshark_output
|
||||
)
|
||||
print("Generated SHARK's output")
|
||||
print("Generated AMDSHARK's output")
|
||||
# ESRGAN OUTPUT
|
||||
esrgan_output = np.transpose(esrgan_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
esrgan_output = (esrgan_output * 255.0).round()
|
||||
@@ -1,7 +1,7 @@
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
from iree.compiler import compile_str
|
||||
from iree import runtime as ireert
|
||||
import os
|
||||
@@ -35,7 +35,7 @@ if __name__ == "__main__":
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
|
||||
mlir_importer = SharkImporter(
|
||||
mlir_importer = AMDSharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="torch",
|
||||
@@ -43,11 +43,9 @@ if __name__ == "__main__":
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
amdshark_module = AMDSharkInference(minilm_mlir)
|
||||
amdshark_module.compile()
|
||||
token_logits = torch.tensor(amdshark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
encoded_inputs["input_ids"] == tokenizer.mask_token_id
|
||||
)[1]
|
||||
@@ -71,7 +69,7 @@ if __name__ == "__main__":
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
token_logits = torch.tensor(amdshark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
encoded_inputs["input_ids"] == tokenizer.mask_token_id
|
||||
)[1]
|
||||
@@ -3,8 +3,8 @@ import requests
|
||||
|
||||
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
from iree.compiler import tf as tfc
|
||||
from iree.compiler import compile_str
|
||||
from iree import runtime as ireert
|
||||
@@ -46,7 +46,7 @@ if __name__ == "__main__":
|
||||
return_tensors="tf",
|
||||
)
|
||||
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
|
||||
mlir_importer = SharkImporter(
|
||||
mlir_importer = AMDSharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="tf",
|
||||
@@ -54,11 +54,11 @@ if __name__ == "__main__":
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
amdshark_module = AMDSharkInference(minilm_mlir, mlir_dialect="mhlo")
|
||||
amdshark_module.compile()
|
||||
output_idx = 0
|
||||
data_idx = 1
|
||||
token_logits = shark_module.forward(inputs)[output_idx][data_idx]
|
||||
token_logits = amdshark_module.forward(inputs)[output_idx][data_idx]
|
||||
mask_id = np.where(
|
||||
tf.squeeze(encoded_inputs["input_ids"]) == tokenizer.mask_token_id
|
||||
)
|
||||
@@ -82,7 +82,7 @@ if __name__ == "__main__":
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
token_logits = shark_module.forward(inputs)[output_idx][data_idx]
|
||||
token_logits = amdshark_module.forward(inputs)[output_idx][data_idx]
|
||||
mask_id = np.where(
|
||||
tf.squeeze(encoded_inputs["input_ids"])
|
||||
== tokenizer.mask_token_id
|
||||
14
amdshark/examples/amdshark_inference/bloom_tank.py
Normal file
14
amdshark/examples/amdshark_inference/bloom_tank.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"bloom", frontend="torch"
|
||||
)
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_model, device="cpu", mlir_dialect="tm_tensor"
|
||||
)
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward(inputs)
|
||||
print("The obtained result via amdshark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
@@ -3,7 +3,7 @@ import requests
|
||||
|
||||
from transformers import GPT2Tokenizer, TFGPT2Model
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
# Create a set of inputs
|
||||
gpt2_inputs = [
|
||||
@@ -30,11 +30,11 @@ if __name__ == "__main__":
|
||||
text = "I love the distilled version of models."
|
||||
|
||||
inputs = tokenizer(text, return_tensors="tf")
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
GPT2Module(), (inputs["input_ids"], inputs["attention_mask"])
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
amdshark_module.set_frontend("tensorflow")
|
||||
amdshark_module.compile()
|
||||
print(
|
||||
shark_module.forward((inputs["input_ids"], inputs["attention_mask"]))
|
||||
amdshark_module.forward((inputs["input_ids"], inputs["attention_mask"]))
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
# SHARK LLaMA
|
||||
# AMDSHARK LLaMA
|
||||
|
||||
## TORCH-MLIR Version
|
||||
|
||||
@@ -14,5 +14,5 @@ git clone https://github.com/nod-ai/llama.git
|
||||
Then in this repository
|
||||
```
|
||||
pip install -e .
|
||||
python llama/shark_model.py
|
||||
python llama/amdshark_model.py
|
||||
```
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_compile import shark_compile_through_fx
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_compile import amdshark_compile_through_fx
|
||||
from MEGABYTE_pytorch import MEGABYTE
|
||||
|
||||
import os
|
||||
@@ -38,10 +38,10 @@ inputs = [torch.randint(0, 16000, (1, 1024, 4))]
|
||||
|
||||
# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :-
|
||||
# 1. aten.alias
|
||||
shark_module, _ = shark_compile_through_fx(
|
||||
amdshark_module, _ = amdshark_compile_through_fx(
|
||||
model=megaModel,
|
||||
inputs=inputs,
|
||||
extended_model_name="mega_shark",
|
||||
extended_model_name="mega_amdshark",
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
save_dir=os.getcwd(),
|
||||
@@ -59,8 +59,8 @@ def print_output_info(output, msg):
|
||||
print("\n\t", output.shape)
|
||||
|
||||
|
||||
ans = shark_module("forward", inputs)
|
||||
print_output_info(torch.from_numpy(ans), "SHARK's output")
|
||||
ans = amdshark_module("forward", inputs)
|
||||
print_output_info(torch.from_numpy(ans), "AMDSHARK's output")
|
||||
|
||||
ans = megaModel.forward(*inputs)
|
||||
print_output_info(ans, "ORIGINAL Model's output")
|
||||
@@ -68,5 +68,5 @@ print_output_info(ans, "ORIGINAL Model's output")
|
||||
# and sample from the logits accordingly
|
||||
# or you can use the generate function
|
||||
|
||||
# NEED TO LOOK AT THIS LATER IF REQUIRED IN SHARK.
|
||||
# NEED TO LOOK AT THIS LATER IF REQUIRED IN AMDSHARK.
|
||||
# sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
|
||||
31
amdshark/examples/amdshark_inference/mhlo_example.py
Normal file
31
amdshark/examples/amdshark_inference/mhlo_example.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
import numpy as np
|
||||
|
||||
mhlo_ir = r"""builtin.module {
|
||||
func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32>
|
||||
%1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
return %1 : tensor<4x4xf32>
|
||||
}
|
||||
}"""
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
|
||||
print("Running amdshark on cpu backend")
|
||||
amdshark_module = AMDSharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
||||
|
||||
# Generate the random inputs and feed into the graph.
|
||||
x = amdshark_module.generate_random_inputs()
|
||||
amdshark_module.compile()
|
||||
print(amdshark_module.forward(x))
|
||||
|
||||
print("Running amdshark on cuda backend")
|
||||
amdshark_module = AMDSharkInference(mhlo_ir, device="cuda", mlir_dialect="mhlo")
|
||||
amdshark_module.compile()
|
||||
print(amdshark_module.forward(x))
|
||||
|
||||
print("Running amdshark on vulkan backend")
|
||||
amdshark_module = AMDSharkInference(mhlo_ir, device="vulkan", mlir_dialect="mhlo")
|
||||
amdshark_module.compile()
|
||||
print(amdshark_module.forward(x))
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
@@ -23,13 +23,13 @@ class MiniLMSequenceClassification(torch.nn.Module):
|
||||
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
MiniLMSequenceClassification(),
|
||||
(test_input,),
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
|
||||
shark_module.compile()
|
||||
shark_module.forward((test_input,))
|
||||
shark_module.benchmark_all((test_input,))
|
||||
amdshark_module.compile()
|
||||
amdshark_module.forward((test_input,))
|
||||
amdshark_module.benchmark_all((test_input,))
|
||||
@@ -1,6 +1,6 @@
|
||||
import tensorflow as tf
|
||||
from transformers import BertModel, BertTokenizer, TFBertModel
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
@@ -53,9 +53,9 @@ if __name__ == "__main__":
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
BertModule(), test_input, benchmark_mode=True
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
shark_module.benchmark_all(test_input)
|
||||
amdshark_module.set_frontend("tensorflow")
|
||||
amdshark_module.compile()
|
||||
amdshark_module.benchmark_all(test_input)
|
||||
@@ -3,7 +3,7 @@ import torch
|
||||
import jax
|
||||
from typing import Union, Dict, List, Any
|
||||
import numpy as np
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
import io
|
||||
|
||||
NumpyTree = Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]
|
||||
@@ -60,11 +60,11 @@ jax_model = get_jax_model()
|
||||
mlir = export_jax_to_mlir(jax_model, sample_input)
|
||||
|
||||
# Compile and load module.
|
||||
shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo")
|
||||
shark_inference.compile()
|
||||
amdshark_inference = AMDSharkInference(mlir_module=mlir, mlir_dialect="mhlo")
|
||||
amdshark_inference.compile()
|
||||
|
||||
# Run main function.
|
||||
result = shark_inference("main", jax.tree_util.tree_flatten(sample_input)[0])
|
||||
result = amdshark_inference("main", jax.tree_util.tree_flatten(sample_input)[0])
|
||||
|
||||
# Run JAX model.
|
||||
reference_result = jax.tree_util.tree_flatten(jax_model(**sample_input))[0]
|
||||
@@ -1,6 +1,6 @@
|
||||
flax
|
||||
jax[cpu]
|
||||
nodai-SHARK
|
||||
nodai-AMDSHARK
|
||||
orbax
|
||||
transformers
|
||||
torch
|
||||
23
amdshark/examples/amdshark_inference/minilm_jit.py
Normal file
23
amdshark/examples/amdshark_inference/minilm_jit.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased",
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
|
||||
amdshark_module = AMDSharkInference(mlir_model, device="cpu", mlir_dialect="linalg")
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward(inputs)
|
||||
print("The obtained result via amdshark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
|
||||
|
||||
# Let's generate random inputs, currently supported
|
||||
# for static models.
|
||||
rand_inputs = amdshark_module.generate_random_inputs()
|
||||
rand_results = amdshark_module.forward(rand_inputs)
|
||||
|
||||
print("Running amdshark_module with random_inputs is: ", rand_results)
|
||||
@@ -1,6 +1,6 @@
|
||||
import tensorflow as tf
|
||||
from transformers import BertModel, BertTokenizer, TFBertModel
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
@@ -48,7 +48,7 @@ if __name__ == "__main__":
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
BertModule(),
|
||||
(
|
||||
encoded_input["input_ids"],
|
||||
@@ -56,11 +56,11 @@ if __name__ == "__main__":
|
||||
encoded_input["token_type_ids"],
|
||||
),
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
amdshark_module.set_frontend("tensorflow")
|
||||
amdshark_module.compile()
|
||||
|
||||
print(
|
||||
shark_module.forward(
|
||||
amdshark_module.forward(
|
||||
(
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
|
||||
torch.hub.list("zhanghang1989/ResNeSt", force_reload=True)
|
||||
|
||||
@@ -21,7 +21,7 @@ class ResnestModule(torch.nn.Module):
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
mlir_importer = AMDSharkImporter(
|
||||
ResnestModule(),
|
||||
(input,),
|
||||
frontend="torch",
|
||||
@@ -33,7 +33,7 @@ mlir_importer = SharkImporter(
|
||||
|
||||
print(golden_out)
|
||||
|
||||
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input,))
|
||||
amdshark_module = AMDSharkInference(vision_mlir, mlir_dialect="linalg")
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward((input,))
|
||||
print("Obtained result", result)
|
||||
@@ -1,5 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.parser import amdshark_args
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -49,23 +49,21 @@ module = torch_mlir.compile(
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cuda", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
amdshark_module = AMDSharkInference(mlir_model, device="cuda", mlir_dialect="linalg")
|
||||
amdshark_module.compile()
|
||||
|
||||
|
||||
def shark_result(x):
|
||||
def amdshark_result(x):
|
||||
x_ny = x.cpu().detach().numpy()
|
||||
inputs = (x_ny,)
|
||||
result = shark_module.forward(inputs)
|
||||
result = amdshark_module.forward(inputs)
|
||||
return torch.from_numpy(result)
|
||||
|
||||
|
||||
observed_out = shark_result(test_input_fp16)
|
||||
observed_out = amdshark_result(test_input_fp16)
|
||||
|
||||
print("Golden result:", actual_out_fp16)
|
||||
print("SHARK result:", observed_out)
|
||||
print("AMDSHARK result:", observed_out)
|
||||
|
||||
actual_out_fp16 = actual_out_fp16.to(device=torch.device("cpu"))
|
||||
|
||||
@@ -4,8 +4,8 @@ import torch
|
||||
import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_downloader import download_model
|
||||
|
||||
|
||||
################################## Preprocessing inputs and model ############
|
||||
@@ -70,13 +70,13 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"resnet50", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(mlir_model, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
path = shark_module.save_module()
|
||||
shark_module.load_module(path)
|
||||
result = shark_module("forward", (img.detach().numpy(),))
|
||||
amdshark_module = AMDSharkInference(mlir_model, mlir_dialect="linalg")
|
||||
amdshark_module.compile()
|
||||
path = amdshark_module.save_module()
|
||||
amdshark_module.load_module(path)
|
||||
result = amdshark_module("forward", (img.detach().numpy(),))
|
||||
|
||||
print("The top 3 results obtained via shark_runner is:")
|
||||
print("The top 3 results obtained via amdshark_runner is:")
|
||||
print(top3_possibilities(torch.from_numpy(result)))
|
||||
|
||||
print()
|
||||
@@ -34,8 +34,8 @@ import subprocess
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_downloader import download_public_file
|
||||
from transformers import (
|
||||
BloomTokenizerFast,
|
||||
BloomForSequenceClassification,
|
||||
@@ -77,13 +77,13 @@ class ShardedBloom:
|
||||
module = f_.read()
|
||||
f_.close()
|
||||
module = bytes(module, "utf-8")
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
module,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
shark_module.save_module(
|
||||
amdshark_module.save_module(
|
||||
module_name=f"{self.src_folder}/{layer_name}",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
@@ -92,14 +92,14 @@ class ShardedBloom:
|
||||
],
|
||||
)
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
"",
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
return shark_module
|
||||
return amdshark_module
|
||||
|
||||
def init_layers(self, device, replace=False, device_idx=[0]):
|
||||
if device_idx is not None:
|
||||
@@ -311,7 +311,7 @@ def _prepare_attn_mask(
|
||||
|
||||
def download_model(destination_folder, model_name):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/sharded_bloom/{model_name}/", destination_folder
|
||||
f"gs://amdshark_tank/sharded_bloom/{model_name}/", destination_folder
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys
|
||||
import os
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
|
||||
import re
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
@@ -142,7 +142,7 @@ if __name__ == "__main__":
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
@@ -150,7 +150,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
amdshark_module.save_module(
|
||||
module_name=f"{working_dir}/word_embeddings",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
@@ -159,8 +159,8 @@ if __name__ == "__main__":
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(f"{working_dir}/word_embeddings.vmfb")
|
||||
input_embeds = shark_module(
|
||||
amdshark_module.load_module(f"{working_dir}/word_embeddings.vmfb")
|
||||
input_embeds = amdshark_module(
|
||||
inputs=(input_ids,), function_name="forward"
|
||||
)
|
||||
input_embeds = torch.tensor(input_embeds).float()
|
||||
@@ -175,7 +175,7 @@ if __name__ == "__main__":
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
@@ -183,7 +183,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
amdshark_module.save_module(
|
||||
module_name=f"{working_dir}/word_embeddings_layernorm",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
@@ -192,10 +192,10 @@ if __name__ == "__main__":
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(
|
||||
amdshark_module.load_module(
|
||||
f"{working_dir}/word_embeddings_layernorm.vmfb"
|
||||
)
|
||||
hidden_states = shark_module(
|
||||
hidden_states = amdshark_module(
|
||||
inputs=(input_embeds,), function_name="forward"
|
||||
)
|
||||
hidden_states = torch.tensor(hidden_states).float()
|
||||
@@ -243,7 +243,7 @@ if __name__ == "__main__":
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_str,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
@@ -251,7 +251,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
amdshark_module.save_module(
|
||||
module_name=f"{working_dir}/bloom_block_{layer_name}",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
@@ -260,11 +260,11 @@ if __name__ == "__main__":
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(
|
||||
amdshark_module.load_module(
|
||||
f"{working_dir}/bloom_block_{layer_name}.vmfb"
|
||||
)
|
||||
|
||||
output = shark_module(
|
||||
output = amdshark_module(
|
||||
inputs=(
|
||||
hidden_states.detach().numpy(),
|
||||
alibi.detach().numpy(),
|
||||
@@ -290,7 +290,7 @@ if __name__ == "__main__":
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
@@ -298,7 +298,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
amdshark_module.save_module(
|
||||
module_name=f"{working_dir}/ln_f",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
@@ -307,11 +307,11 @@ if __name__ == "__main__":
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(f"{working_dir}/ln_f.vmfb")
|
||||
amdshark_module.load_module(f"{working_dir}/ln_f.vmfb")
|
||||
|
||||
hidden_states = torch.load(f"{working_dir}/hidden_states_{n_layer}.pt")
|
||||
|
||||
hidden_states = shark_module(
|
||||
hidden_states = amdshark_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
|
||||
@@ -347,7 +347,7 @@ if __name__ == "__main__":
|
||||
logits = lm_head(torch.tensor(hidden_states).float())
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
@@ -355,7 +355,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
amdshark_module.save_module(
|
||||
module_name=f"{working_dir}/lm_head",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
@@ -364,9 +364,9 @@ if __name__ == "__main__":
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(f"{working_dir}/lm_head.vmfb")
|
||||
amdshark_module.load_module(f"{working_dir}/lm_head.vmfb")
|
||||
|
||||
logits = shark_module(
|
||||
logits = amdshark_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
|
||||
@@ -52,8 +52,8 @@ import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
@@ -349,7 +349,7 @@ input_dlrm = (dense_inp, vs0, *vsi)
|
||||
|
||||
golden_output = dlrm_model(dense_inp, vs0, *vsi)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
mlir_importer = AMDSharkImporter(
|
||||
dlrm_model,
|
||||
input_dlrm,
|
||||
frontend="torch",
|
||||
@@ -359,11 +359,11 @@ mlir_importer = SharkImporter(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
dlrm_mlir, func_name, device="vulkan", mlir_dialect="linalg"
|
||||
amdshark_module = AMDSharkInference(
|
||||
dlrm_mlir, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(input_dlrm)
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward(input_dlrm)
|
||||
np.testing.assert_allclose(
|
||||
golden_output.detach().numpy(), result, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
@@ -15,8 +15,8 @@ from torchrec.models.dlrm import (
|
||||
SparseArch,
|
||||
OverArch,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(0)
|
||||
@@ -70,7 +70,7 @@ def to_list(key_jagged, combined_keys):
|
||||
return combined_list
|
||||
|
||||
|
||||
class SparseArchShark(nn.Module):
|
||||
class SparseArchAMDShark(nn.Module):
|
||||
def create_emb(self, embedding_dim, num_embeddings_list):
|
||||
embedding_list = nn.ModuleList()
|
||||
for i in range(0, num_embeddings_list.size):
|
||||
@@ -91,7 +91,7 @@ class SparseArchShark(nn.Module):
|
||||
total_features,
|
||||
num_embeddings_list,
|
||||
):
|
||||
super(SparseArchShark, self).__init__()
|
||||
super(SparseArchAMDShark, self).__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_features = total_features
|
||||
self.embedding_list = self.create_emb(
|
||||
@@ -150,7 +150,7 @@ def test_sparse_arch() -> None:
|
||||
),
|
||||
offsets=offsets,
|
||||
)
|
||||
sparse_archi = SparseArchShark(D, 3, np.array([10, 10]))
|
||||
sparse_archi = SparseArchAMDShark(D, 3, np.array([10, 10]))
|
||||
sparse_archi.embedding_list[0].weight = w1
|
||||
sparse_archi.embedding_list[1].weight = w2
|
||||
inputs = to_list(features, {"f1": 0, "f3": 0, "f2": 1})
|
||||
@@ -169,7 +169,7 @@ def test_sparse_arch() -> None:
|
||||
test_sparse_arch()
|
||||
|
||||
|
||||
class DLRMShark(nn.Module):
|
||||
class DLRMAMDShark(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
@@ -181,7 +181,7 @@ class DLRMShark(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.sparse_arch: SparseArchShark = SparseArchShark(
|
||||
self.sparse_arch: SparseArchAMDShark = SparseArchAMDShark(
|
||||
embedding_dim, total_features, num_embeddings_list
|
||||
)
|
||||
num_sparse_features: int = total_features
|
||||
@@ -250,7 +250,7 @@ def test_dlrm() -> None:
|
||||
dense_arch_layer_sizes=[20, D],
|
||||
over_arch_layer_sizes=[5, 1],
|
||||
)
|
||||
sparse_nn_nod = DLRMShark(
|
||||
sparse_nn_nod = DLRMAMDShark(
|
||||
embedding_dim=8,
|
||||
total_features=3,
|
||||
num_embeddings_list=np.array([100, 100]),
|
||||
@@ -283,7 +283,7 @@ def test_dlrm() -> None:
|
||||
# print(logits_nod)
|
||||
|
||||
# Import the module and print.
|
||||
mlir_importer = SharkImporter(
|
||||
mlir_importer = AMDSharkImporter(
|
||||
sparse_nn_nod,
|
||||
(dense_features, *x),
|
||||
frontend="torch",
|
||||
@@ -293,11 +293,11 @@ def test_dlrm() -> None:
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
dlrm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
amdshark_module = AMDSharkInference(
|
||||
dlrm_mlir, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward(inputs)
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
|
||||
torch.allclose(
|
||||
@@ -3,7 +3,7 @@ import requests
|
||||
|
||||
from transformers import T5Tokenizer, TFT5Model
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
# Create a set of inputs
|
||||
t5_inputs = [
|
||||
@@ -29,7 +29,7 @@ if __name__ == "__main__":
|
||||
text = "I love the distilled version of models."
|
||||
inputs = tokenizer(text, return_tensors="tf").input_ids
|
||||
|
||||
shark_module = SharkInference(T5Module(), (inputs, inputs))
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward((inputs, inputs)))
|
||||
amdshark_module = AMDSharkInference(T5Module(), (inputs, inputs))
|
||||
amdshark_module.set_frontend("tensorflow")
|
||||
amdshark_module.compile()
|
||||
print(amdshark_module.forward((inputs, inputs)))
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
@@ -35,9 +35,9 @@ vision_models_list = [
|
||||
]
|
||||
|
||||
for i, vision_model in enumerate(vision_models_list):
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
VisionModule(vision_model),
|
||||
(input,),
|
||||
)
|
||||
shark_module.compile()
|
||||
shark_module.forward((input,))
|
||||
amdshark_module.compile()
|
||||
amdshark_module.forward((input,))
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
|
||||
|
||||
class UnetModule(torch.nn.Module):
|
||||
@@ -23,7 +23,7 @@ class UnetModule(torch.nn.Module):
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
mlir_importer = AMDSharkImporter(
|
||||
UnetModule(),
|
||||
(input,),
|
||||
frontend="torch",
|
||||
@@ -33,7 +33,7 @@ mlir_importer = SharkImporter(
|
||||
tracing_required=False
|
||||
)
|
||||
|
||||
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input,))
|
||||
amdshark_module = AMDSharkInference(vision_mlir, mlir_dialect="linalg")
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward((input,))
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
@@ -1,13 +1,13 @@
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from pipeline_shark_stable_diffusion_upscale import (
|
||||
SharkStableDiffusionUpscalePipeline,
|
||||
from pipeline_amdshark_stable_diffusion_upscale import (
|
||||
AMDSharkStableDiffusionUpscalePipeline,
|
||||
)
|
||||
import torch
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipeline = SharkStableDiffusionUpscalePipeline(model_id)
|
||||
pipeline = AMDSharkStableDiffusionUpscalePipeline(model_id)
|
||||
|
||||
# let's download an image
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
|
||||
@@ -32,13 +32,13 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
shark_clip = compile_through_fx(
|
||||
amdshark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
model_input["clip"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_clip
|
||||
return amdshark_clip
|
||||
|
||||
|
||||
def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
@@ -55,13 +55,13 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
return x
|
||||
|
||||
vae = VaeModel()
|
||||
shark_vae = compile_through_fx(
|
||||
amdshark_vae = compile_through_fx(
|
||||
vae,
|
||||
model_input["vae"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
return amdshark_vae
|
||||
|
||||
|
||||
def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
@@ -87,7 +87,7 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
|
||||
unet = UnetModel()
|
||||
f16_input_mask = (True, True, True, False)
|
||||
shark_unet = compile_through_fx(
|
||||
amdshark_unet = compile_through_fx(
|
||||
unet,
|
||||
model_input["unet"],
|
||||
model_name=model_name,
|
||||
@@ -95,4 +95,4 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
f16_input_mask=f16_input_mask,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_unet
|
||||
return amdshark_unet
|
||||
@@ -5,7 +5,7 @@ from model_wrappers import (
|
||||
get_clip_mlir,
|
||||
)
|
||||
from upscaler_args import args
|
||||
from utils import get_shark_model
|
||||
from utils import get_amdshark_model
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
if BATCH_SIZE != 1:
|
||||
@@ -24,25 +24,25 @@ clip_flag = [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
|
||||
bucket = "gs://shark_tank/stable_diffusion/"
|
||||
bucket = "gs://amdshark_tank/stable_diffusion/"
|
||||
|
||||
|
||||
def get_unet():
|
||||
model_name = "upscaler_unet"
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(model_name, unet_flag)
|
||||
return get_shark_model(bucket, model_name, unet_flag)
|
||||
return get_amdshark_model(bucket, model_name, unet_flag)
|
||||
|
||||
|
||||
def get_vae():
|
||||
model_name = "upscaler_vae"
|
||||
if args.import_mlir:
|
||||
return get_vae_mlir(model_name, vae_flag)
|
||||
return get_shark_model(bucket, model_name, vae_flag)
|
||||
return get_amdshark_model(bucket, model_name, vae_flag)
|
||||
|
||||
|
||||
def get_clip():
|
||||
model_name = "upscaler_clip"
|
||||
if args.import_mlir:
|
||||
return get_clip_mlir(model_name, clip_flag)
|
||||
return get_shark_model(bucket, model_name, clip_flag)
|
||||
return get_amdshark_model(bucket, model_name, clip_flag)
|
||||
@@ -46,13 +46,13 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
def shark_run_wrapper(model, *args):
|
||||
def amdshark_run_wrapper(model, *args):
|
||||
np_inputs = tuple([x.detach().numpy() for x in args])
|
||||
outputs = model("forward", np_inputs)
|
||||
return torch.from_numpy(outputs)
|
||||
|
||||
|
||||
class SharkStableDiffusionUpscalePipeline:
|
||||
class AMDSharkStableDiffusionUpscalePipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_id,
|
||||
@@ -131,7 +131,7 @@ class SharkStableDiffusionUpscalePipeline:
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
text_embeddings = shark_run_wrapper(
|
||||
text_embeddings = amdshark_run_wrapper(
|
||||
self.text_encoder, text_input_ids.to(device)
|
||||
)
|
||||
|
||||
@@ -180,7 +180,7 @@ class SharkStableDiffusionUpscalePipeline:
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
uncond_embeddings = shark_run_wrapper(
|
||||
uncond_embeddings = amdshark_run_wrapper(
|
||||
self.text_encoder,
|
||||
uncond_input.input_ids.to(device),
|
||||
)
|
||||
@@ -227,7 +227,7 @@ class SharkStableDiffusionUpscalePipeline:
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
image = shark_run_wrapper(self.vae, latents)
|
||||
image = amdshark_run_wrapper(self.vae, latents)
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
@@ -445,7 +445,7 @@ class SharkStableDiffusionUpscalePipeline:
|
||||
timestep = torch.tensor([t]).to(torch.float32)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = shark_run_wrapper(
|
||||
noise_pred = amdshark_run_wrapper(
|
||||
self.unet,
|
||||
latent_model_input.half(),
|
||||
timestep,
|
||||
@@ -59,7 +59,7 @@ p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
|
||||
help="imports the model from torch module to amdshark_module otherwise downloads the model from amdshark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -94,18 +94,5 @@ p.add_argument(
|
||||
help="Profiles vulkan device and collects the .rdc info",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="4147483648",
|
||||
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for disabling vulkan validation layers when benchmarking",
|
||||
)
|
||||
|
||||
|
||||
args = p.parse_args()
|
||||
@@ -1,15 +1,16 @@
|
||||
import os
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from upscaler_args import args
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
from amdshark.amdshark_importer import import_with_fx
|
||||
from amdshark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
def _compile_module(amdshark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
device = (
|
||||
args.device
|
||||
@@ -20,7 +21,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
amdshark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
print("Saving to {}".format(vmfb_path))
|
||||
@@ -30,55 +31,52 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
vmfb_path
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
path = amdshark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
amdshark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
amdshark_module.compile(extra_args)
|
||||
return amdshark_module
|
||||
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.parser import shark_args
|
||||
# Downloads the model from amdshark_tank and returns the amdshark_module.
|
||||
def get_amdshark_model(tank_url, model_name, extra_args=[]):
|
||||
from amdshark.amdshark_downloader import download_model
|
||||
from amdshark.parser import amdshark_args
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
# shark_args.local_tank_cache = args.local_tank_cache
|
||||
# Set local amdshark_tank cache directory.
|
||||
# amdshark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=tank_url,
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
return _compile_module(amdshark_module, model_name, extra_args)
|
||||
|
||||
|
||||
# Converts the torch-module into a shark_module.
|
||||
# Converts the torch-module into a amdshark_module.
|
||||
def compile_through_fx(
|
||||
model, inputs, model_name, is_f16=False, f16_input_mask=None, extra_args=[]
|
||||
):
|
||||
mlir_module, func_name = import_with_fx(
|
||||
model, inputs, is_f16, f16_input_mask
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
return _compile_module(amdshark_module, model_name, extra_args)
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
|
||||
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
|
||||
]
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
@@ -114,7 +112,7 @@ def get_device_mapping(driver, key_combination=3):
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired combination of name/path.
|
||||
"""
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
@@ -207,7 +205,7 @@ def set_init_device_flags():
|
||||
# Utility to get list of devices available.
|
||||
def get_available_devices():
|
||||
def get_devices_by_name(driver_name):
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
15
amdshark/examples/amdshark_inference/v_diffusion.py
Normal file
15
amdshark/examples/amdshark_inference/v_diffusion.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"v_diffusion", frontend="torch"
|
||||
)
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_model, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward(inputs)
|
||||
print("The obtained result via amdshark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch.nn.utils import _stateless
|
||||
from torch.nn.utils import stateless
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
from amdshark.amdshark_trainer import AMDSharkTrainer
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
@@ -33,7 +33,7 @@ inp = (torch.randint(2, (1, 128)),)
|
||||
|
||||
def forward(params, buffers, args):
|
||||
params_and_buffers = {**params, **buffers}
|
||||
_stateless.functional_call(
|
||||
stateless.functional_call(
|
||||
mod, params_and_buffers, args, {}
|
||||
).sum().backward()
|
||||
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
|
||||
@@ -42,7 +42,7 @@ def forward(params, buffers, args):
|
||||
return params, buffers
|
||||
|
||||
|
||||
shark_module = SharkTrainer(mod, inp)
|
||||
shark_module.compile(forward)
|
||||
|
||||
print(shark_module.train())
|
||||
amdshark_module = AMDSharkTrainer(mod, inp)
|
||||
amdshark_module.compile(forward)
|
||||
amdshark_module.train(num_iters=2)
|
||||
print("training done")
|
||||
@@ -3,8 +3,8 @@ import os
|
||||
import time
|
||||
import tensorflow as tf
|
||||
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
from shark.parser import parser
|
||||
from amdshark.amdshark_trainer import AMDSharkTrainer
|
||||
from amdshark.parser import parser
|
||||
from urllib import request
|
||||
|
||||
parser.add_argument(
|
||||
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
]
|
||||
file_link = "https://storage.googleapis.com/shark_tank/users/stanley/bert_tf_training.mlir"
|
||||
file_link = "https://storage.googleapis.com/amdshark_tank/users/stanley/bert_tf_training.mlir"
|
||||
response = request.urlretrieve(file_link, load_args.download_mlir_path)
|
||||
sample_input_tensors = [
|
||||
tf.convert_to_tensor(val, dtype=tf.int32)
|
||||
@@ -41,7 +41,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
with open(load_args.download_mlir_path, "rb") as input_file:
|
||||
bert_mlir = input_file.read()
|
||||
shark_module = SharkTrainer(
|
||||
amdshark_module = AMDSharkTrainer(
|
||||
bert_mlir,
|
||||
(
|
||||
sample_input_tensors,
|
||||
@@ -50,10 +50,10 @@ if __name__ == "__main__":
|
||||
),
|
||||
),
|
||||
)
|
||||
shark_module.set_frontend("mhlo")
|
||||
shark_module.compile()
|
||||
amdshark_module.set_frontend("mhlo")
|
||||
amdshark_module.compile()
|
||||
start = time.time()
|
||||
print(shark_module.train(num_iter))
|
||||
print(amdshark_module.train(num_iter))
|
||||
end = time.time()
|
||||
total_time = end - start
|
||||
print("time: " + str(total_time))
|
||||
@@ -8,7 +8,7 @@ from official.nlp.modeling import layers
|
||||
from official.nlp.modeling import networks
|
||||
from official.nlp.modeling.models import bert_classifier
|
||||
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
from amdshark.amdshark_trainer import AMDSharkTrainer
|
||||
|
||||
|
||||
tf.random.set_seed(0)
|
||||
@@ -79,7 +79,7 @@ if __name__ == "__main__":
|
||||
for val in predict_sample_input
|
||||
]
|
||||
num_iter = 10
|
||||
shark_module = SharkTrainer(
|
||||
amdshark_module = AMDSharkTrainer(
|
||||
BertModule(),
|
||||
(
|
||||
sample_input_tensors,
|
||||
@@ -88,10 +88,10 @@ if __name__ == "__main__":
|
||||
),
|
||||
),
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
amdshark_module.set_frontend("tensorflow")
|
||||
amdshark_module.compile()
|
||||
start = time.time()
|
||||
print(shark_module.train(num_iter))
|
||||
print(amdshark_module.train(num_iter))
|
||||
end = time.time()
|
||||
total_time = end - start
|
||||
print("time: " + str(total_time))
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from torch.nn.utils import _stateless
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
from amdshark.amdshark_trainer import AMDSharkTrainer
|
||||
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
@@ -37,8 +37,8 @@ def forward(params, buffers, args):
|
||||
|
||||
# fx_graph = forward(dict(mod.named_parameters()), dict(mod.named_buffers()), inp)
|
||||
|
||||
shark_module = SharkTrainer(mod, inp)
|
||||
amdshark_module = AMDSharkTrainer(mod, inp)
|
||||
# Pass the training function in case of torch
|
||||
shark_module.compile(training_fn=forward)
|
||||
amdshark_module.compile(training_fn=forward)
|
||||
|
||||
shark_module.train(num_iters=10)
|
||||
amdshark_module.train(num_iters=10)
|
||||
@@ -5,10 +5,10 @@
|
||||
<details>
|
||||
<summary>Installation (Linux)</summary>
|
||||
|
||||
### Activate shark.venv Virtual Environment
|
||||
### Activate amdshark.venv Virtual Environment
|
||||
|
||||
```shell
|
||||
source shark.venv/bin/activate
|
||||
source amdshark.venv/bin/activate
|
||||
|
||||
# Some older pip installs may not be able to handle the recent PyTorch deps
|
||||
python -m pip install --upgrade pip
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
## Installation (Linux)
|
||||
|
||||
### Activate shark.venv Virtual Environment
|
||||
### Activate amdshark.venv Virtual Environment
|
||||
|
||||
```shell
|
||||
source shark.venv/bin/activate
|
||||
source amdshark.venv/bin/activate
|
||||
|
||||
# Some older pip installs may not be able to handle the recent PyTorch deps
|
||||
python -m pip install --upgrade pip
|
||||
@@ -23,7 +23,7 @@ pip install accelerate transformers ftfy
|
||||
|
||||
Please cherry-pick this branch of torch-mlir: https://github.com/vivekkhandelwal1/torch-mlir/tree/sd-ops
|
||||
and build it locally. You can find the instructions for using locally build Torch-MLIR,
|
||||
here: https://github.com/nod-ai/SHARK#how-to-use-your-locally-built-iree--torch-mlir-with-shark
|
||||
here: https://github.com/nod-ai/AMD-SHARK-Studio#how-to-use-your-locally-built-iree--torch-mlir-with-amdshark
|
||||
|
||||
## Run the Stable diffusion fine tuning
|
||||
|
||||
@@ -24,7 +24,7 @@ from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
import torch._dynamo as dynamo
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
torch._dynamo.config.verbose = True
|
||||
|
||||
@@ -476,8 +476,8 @@ class UnetModel(torch.nn.Module):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
|
||||
shark_vae = VaeModel()
|
||||
shark_unet = UnetModel()
|
||||
amdshark_vae = VaeModel()
|
||||
amdshark_unet = UnetModel()
|
||||
|
||||
####### Creating our training data ########
|
||||
|
||||
@@ -638,14 +638,14 @@ def refbackend_torchdynamo_backend(
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
amdshark_module.compile()
|
||||
|
||||
def compiled_callable(*inputs):
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
result = shark_module("forward", inputs)
|
||||
result = amdshark_module("forward", inputs)
|
||||
if was_unwrapped:
|
||||
result = [
|
||||
result,
|
||||
@@ -709,7 +709,7 @@ optimizer = torch.optim.AdamW(
|
||||
# Training function
|
||||
def train_func(batch_pixel_values, batch_input_ids):
|
||||
# Convert images to latent space
|
||||
latents = shark_vae(batch_pixel_values).sample().detach()
|
||||
latents = amdshark_vae(batch_pixel_values).sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
@@ -731,7 +731,7 @@ def train_func(batch_pixel_values, batch_input_ids):
|
||||
encoder_hidden_states = text_encoder(batch_input_ids)[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = shark_unet(
|
||||
noise_pred = amdshark_unet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
@@ -31,7 +31,7 @@ from torch_mlir_e2e_test.eager_backends.refbackend import (
|
||||
NUMPY_TO_TORCH_DTYPE_DICT,
|
||||
)
|
||||
|
||||
from shark.iree_utils.compile_utils import (
|
||||
from amdshark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
IREE_DEVICE_MAP,
|
||||
)
|
||||
@@ -13,15 +13,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
## Common utilities to be shared by iree utilities.
|
||||
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
|
||||
def run_cmd(cmd, debug=False):
|
||||
def run_cmd(cmd, debug=False, raise_err=False):
|
||||
"""
|
||||
Inputs: cli command string.
|
||||
Inputs:
|
||||
cmd : cli command string.
|
||||
debug : if True, prints debug info
|
||||
raise_err : if True, raise exception to caller
|
||||
"""
|
||||
if debug:
|
||||
print("IREE run command: \n\n")
|
||||
@@ -39,8 +42,11 @@ def run_cmd(cmd, debug=False):
|
||||
stderr = result.stderr.decode()
|
||||
return stdout, stderr
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
if raise_err:
|
||||
raise Exception from e
|
||||
else:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
|
||||
|
||||
def iree_device_map(device):
|
||||
@@ -52,6 +58,8 @@ def iree_device_map(device):
|
||||
)
|
||||
if len(uri_parts) == 1:
|
||||
return iree_driver
|
||||
elif "rocm" in uri_parts:
|
||||
return "rocm"
|
||||
else:
|
||||
return f"{iree_driver}://{uri_parts[1]}"
|
||||
|
||||
@@ -63,12 +71,12 @@ def get_supported_device_list():
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"AMD-AIE": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"hip": "hip",
|
||||
"intel-gpu": "level_zero",
|
||||
}
|
||||
|
||||
@@ -82,47 +90,42 @@ def iree_target_map(device):
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cpu-task": "llvm-cpu",
|
||||
"AMD-AIE": "llvm-cpu",
|
||||
"cpu-sync": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"vulkan": "vulkan-spirv",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"hip": "rocm",
|
||||
"intel-gpu": "opencl-spirv",
|
||||
}
|
||||
|
||||
|
||||
# Finds whether the required drivers are installed for the given device.
|
||||
@functools.cache
|
||||
def check_device_drivers(device):
|
||||
"""Checks necessary drivers present for gpu and vulkan devices"""
|
||||
"""
|
||||
Checks necessary drivers present for gpu and vulkan devices
|
||||
False => drivers present!
|
||||
"""
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
subprocess.check_output("nvidia-smi")
|
||||
except Exception:
|
||||
return True
|
||||
elif device in ["vulkan"]:
|
||||
try:
|
||||
subprocess.check_output("vulkaninfo")
|
||||
except Exception:
|
||||
return True
|
||||
elif device == "metal":
|
||||
return False
|
||||
elif device in ["intel-gpu"]:
|
||||
try:
|
||||
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])
|
||||
return False
|
||||
except Exception:
|
||||
return True
|
||||
elif device == "cpu":
|
||||
return False
|
||||
elif device == "rocm":
|
||||
try:
|
||||
subprocess.check_output("rocminfo")
|
||||
except Exception:
|
||||
return True
|
||||
from iree.runtime import get_driver
|
||||
|
||||
device_mapped = iree_device_map(device)
|
||||
|
||||
try:
|
||||
_ = get_driver(device_mapped)
|
||||
except ValueError as ve:
|
||||
print(
|
||||
f"[ERR] device `{device}` not registered with IREE. "
|
||||
"Ensure IREE is configured for use with this device.\n"
|
||||
f"Full Error: \n {repr(ve)}"
|
||||
)
|
||||
return True
|
||||
except RuntimeError as re:
|
||||
print(f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}")
|
||||
return True
|
||||
|
||||
# Unknown device. We assume drivers are installed.
|
||||
return False
|
||||
@@ -130,11 +133,32 @@ def check_device_drivers(device):
|
||||
|
||||
# Installation info for the missing device drivers.
|
||||
def device_driver_info(device):
|
||||
if device == "cuda":
|
||||
return "nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
|
||||
elif device in ["metal", "vulkan"]:
|
||||
return "vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
|
||||
elif device == "rocm":
|
||||
return "rocm info not found. Please install rocm"
|
||||
device_driver_err_map = {
|
||||
"cuda": {
|
||||
"debug": "Try `nvidia-smi` on system to check.",
|
||||
"solution": " from https://www.nvidia.in/Download/index.aspx?lang=en-in for your system.",
|
||||
},
|
||||
"vulkan": {
|
||||
"debug": "Try `vulkaninfo` on system to check.",
|
||||
"solution": " from https://vulkan.lunarg.com/sdk/home for your distribution.",
|
||||
},
|
||||
"metal": {
|
||||
"debug": "Check if Bare metal is supported and enabled on your system.",
|
||||
"solution": ".",
|
||||
},
|
||||
"rocm": {
|
||||
"debug": f"Try `{'hip' if sys.platform == 'win32' else 'rocm'}info` on system to check.",
|
||||
"solution": " from https://rocm.docs.amd.com/en/latest/rocm.html for your system.",
|
||||
},
|
||||
}
|
||||
|
||||
if device in device_driver_err_map:
|
||||
err_msg = (
|
||||
f"Required drivers for {device} not found. {device_driver_err_map[device]['debug']} "
|
||||
f"Please install the required drivers{device_driver_err_map[device]['solution']} "
|
||||
f"For further assistance please reach out to the community on discord [https://discord.com/invite/RUqY2h2s9u]"
|
||||
f" and/or file a bug at https://github.com/nod-ai/AMD-SHARK-Studio/issues"
|
||||
)
|
||||
return err_msg
|
||||
else:
|
||||
return f"{device} is not supported."
|
||||
@@ -12,9 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
|
||||
from shark.iree_utils._common import run_cmd, iree_device_map
|
||||
from shark.iree_utils.cpu_utils import get_cpu_count
|
||||
from amdshark.iree_utils._common import run_cmd, iree_device_map
|
||||
from amdshark.iree_utils.cpu_utils import get_cpu_count
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
@@ -62,16 +61,12 @@ def build_benchmark_args(
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = benchmark_module.__path__[0]
|
||||
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
|
||||
if platform.system() == "Windows":
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module.exe"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
|
||||
time_extractor = None
|
||||
else:
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
@@ -106,15 +101,13 @@ def build_benchmark_args_non_tensor_input(
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = benchmark_module.__path__[0]
|
||||
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
|
||||
if platform.system() == "Windows":
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module.exe"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
|
||||
time_extractor = None
|
||||
else:
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
if function_name:
|
||||
@@ -139,7 +132,7 @@ def run_benchmark_module(benchmark_cl):
|
||||
benchmark_path = benchmark_cl[0]
|
||||
assert os.path.exists(
|
||||
benchmark_path
|
||||
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
|
||||
), "Cannot find iree_benchmark_module, Please contact AMDSHARK maintainer on discord."
|
||||
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
|
||||
try:
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
@@ -11,68 +11,93 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.iree_utils._common import iree_device_map, iree_target_map
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
|
||||
from shark.iree_utils.benchmark_utils import *
|
||||
from shark.parser import shark_args
|
||||
import functools
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from amdshark.parser import amdshark_args
|
||||
|
||||
from .trace import DetailLogger
|
||||
from ._common import iree_device_map, iree_target_map
|
||||
from .cpu_utils import get_iree_cpu_rt_args
|
||||
from .benchmark_utils import *
|
||||
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
print("Configuring for device:" + device)
|
||||
device_uri = device.split("://")
|
||||
if len(device_uri) > 1:
|
||||
if device_uri[0] not in ["vulkan"]:
|
||||
print(
|
||||
f"Specific device selection only supported for vulkan now."
|
||||
f"Proceeding with {device} as device."
|
||||
)
|
||||
device_num = device_uri[1]
|
||||
else:
|
||||
device_num = 0
|
||||
device, device_num = clean_device_info(device)
|
||||
|
||||
if device_uri[0] == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
if "cpu" in device:
|
||||
from amdshark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
|
||||
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
|
||||
u_kernel_flag = ["--iree-llvmcpu-enable-ukernels"]
|
||||
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
|
||||
|
||||
return (
|
||||
get_iree_cpu_args()
|
||||
+ data_tiling_flag
|
||||
+ u_kernel_flag
|
||||
+ stack_size_flag
|
||||
)
|
||||
if device_uri[0] == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
if device == "cuda":
|
||||
from amdshark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
return get_iree_gpu_args()
|
||||
if device_uri[0] == "vulkan":
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
if device == "vulkan":
|
||||
from amdshark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device_uri[0] == "metal":
|
||||
from shark.iree_utils.metal_utils import get_iree_metal_args
|
||||
if device == "metal":
|
||||
from amdshark.iree_utils.metal_utils import get_iree_metal_args
|
||||
|
||||
return get_iree_metal_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device_uri[0] == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
return get_iree_metal_args(extra_args=extra_args)
|
||||
if device == "rocm":
|
||||
from amdshark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args()
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
|
||||
if device == "hip":
|
||||
from amdshark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args, hip_driver=True)
|
||||
return []
|
||||
|
||||
def get_iree_target_triple(device):
|
||||
args = get_iree_device_args(device)
|
||||
for flag in args:
|
||||
if "triple" in flag:
|
||||
triple = flag.split("=")[-1]
|
||||
return triple
|
||||
return ""
|
||||
|
||||
|
||||
def clean_device_info(raw_device):
|
||||
# return appropriate device and device_id for consumption by Studio pipeline
|
||||
# Multiple devices only supported for vulkan and rocm (as of now).
|
||||
# default device must be selected for all others
|
||||
|
||||
device_id = None
|
||||
device = (
|
||||
raw_device
|
||||
if "=>" not in raw_device
|
||||
else raw_device.split("=>")[1].strip()
|
||||
)
|
||||
if "://" in device:
|
||||
device, device_id = device.split("://")
|
||||
if len(device_id) <= 2:
|
||||
device_id = int(device_id)
|
||||
|
||||
if device not in ["hip", "rocm", "vulkan"]:
|
||||
device_id = None
|
||||
if device in ["hip", "rocm", "vulkan"] and device_id == None:
|
||||
device_id = 0
|
||||
return device, device_id
|
||||
|
||||
|
||||
# Get the iree-compiler arguments given frontend.
|
||||
def get_iree_frontend_args(frontend):
|
||||
@@ -81,7 +106,7 @@ def get_iree_frontend_args(frontend):
|
||||
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
return [
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-flow-demote-i64-to-i32",
|
||||
"--iree-input-demote-i64-to-i32",
|
||||
]
|
||||
else:
|
||||
# Frontend not found.
|
||||
@@ -89,29 +114,42 @@ def get_iree_frontend_args(frontend):
|
||||
|
||||
|
||||
# Common args to be used given any frontend or device.
|
||||
def get_iree_common_args():
|
||||
return [
|
||||
"--iree-stream-resource-index-bits=64",
|
||||
"--iree-vm-target-index-bits=64",
|
||||
"--iree-vm-bytecode-module-strip-source-map=true",
|
||||
def get_iree_common_args(debug=False):
|
||||
common_args = [
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
"--mlir-elide-elementsattrs-if-larger=10",
|
||||
]
|
||||
if debug == True:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=false",
|
||||
"--verify=true",
|
||||
]
|
||||
)
|
||||
else:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=true",
|
||||
"--verify=false",
|
||||
]
|
||||
)
|
||||
return common_args
|
||||
|
||||
|
||||
# Args that are suitable only for certain models or groups of models.
|
||||
# shark_args are passed down from pytests to control which models compile with these flags,
|
||||
# but they can also be set in shark/parser.py
|
||||
# amdshark_args are passed down from pytests to control which models compile with these flags,
|
||||
# but they can also be set in amdshark/parser.py
|
||||
def get_model_specific_args():
|
||||
ms_args = []
|
||||
if shark_args.enable_conv_transform == True:
|
||||
if amdshark_args.enable_conv_transform == True:
|
||||
ms_args += [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc))"
|
||||
]
|
||||
if shark_args.enable_img2col_transform == True:
|
||||
if amdshark_args.enable_img2col_transform == True:
|
||||
ms_args += [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col))"
|
||||
]
|
||||
if shark_args.use_winograd == True:
|
||||
if amdshark_args.use_winograd == True:
|
||||
ms_args += [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
@@ -224,7 +262,7 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
|
||||
benchmark_file.write(f"DISPATCH: {d_}\n")
|
||||
benchmark_file.write(str(iter_per_second) + "\n")
|
||||
benchmark_file.write(
|
||||
"SHARK BENCHMARK RESULT: "
|
||||
"AMDSHARK BENCHMARK RESULT: "
|
||||
+ str(1 / (iter_per_second * 0.001))
|
||||
+ "\n"
|
||||
)
|
||||
@@ -274,14 +312,18 @@ def compile_module_to_flatbuffer(
|
||||
model_config_path,
|
||||
extra_args,
|
||||
model_name="None",
|
||||
debug=False,
|
||||
compile_str=False,
|
||||
write_to=None,
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
input_type = "auto"
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args()
|
||||
args += get_iree_common_args(debug=debug)
|
||||
args += get_model_specific_args()
|
||||
args += extra_args
|
||||
args += amdshark_args.additional_compile_args
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "auto"
|
||||
@@ -291,11 +333,10 @@ def compile_module_to_flatbuffer(
|
||||
input_type = "tosa"
|
||||
elif frontend in ["tm_tensor"]:
|
||||
input_type = ireec.InputType.TM_TENSOR
|
||||
elif frontend in ["torch", "pytorch"]:
|
||||
input_type = "torch"
|
||||
|
||||
# TODO: make it simpler.
|
||||
# Compile according to the input type, else just try compiling.
|
||||
if input_type != "":
|
||||
# Currently for MHLO/TOSA.
|
||||
if compile_str:
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[iree_target_map(device)],
|
||||
@@ -303,94 +344,161 @@ def compile_module_to_flatbuffer(
|
||||
input_type=input_type,
|
||||
)
|
||||
else:
|
||||
# Currently for Torch.
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
assert os.path.isfile(module)
|
||||
flatbuffer_blob = ireec.compile_file(
|
||||
str(module),
|
||||
input_type=input_type,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=args,
|
||||
)
|
||||
|
||||
if write_to is not None:
|
||||
with open(write_to, "wb") as f:
|
||||
f.write(flatbuffer_blob)
|
||||
return None
|
||||
|
||||
return flatbuffer_blob
|
||||
|
||||
|
||||
def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
def get_iree_module(
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=None,
|
||||
rt_flags: list = [],
|
||||
external_weight_file=None,
|
||||
):
|
||||
if external_weight_file is not None:
|
||||
index = ireert.ParameterIndex()
|
||||
index.load(external_weight_file)
|
||||
# Returns the compiled module and the configs.
|
||||
for flag in rt_flags:
|
||||
ireert.flags.parse_flag(flag)
|
||||
if device_idx is not None:
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
haldriver = ireert.get_driver(device)
|
||||
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
allocators=shark_args.device_allocator,
|
||||
hal_device_id,
|
||||
allocators=amdshark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = hal_device_id
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
config.vm_instance, flatbuffer_blob, warn_if_copy=False
|
||||
)
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
modules = []
|
||||
if external_weight_file is not None:
|
||||
modules.append(index.create_provider(scope="model"))
|
||||
ctx = ireert.SystemContext(vm_modules=modules, config=config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
ModuleCompiled = getattr(ctx.modules, vm_module.name)
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
def load_vmfb_using_mmap(
|
||||
flatbuffer_blob_or_path, device: str, device_idx: int = None
|
||||
flatbuffer_blob_or_path,
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
rt_flags: list = [],
|
||||
external_weight_file: str = None,
|
||||
):
|
||||
instance = ireert.VmInstance()
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
allocators=[],
|
||||
)
|
||||
# First get configs.
|
||||
if device_idx is not None:
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
haldriver = ireert.get_driver(device)
|
||||
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
print(f"Loading module {flatbuffer_blob_or_path}...")
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
ireert.flags.parse_flags(flag)
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
# (This would arise if we're invoking `load_module` from a SharkInference obj)
|
||||
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
|
||||
# (This would arise if we're invoking `compile` from a SharkInference obj)
|
||||
temp_file_to_unlink = None
|
||||
if isinstance(flatbuffer_blob_or_path, Path):
|
||||
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
|
||||
if (
|
||||
isinstance(flatbuffer_blob_or_path, str)
|
||||
and ".vmfb" in flatbuffer_blob_or_path
|
||||
):
|
||||
vmfb_file_path = flatbuffer_blob_or_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(mmaped_vmfb)
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
tf.write(flatbuffer_blob_or_path)
|
||||
tf.flush()
|
||||
vmfb_file_path = tf.name
|
||||
temp_file_to_unlink = vmfb_file_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
|
||||
return mmaped_vmfb, config, temp_file_to_unlink
|
||||
rt_flags.append(flag)
|
||||
for flag in rt_flags:
|
||||
print(flag)
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
if "rocm" in device:
|
||||
device = "rocm"
|
||||
with DetailLogger(timeout=2.5) as dl:
|
||||
# First get configs.
|
||||
if device_idx is not None:
|
||||
dl.log(f"Mapping device id: {device_idx}")
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
dl.log(f"ireert.get_driver()")
|
||||
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
hal_device_id,
|
||||
allocators=amdshark_args.device_allocator,
|
||||
)
|
||||
dl.log(f"ireert.create_device()")
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = hal_device_id
|
||||
dl.log(f"ireert.Config()")
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
dl.log("get_iree_runtime_config")
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
# (This would arise if we're invoking `load_module` from a AMDSharkInference obj)
|
||||
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
|
||||
# (This would arise if we're invoking `compile` from a AMDSharkInference obj)
|
||||
temp_file_to_unlink = None
|
||||
if isinstance(flatbuffer_blob_or_path, Path):
|
||||
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
|
||||
if (
|
||||
isinstance(flatbuffer_blob_or_path, str)
|
||||
and ".vmfb" in flatbuffer_blob_or_path
|
||||
):
|
||||
vmfb_file_path = flatbuffer_blob_or_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(
|
||||
config.vm_instance, flatbuffer_blob_or_path
|
||||
)
|
||||
vm_modules = []
|
||||
if external_weight_file is not None:
|
||||
index = ireert.ParameterIndex()
|
||||
index.load(external_weight_file)
|
||||
param_module = ireert.create_io_parameters_module(
|
||||
config.vm_instance, index.create_provider(scope="model")
|
||||
)
|
||||
vm_modules.append(param_module)
|
||||
vm_modules.append(mmaped_vmfb)
|
||||
vm_modules.append(
|
||||
ireert.create_hal_module(config.vm_instance, config.device)
|
||||
)
|
||||
dl.log(f"mmap {flatbuffer_blob_or_path}")
|
||||
if "vulkan" in device:
|
||||
# Vulkan pipeline creation consumes significant amount of time.
|
||||
print(
|
||||
"\tCompiling Vulkan shaders. This may take a few minutes."
|
||||
)
|
||||
ctx = ireert.SystemContext(config=config, vm_modules=vm_modules)
|
||||
dl.log(f"ireert.SystemContext created")
|
||||
for flag in amdshark_args.additional_runtime_args:
|
||||
ireert.flags.parse_flags(flag)
|
||||
dl.log(f"module initialized")
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
tf.write(flatbuffer_blob_or_path)
|
||||
tf.flush()
|
||||
vmfb_file_path = tf.name
|
||||
temp_file_to_unlink = vmfb_file_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
|
||||
dl.log(f"mmap temp {vmfb_file_path}")
|
||||
return mmaped_vmfb, config, temp_file_to_unlink
|
||||
|
||||
|
||||
def get_iree_compiled_module(
|
||||
@@ -399,12 +507,24 @@ def get_iree_compiled_module(
|
||||
frontend: str = "torch",
|
||||
model_config_path: str = None,
|
||||
extra_args: list = [],
|
||||
rt_flags: list = [],
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
debug: bool = False,
|
||||
compile_str: bool = False,
|
||||
external_weight_file: str = None,
|
||||
write_to: bool = None,
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, model_config_path, extra_args
|
||||
module=module,
|
||||
device=device,
|
||||
frontend=frontend,
|
||||
model_config_path=model_config_path,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=compile_str,
|
||||
write_to=write_to,
|
||||
)
|
||||
temp_file_to_unlink = None
|
||||
# TODO: Currently mmap=True control flow path has been switched off for mmap.
|
||||
@@ -412,13 +532,22 @@ def get_iree_compiled_module(
|
||||
# we're setting delete=False when creating NamedTemporaryFile. That's why
|
||||
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
|
||||
if mmap:
|
||||
print(f"Will load the compiled module as a mmapped temporary file")
|
||||
if write_to is not None:
|
||||
flatbuffer_blob = write_to
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_blob, device, device_idx
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx,
|
||||
rt_flags,
|
||||
external_weight_file=external_weight_file,
|
||||
)
|
||||
else:
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=device_idx
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=device_idx,
|
||||
rt_flags=rt_flags,
|
||||
external_weight_file=external_weight_file,
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
@@ -433,18 +562,21 @@ def load_flatbuffer(
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
temp_file_to_unlink = None
|
||||
if mmap:
|
||||
print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file")
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_path, device, device_idx
|
||||
flatbuffer_path, device, device_idx, rt_flags
|
||||
)
|
||||
else:
|
||||
with open(os.path.join(flatbuffer_path), "rb") as f:
|
||||
flatbuffer_blob = f.read()
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=device_idx
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=device_idx,
|
||||
rt_flags=rt_flags,
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
@@ -462,10 +594,18 @@ def export_iree_module_to_vmfb(
|
||||
model_config_path: str = None,
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
debug: bool = False,
|
||||
compile_str: bool = False,
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, mlir_dialect, model_config_path, extra_args
|
||||
module=module,
|
||||
device=device,
|
||||
frontend=mlir_dialect,
|
||||
model_config_path=model_config_path,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=compile_str,
|
||||
)
|
||||
if module_name is None:
|
||||
device_name = (
|
||||
@@ -473,9 +613,9 @@ def export_iree_module_to_vmfb(
|
||||
)
|
||||
module_name = f"{mlir_dialect}_{device_name}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
with open(filename, "wb") as f:
|
||||
f.write(flatbuffer_blob)
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
return filename
|
||||
|
||||
|
||||
@@ -500,37 +640,65 @@ def get_results(
|
||||
config,
|
||||
frontend="torch",
|
||||
send_to_host=True,
|
||||
debug_timeout: float = 5.0,
|
||||
device: str = None,
|
||||
):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
result = compiled_vm[function_name](*device_inputs)
|
||||
result_tensors = []
|
||||
if isinstance(result, tuple):
|
||||
if send_to_host:
|
||||
for val in result:
|
||||
result_tensors.append(np.asarray(val, val.dtype))
|
||||
with DetailLogger(debug_timeout) as dl:
|
||||
device_inputs = []
|
||||
if device == "rocm" and hasattr(config, "id"):
|
||||
haldriver = ireert.get_driver("rocm")
|
||||
haldevice = haldriver.create_device(
|
||||
config.id,
|
||||
allocators=amdshark_args.device_allocator,
|
||||
)
|
||||
for input_array in input:
|
||||
dl.log(f"Load to device: {input_array.shape}")
|
||||
device_inputs.append(
|
||||
ireert.asdevicearray(config.device, input_array)
|
||||
)
|
||||
dl.log(f"Invoke function: {function_name}")
|
||||
result = compiled_vm[function_name](*device_inputs)
|
||||
dl.log(f"Invoke complete")
|
||||
result_tensors = []
|
||||
if isinstance(result, tuple):
|
||||
if send_to_host:
|
||||
for val in result:
|
||||
dl.log(f"Result to host: {val.shape}")
|
||||
result_tensors.append(np.asarray(val, val.dtype))
|
||||
else:
|
||||
for val in result:
|
||||
result_tensors.append(val)
|
||||
return result_tensors
|
||||
elif isinstance(result, dict):
|
||||
data = list(result.items())
|
||||
if send_to_host:
|
||||
res = np.array(data, dtype=object)
|
||||
return np.copy(res)
|
||||
return data
|
||||
else:
|
||||
for val in result:
|
||||
result_tensors.append(val)
|
||||
return result_tensors
|
||||
elif isinstance(result, dict):
|
||||
data = list(result.items())
|
||||
if send_to_host:
|
||||
res = np.array(data, dtype=object)
|
||||
return np.copy(res)
|
||||
return data
|
||||
else:
|
||||
if send_to_host and result is not None:
|
||||
return result.to_host()
|
||||
return result
|
||||
if send_to_host and result is not None:
|
||||
dl.log("Result to host")
|
||||
return result.to_host()
|
||||
return result
|
||||
dl.log("Execution complete")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_iree_runtime_config(device):
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
if "metal" in device and amdshark_args.device_allocator == "caching":
|
||||
print(
|
||||
"[WARNING] metal devices can not have a `caching` allocator."
|
||||
"\nUsing default allocator `None`"
|
||||
)
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
allocators=shark_args.device_allocator,
|
||||
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
|
||||
allocators=amdshark_args.device_allocator
|
||||
if "metal" not in device
|
||||
else None,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
return config
|
||||
@@ -14,9 +14,10 @@
|
||||
|
||||
# All the iree_cpu related functionalities go here.
|
||||
|
||||
import functools
|
||||
import subprocess
|
||||
import platform
|
||||
from shark.parser import shark_args
|
||||
from amdshark.parser import amdshark_args
|
||||
|
||||
|
||||
def get_cpu_count():
|
||||
@@ -30,6 +31,7 @@ def get_cpu_count():
|
||||
|
||||
|
||||
# Get the default cpu args.
|
||||
@functools.cache
|
||||
def get_iree_cpu_args():
|
||||
uname = platform.uname()
|
||||
os_name, proc_name = uname.system, uname.machine
|
||||
@@ -42,7 +44,7 @@ def get_iree_cpu_args():
|
||||
elif os_name == "Windows":
|
||||
target_triple = "x86_64-pc-windows-msvc"
|
||||
else:
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dAMDSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
print(f"Target triple found:{target_triple}")
|
||||
return [
|
||||
@@ -51,12 +53,13 @@ def get_iree_cpu_args():
|
||||
|
||||
|
||||
# Get iree runtime flags for cpu
|
||||
@functools.cache
|
||||
def get_iree_cpu_rt_args():
|
||||
default = get_cpu_count()
|
||||
default = default if default <= 8 else default - 2
|
||||
cpu_count = (
|
||||
default
|
||||
if shark_args.task_topology_max_group_count is None
|
||||
else shark_args.task_topology_max_group_count
|
||||
if amdshark_args.task_topology_max_group_count is None
|
||||
else amdshark_args.task_topology_max_group_count
|
||||
)
|
||||
return [f"--task_topology_max_group_count={cpu_count}"]
|
||||
@@ -14,12 +14,19 @@
|
||||
|
||||
# All the iree_gpu related functionalities go here.
|
||||
|
||||
import functools
|
||||
import iree.runtime as ireert
|
||||
import ctypes
|
||||
from shark.parser import shark_args
|
||||
import sys
|
||||
from subprocess import CalledProcessError
|
||||
from amdshark.parser import amdshark_args
|
||||
from amdshark.iree_utils._common import run_cmd
|
||||
|
||||
# TODO: refactor to rocm and cuda utils
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
@functools.cache
|
||||
def get_iree_gpu_args():
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
ireert.flags.parse_flags("--cuda_allow_inline_execution")
|
||||
@@ -28,7 +35,7 @@ def get_iree_gpu_args():
|
||||
if (
|
||||
sm_arch
|
||||
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
|
||||
) and (shark_args.enable_tf32 == True):
|
||||
) and (amdshark_args.enable_tf32 == True):
|
||||
return [
|
||||
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
|
||||
]
|
||||
@@ -36,26 +43,94 @@ def get_iree_gpu_args():
|
||||
return []
|
||||
|
||||
|
||||
def check_rocm_device_arch_in_args(extra_args):
|
||||
# Check if the target arch flag for rocm device present in extra_args
|
||||
for flag in extra_args:
|
||||
if "iree-rocm-target-chip" in flag:
|
||||
flag_arch = flag.split("=")[1]
|
||||
return flag_arch
|
||||
return None
|
||||
|
||||
|
||||
def get_rocm_device_arch(device_num=0, extra_args=[], hip_driver=False):
|
||||
# ROCM Device Arch selection:
|
||||
# 1 : User given device arch using `--iree-rocm-target-chip` flag
|
||||
# 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index <device_num>
|
||||
# 3 : default arch : gfx1100
|
||||
|
||||
arch_in_flag = check_rocm_device_arch_in_args(extra_args)
|
||||
if arch_in_flag is not None:
|
||||
print(
|
||||
f"User Specified rocm target device arch from flag : {arch_in_flag} will be used"
|
||||
)
|
||||
return arch_in_flag
|
||||
|
||||
arch_in_device_dump = None
|
||||
|
||||
# get rocm arch from iree dump devices
|
||||
def get_devices_info_from_dump(dump, driver):
|
||||
from os import linesep
|
||||
|
||||
if driver == "hip":
|
||||
dump_clean = list(
|
||||
filter(
|
||||
lambda s: "AMD" in s,
|
||||
dump.split(linesep),
|
||||
)
|
||||
)
|
||||
else:
|
||||
dump_clean = list(
|
||||
filter(
|
||||
lambda s: f"--device={driver}" in s or "gpu-arch-name:" in s,
|
||||
dump.split(linesep),
|
||||
)
|
||||
)
|
||||
arch_pairs = [
|
||||
(
|
||||
dump_clean[i].split("=")[1].strip(),
|
||||
dump_clean[i + 1].split(":")[1].strip(),
|
||||
)
|
||||
for i in range(0, len(dump_clean), 2)
|
||||
]
|
||||
return arch_pairs
|
||||
|
||||
dump_device_info = None
|
||||
driver = "hip" if hip_driver else "rocm"
|
||||
try:
|
||||
dump_device_info = run_cmd(
|
||||
"iree-run-module --dump_devices=" + driver, raise_err=True
|
||||
)
|
||||
except Exception as e:
|
||||
print("could not execute `iree-run-module --dump_devices=" + driver + "`")
|
||||
|
||||
if dump_device_info is not None:
|
||||
device_num = 0 if device_num is None else device_num
|
||||
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0], driver)
|
||||
if len(device_arch_pairs) > device_num: # can find arch in the list
|
||||
arch_in_device_dump = device_arch_pairs[device_num][1]
|
||||
|
||||
if arch_in_device_dump is not None:
|
||||
print(f"Found ROCm device arch : {arch_in_device_dump}")
|
||||
return arch_in_device_dump
|
||||
|
||||
default_rocm_arch = "gfx1100"
|
||||
print(
|
||||
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
|
||||
"\n or from `iree-run-module --dump_devices` command."
|
||||
f"\nUsing {default_rocm_arch} as ROCm arch for compilation."
|
||||
)
|
||||
return default_rocm_arch
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
def get_iree_rocm_args():
|
||||
def get_iree_rocm_args(device_num=0, extra_args=[], hip_driver=False):
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
# get arch from rocminfo.
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
rocm_arch = re.match(
|
||||
r".*(gfx\w+)",
|
||||
subprocess.check_output(
|
||||
"rocminfo | grep -i 'gfx'", shell=True, text=True
|
||||
),
|
||||
).group(1)
|
||||
print(f"Found rocm arch {rocm_arch}...")
|
||||
return [
|
||||
f"--iree-rocm-target-chip={rocm_arch}",
|
||||
"--iree-rocm-link-bc=true",
|
||||
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
|
||||
]
|
||||
rocm_flags = []
|
||||
if check_rocm_device_arch_in_args(extra_args) is None:
|
||||
rocm_arch = get_rocm_device_arch(device_num, extra_args, hip_driver=hip_driver)
|
||||
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
|
||||
|
||||
return rocm_flags
|
||||
|
||||
# Some constants taken from cuda.h
|
||||
CUDA_SUCCESS = 0
|
||||
@@ -65,6 +140,7 @@ CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13
|
||||
CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_cuda_sm_cc():
|
||||
libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll")
|
||||
for libname in libnames:
|
||||
@@ -14,12 +14,15 @@
|
||||
|
||||
# All the iree_vulkan related functionalities go here.
|
||||
|
||||
from shark.iree_utils._common import run_cmd
|
||||
import functools
|
||||
|
||||
from amdshark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
from sys import platform
|
||||
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
from amdshark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_metal_device_name(device_num=0):
|
||||
iree_device_dump = run_cmd("iree-run-module --dump_devices")
|
||||
iree_device_dump = iree_device_dump[0].split("\n\n")
|
||||
@@ -57,15 +60,7 @@ def get_metal_target_triple(device_name):
|
||||
Returns:
|
||||
str or None: target triple or None if no match found for given name
|
||||
"""
|
||||
# Apple Targets
|
||||
if all(x in device_name for x in ("Apple", "M1")):
|
||||
triple = "m1-moltenvk-macos"
|
||||
elif all(x in device_name for x in ("Apple", "M2")):
|
||||
triple = "m1-moltenvk-macos"
|
||||
|
||||
else:
|
||||
triple = None
|
||||
return triple
|
||||
return "macos"
|
||||
|
||||
|
||||
def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
@@ -81,12 +76,12 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
triple = get_metal_target_triple(metal_device)
|
||||
if triple is not None:
|
||||
print(
|
||||
f"Found metal device {metal_device}. Using metal target triple {triple}"
|
||||
f"Found metal device {metal_device}. Using metal target platform {triple}"
|
||||
)
|
||||
return f"-iree-metal-target-platform={triple}"
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
Contact AMDSHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
or pull up an issue."""
|
||||
)
|
||||
print(f"Target : {metal_device}")
|
||||
@@ -94,24 +89,10 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
|
||||
|
||||
def get_iree_metal_args(device_num=0, extra_args=[]):
|
||||
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
# Add any metal spefic compilation flags here
|
||||
res_metal_flag = []
|
||||
metal_triple_flag = None
|
||||
for arg in extra_args:
|
||||
if "-iree-metal-target-platform=" in arg:
|
||||
print(f"Using target triple {arg} from command line args")
|
||||
metal_triple_flag = arg
|
||||
break
|
||||
|
||||
if metal_triple_flag is None:
|
||||
metal_triple_flag = get_metal_triple_flag(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
|
||||
if metal_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
|
||||
res_metal_flag.append(vulkan_target_env)
|
||||
if len(extra_args) > 0:
|
||||
res_metal_flag.extend(extra_args)
|
||||
return res_metal_flag
|
||||
|
||||
|
||||
76
amdshark/iree_utils/trace.py
Normal file
76
amdshark/iree_utils/trace.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
def _enable_detail_trace() -> bool:
|
||||
return os.getenv("AMDSHARK_DETAIL_TRACE", "0") == "1"
|
||||
|
||||
|
||||
class DetailLogger:
|
||||
"""Context manager which can accumulate detailed log messages.
|
||||
|
||||
Detailed log is only emitted if the operation takes a long time
|
||||
or errors.
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: float):
|
||||
self._timeout = timeout
|
||||
self._messages: List[Tuple[float, str]] = []
|
||||
self._start_time = time.time()
|
||||
self._active = not _enable_detail_trace()
|
||||
self._lock = threading.RLock()
|
||||
self._cond = threading.Condition(self._lock)
|
||||
self._thread = None
|
||||
|
||||
def __enter__(self):
|
||||
self._thread = threading.Thread(target=self._run)
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
with self._lock:
|
||||
self._active = False
|
||||
self._cond.notify()
|
||||
if traceback:
|
||||
self.dump_on_error(f"exception")
|
||||
|
||||
def _run(self):
|
||||
with self._lock:
|
||||
timed_out = not self._cond.wait(self._timeout)
|
||||
if timed_out:
|
||||
self.dump_on_error(f"took longer than {self._timeout}s")
|
||||
|
||||
def log(self, msg):
|
||||
with self._lock:
|
||||
timestamp = time.time()
|
||||
if self._active:
|
||||
self._messages.append((timestamp, msg))
|
||||
else:
|
||||
print(f" +{(timestamp - self._start_time) * 1000}ms: {msg}")
|
||||
|
||||
def dump_on_error(self, summary: str):
|
||||
with self._lock:
|
||||
if self._active:
|
||||
print(f"::: Detailed report ({summary}):")
|
||||
for timestamp, msg in self._messages:
|
||||
print(
|
||||
f" +{(timestamp - self._start_time) * 1000}ms: {msg}"
|
||||
)
|
||||
self._active = False
|
||||
@@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict
|
||||
import functools
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_target_env(vulkan_target_triple):
|
||||
arch, product, os = vulkan_target_triple.split("=")[1].split("-")
|
||||
triple = (arch, product, os)
|
||||
@@ -31,7 +33,7 @@ def get_vulkan_target_env(vulkan_target_triple):
|
||||
device_type = get_device_type(triple)
|
||||
# get capabilities
|
||||
capabilities = get_vulkan_target_capabilities(triple)
|
||||
target_env = f"#vk.target_env<{version}, r({revision}), {extensions}, {vendor}:{device_type}, #vk.caps< {capabilities} >>"
|
||||
target_env = f"<#spirv.vce<{version}, r({revision}), {extensions}>, {vendor}:{device_type}, #spirv.resource_limits< {capabilities} >>"
|
||||
return target_env
|
||||
|
||||
|
||||
@@ -52,76 +54,75 @@ def get_version(triple):
|
||||
return "v1.3"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_extensions(triple):
|
||||
def make_ext_list(ext_list):
|
||||
res = ""
|
||||
for e in ext_list:
|
||||
res += e + ", "
|
||||
res = f"[{res[:-2]}]"
|
||||
return res
|
||||
res = ", ".join(ext_list)
|
||||
return f"[{res}]"
|
||||
|
||||
arch, product, os = triple
|
||||
if arch == "m1":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "valhall":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "adreno":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
if os == "android31":
|
||||
ext.append("VK_KHR_8bit_storage")
|
||||
ext.append("SPV_KHR_8bit_storage")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if get_vendor(triple) == "SwiftShader":
|
||||
ext = ["VK_KHR_storage_buffer_storage_class"]
|
||||
ext = ["SPV_KHR_storage_buffer_storage_class"]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "unknown":
|
||||
ext = [
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
"VK_EXT_subgroup_size_control",
|
||||
]
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("VK_NV_cooperative_matrix")
|
||||
ext.append("SPV_KHR_cooperative_matrix")
|
||||
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
|
||||
ext.append("VK_KHR_shader_integer_dot_product")
|
||||
ext.append("SPV_KHR_shader_integer_dot_product")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vendor(triple):
|
||||
arch, product, os = triple
|
||||
if arch == "unknown":
|
||||
@@ -146,6 +147,7 @@ def get_vendor(triple):
|
||||
return "Unknown"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_device_type(triple):
|
||||
arch, product, _ = triple
|
||||
if arch == "unknown":
|
||||
@@ -166,6 +168,7 @@ def get_device_type(triple):
|
||||
|
||||
# get all the capabilities for the device
|
||||
# TODO: make a dataclass for capabilites and init using vulkaninfo
|
||||
@functools.cache
|
||||
def get_vulkan_target_capabilities(triple):
|
||||
def get_subgroup_val(l):
|
||||
return int(sum([subgroup_feature[sgf] for sgf in l]))
|
||||
@@ -183,13 +186,13 @@ def get_vulkan_target_capabilities(triple):
|
||||
"Quad": 128,
|
||||
"PartitionedNV": 256,
|
||||
}
|
||||
cap["maxComputeSharedMemorySize"] = 16384
|
||||
cap["maxComputeWorkGroupInvocations"] = 128
|
||||
cap["maxComputeWorkGroupSize"] = [128, 128, 64]
|
||||
cap["subgroupSize"] = 32
|
||||
cap["max_compute_shared_memory_size"] = 16384
|
||||
cap["max_compute_workgroup_invocations"] = 128
|
||||
cap["max_compute_workgroup_size"] = [128, 128, 64]
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = ["Basic"]
|
||||
cap["minSubgroupSize"] = None
|
||||
cap["maxSubgroupSize"] = None
|
||||
cap["min_subgroup_size"] = None
|
||||
cap["max_subgroup_size"] = None
|
||||
cap["shaderFloat16"] = False
|
||||
cap["shaderFloat64"] = False
|
||||
cap["shaderInt8"] = False
|
||||
@@ -206,13 +209,13 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["coopmatCases"] = None
|
||||
|
||||
if arch in ["rdna1", "rdna2", "rdna3"]:
|
||||
cap["maxComputeSharedMemorySize"] = 65536
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
cap["max_compute_shared_memory_size"] = 65536
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 64
|
||||
cap["subgroup_size"] = 64
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -241,7 +244,8 @@ def get_vulkan_target_capabilities(triple):
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
|
||||
"m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>",
|
||||
"m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>"
|
||||
]
|
||||
|
||||
if product == "rx5700xt":
|
||||
@@ -249,11 +253,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["storagePushConstant8"] = False
|
||||
|
||||
elif arch in ["rgcn5", "rgcn4", "rgcn3"]:
|
||||
cap["maxComputeSharedMemorySize"] = 65536
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
cap["max_compute_shared_memory_size"] = 65536
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -264,8 +268,8 @@ def get_vulkan_target_capabilities(triple):
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
cap["minSubgroupSize"] = 64
|
||||
cap["maxSubgroupSize"] = 64
|
||||
cap["min_subgroup_size"] = 64
|
||||
cap["max_subgroup_size"] = 64
|
||||
|
||||
if arch == "rgcn5":
|
||||
cap["shaderFloat16"] = True
|
||||
@@ -287,11 +291,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "m1":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -318,11 +322,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "valhall":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 512
|
||||
cap["maxComputeWorkGroupSize"] = [512, 512, 512]
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 512
|
||||
cap["max_compute_workgroup_size"] = [512, 512, 512]
|
||||
|
||||
cap["subgroupSize"] = 16
|
||||
cap["subgroup_size"] = 16
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -349,11 +353,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "arc":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 64]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -382,8 +386,8 @@ def get_vulkan_target_capabilities(triple):
|
||||
|
||||
elif arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
cap["maxComputeSharedMemorySize"] = 16384
|
||||
cap["subgroupSize"] = 4
|
||||
cap["max_compute_shared_memory_size"] = 16384
|
||||
cap["subgroup_size"] = 4
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -394,13 +398,13 @@ def get_vulkan_target_capabilities(triple):
|
||||
]
|
||||
|
||||
elif arch in ["pascal"]:
|
||||
cap["maxComputeSharedMemorySize"] = 49152
|
||||
cap["maxComputeWorkGroupInvocations"] = 1536
|
||||
cap["maxComputeWorkGroupSize"] = [1536, 1024, 64]
|
||||
cap["max_compute_shared_memory_size"] = 49152
|
||||
cap["max_compute_workgroup_invocations"] = 1536
|
||||
cap["max_compute_workgroup_size"] = [1536, 1024, 64]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 32
|
||||
cap["subgroup_size"] = 32
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -428,13 +432,13 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch in ["ampere", "turing"]:
|
||||
cap["maxComputeSharedMemorySize"] = 49152
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
cap["max_compute_shared_memory_size"] = 49152
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 32
|
||||
cap["subgroup_size"] = 32
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -462,17 +466,17 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
]
|
||||
|
||||
elif arch == "adreno":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 64]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -488,14 +492,14 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt16"] = True
|
||||
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
if os == "andorid31":
|
||||
if os == "android31":
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "unknown":
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroup_size"] = 64
|
||||
cap["variablePointers"] = False
|
||||
cap["variablePointersStorageBuffer"] = False
|
||||
else:
|
||||
@@ -518,14 +522,14 @@ def get_vulkan_target_capabilities(triple):
|
||||
res += f"{k} = {'unit' if v == True else None}, "
|
||||
elif isinstance(v, list):
|
||||
if k == "subgroupFeatures":
|
||||
res += f"subgroupFeatures = {get_subgroup_val(v)}: i32, "
|
||||
elif k == "maxComputeWorkGroupSize":
|
||||
res += f"maxComputeWorkGroupSize = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
|
||||
res += f"subgroup_features = {get_subgroup_val(v)}: i32, "
|
||||
elif k == "max_compute_workgroup_size":
|
||||
res += f"max_compute_workgroup_size = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
|
||||
elif k == "coopmatCases":
|
||||
cmc = ""
|
||||
for case in v:
|
||||
cmc += f"#vk.coop_matrix_props<{case}>, "
|
||||
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
|
||||
cmc += f"#spirv.coop_matrix_props_khr<{case}>, "
|
||||
res += f"cooperative_matrix_properties_khr = [{cmc[:-2]}], "
|
||||
else:
|
||||
res += f"{k} = {get_comma_sep_str(v)}, "
|
||||
else:
|
||||
@@ -14,25 +14,48 @@
|
||||
|
||||
# All the iree_vulkan related functionalities go here.
|
||||
|
||||
import functools
|
||||
from os import linesep
|
||||
from shark.iree_utils._common import run_cmd
|
||||
from amdshark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
from sys import platform
|
||||
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
from amdshark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
from amdshark.parser import amdshark_args
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_all_vulkan_devices():
|
||||
from iree.runtime import get_driver
|
||||
|
||||
try:
|
||||
driver = get_driver("vulkan")
|
||||
device_list_src = driver.query_available_devices()
|
||||
except:
|
||||
device_list_src = {}
|
||||
|
||||
return [d["name"] for d in device_list_src]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
|
||||
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing device: {vulkaninfo_list[device_num]}")
|
||||
return vulkaninfo_list[device_num]
|
||||
if isinstance(device_num, int):
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing device: vulkan://{device_num}")
|
||||
vulkan_device_name = vulkaninfo_list[device_num]
|
||||
else:
|
||||
from iree.runtime import get_driver
|
||||
|
||||
vulkan_device_driver = get_driver(device_num)
|
||||
vulkan_device_name = vulkan_device_driver.query_available_devices()[0]
|
||||
print(vulkan_device_name)
|
||||
return vulkan_device_name
|
||||
|
||||
|
||||
def get_os_name():
|
||||
@@ -47,6 +70,7 @@ def get_os_name():
|
||||
return "linux"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_target_triple(device_name):
|
||||
"""This method provides a target triple str for specified vulkan device.
|
||||
|
||||
@@ -56,6 +80,8 @@ def get_vulkan_target_triple(device_name):
|
||||
Returns:
|
||||
str or None: target triple or None if no match found for given name
|
||||
"""
|
||||
|
||||
# TODO: Replace this with a dict or something smarter.
|
||||
system_os = get_os_name()
|
||||
# Apple Targets
|
||||
if all(x in device_name for x in ("Apple", "M1")):
|
||||
@@ -105,8 +131,12 @@ def get_vulkan_target_triple(device_name):
|
||||
# Amd Targets
|
||||
# Linux: Radeon RX 7900 XTX
|
||||
# Windows: AMD Radeon RX 7900 XTX
|
||||
elif all(x in device_name for x in ("RX", "7800")):
|
||||
triple = f"rdna3-7800-{system_os}"
|
||||
elif all(x in device_name for x in ("RX", "7900")):
|
||||
triple = f"rdna3-7900-{system_os}"
|
||||
elif all(x in device_name for x in ("Radeon", "780M")):
|
||||
triple = f"rdna3-780m-{system_os}"
|
||||
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
|
||||
triple = f"rdna3-w7900-{system_os}"
|
||||
elif any(x in device_name for x in ("AMD", "Radeon")):
|
||||
@@ -114,6 +144,8 @@ def get_vulkan_target_triple(device_name):
|
||||
# Intel Targets
|
||||
elif any(x in device_name for x in ("A770", "A750")):
|
||||
triple = f"arc-770-{system_os}"
|
||||
elif "v620" in device_name:
|
||||
triple = f"rdna2-v620-{system_os}"
|
||||
|
||||
# Adreno Targets
|
||||
elif all(x in device_name for x in ("Adreno", "740")):
|
||||
@@ -139,10 +171,10 @@ def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
print(
|
||||
f"Found vulkan device {vulkan_device}. Using target triple {triple}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple={triple}"
|
||||
return f"--iree-vulkan-target-triple={triple}"
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
Contact AMDSHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
or pull up an issue."""
|
||||
)
|
||||
print(f"Target : {vulkan_device}")
|
||||
@@ -153,6 +185,10 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
res_vulkan_flag = []
|
||||
res_vulkan_flag += [
|
||||
"--iree-stream-resource-max-allocation-size=3221225472",
|
||||
"--iree-flow-inline-constants-max-byte-length=0"
|
||||
]
|
||||
vulkan_triple_flag = None
|
||||
for arg in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in arg:
|
||||
@@ -164,13 +200,21 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
res_vulkan_flag += [vulkan_triple_flag]
|
||||
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)
|
||||
res_vulkan_flag.append(vulkan_target_env)
|
||||
return res_vulkan_flag
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_iree_vulkan_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_validation_layers={'true' if amdshark_args.vulkan_debug_utils else 'false'}",
|
||||
f"--vulkan_debug_verbosity={'4' if amdshark_args.vulkan_debug_utils else '0'}"
|
||||
f"--vulkan-robust-buffer-access={'true' if amdshark_args.vulkan_debug_utils else 'false'}",
|
||||
]
|
||||
return vulkan_runtime_flags
|
||||
|
||||
|
||||
def set_iree_vulkan_runtime_flags(flags):
|
||||
for flag in flags:
|
||||
ireert.flags.parse_flags(flag)
|
||||
@@ -18,7 +18,7 @@ This function takes the model mlir file and the tuned config file as input,
|
||||
and output a new mlir file with lowering configs annotated on certain ops.
|
||||
There are two ways to utilize the function:
|
||||
1. Call model_annotation function within another python script
|
||||
from shark.model_annotation import model_annotation
|
||||
from amdshark.model_annotation import model_annotation
|
||||
with create_context() as ctx:
|
||||
module = model_annotation(ctx, input_contents=..., config_path=..., search_op=...)
|
||||
2. Run model_annotation.py directly
|
||||
@@ -14,15 +14,42 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="SHARK runner.")
|
||||
|
||||
class SplitStrToListAction(argparse.Action):
|
||||
def __init__(self, option_strings, dest, *args, **kwargs):
|
||||
super(SplitStrToListAction, self).__init__(
|
||||
option_strings=option_strings, dest=dest, *args, **kwargs
|
||||
)
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
del parser, option_string
|
||||
setattr(namespace, self.dest, shlex.split(" "))
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="AMDSHARK runner.")
|
||||
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Device on which shark_runner runs. options are cpu, cuda, and vulkan",
|
||||
help="Device on which amdshark_runner runs. options are cpu, cuda, and vulkan",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional_compile_args",
|
||||
default=list(),
|
||||
nargs=1,
|
||||
action=SplitStrToListAction,
|
||||
help="Additional arguments to pass to the compiler. These are appended as the last arguments.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional_runtime_args",
|
||||
default=list(),
|
||||
nargs=1,
|
||||
action=SplitStrToListAction,
|
||||
help="Additional arguments to pass to the IREE runtime. These are appended as the last arguments.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_tf32",
|
||||
@@ -55,26 +82,26 @@ parser.add_argument(
|
||||
help="When enabled, pytest bench results will include ONNX benchmark results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shark_prefix",
|
||||
"--amdshark_prefix",
|
||||
default=None,
|
||||
help="gs://shark_tank/<this_flag>/model_directories",
|
||||
help="gs://amdshark_tank/<this_flag>/model_directories",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update_tank",
|
||||
default=True,
|
||||
action="store_true",
|
||||
help="When enabled, SHARK downloader will update local shark_tank if local hash is different from latest upstream hash.",
|
||||
help="When enabled, AMDSHARK downloader will update local amdshark_tank if local hash is different from latest upstream hash.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_update_tank",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled, SHARK downloader will force an update of local shark_tank artifacts for each request.",
|
||||
help="When enabled, AMDSHARK downloader will force an update of local amdshark_tank artifacts for each request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local_tank_cache",
|
||||
default=None,
|
||||
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
|
||||
help="Specify where to save downloaded amdshark_tank artifacts. If this is not set, the default is ~/.local/amdshark_tank/.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -114,7 +141,7 @@ parser.add_argument(
|
||||
"--device_allocator",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
default=["caching"],
|
||||
help="Specifies one or more HAL device allocator specs "
|
||||
"to augment the base device allocator",
|
||||
choices=["debug", "caching"],
|
||||
@@ -126,4 +153,18 @@ parser.add_argument(
|
||||
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
|
||||
)
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
parser.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Profiles vulkan device and collects the .rdc info.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for disabling vulkan validation layers when benchmarking.",
|
||||
)
|
||||
|
||||
amdshark_args, unknown = parser.parse_known_args()
|
||||
@@ -13,25 +13,25 @@
|
||||
# limitations under the License.
|
||||
|
||||
from iree.runtime import query_available_drivers, get_driver
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.shark_inference import SharkInference
|
||||
from amdshark.amdshark_downloader import download_model
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import argparse
|
||||
from shark.iree_utils._common import _IREE_DEVICE_MAP
|
||||
from amdshark.iree_utils._common import _IREE_DEVICE_MAP
|
||||
import multiprocessing
|
||||
from shark.shark_runner import supported_dialects
|
||||
from amdshark.amdshark_runner import supported_dialects
|
||||
import logging
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
IREE_TO_SHARK_DRIVER_MAP = {v: k for k, v in _IREE_DEVICE_MAP.items()}
|
||||
IREE_TO_AMDSHARK_DRIVER_MAP = {v: k for k, v in _IREE_DEVICE_MAP.items()}
|
||||
|
||||
|
||||
def stress_test_compiled_model(
|
||||
shark_module_path: str,
|
||||
amdshark_module_path: str,
|
||||
function_name: str,
|
||||
device: str,
|
||||
inputs: List[np.ndarray],
|
||||
@@ -50,14 +50,14 @@ def stress_test_compiled_model(
|
||||
# We are using execution in a sperate thread in order to be able
|
||||
# to wait with a timeout on the inference operation.
|
||||
module_executor = ThreadPoolExecutor(1)
|
||||
shark_module = module_executor.submit(
|
||||
SharkInference,
|
||||
amdshark_module = module_executor.submit(
|
||||
AMDSharkInference,
|
||||
mlir_module=bytes(),
|
||||
function_name=function_name,
|
||||
device=device,
|
||||
).result()
|
||||
module_executor.submit(
|
||||
shark_module.load_module, shark_module_path
|
||||
amdshark_module.load_module, amdshark_module_path
|
||||
).result()
|
||||
input_batches = [np.repeat(arr, batch_size, axis=0) for arr in inputs]
|
||||
golden_output_batches = np.repeat(golden_out, batch_size, axis=0)
|
||||
@@ -67,7 +67,7 @@ def stress_test_compiled_model(
|
||||
first_iteration_output = None
|
||||
for i in range(max_iterations):
|
||||
output = module_executor.submit(
|
||||
shark_module.forward, input_batches
|
||||
amdshark_module.forward, input_batches
|
||||
).result(inference_timeout_seconds)
|
||||
if first_iteration_output is None:
|
||||
np.testing.assert_array_almost_equal_nulp(
|
||||
@@ -100,9 +100,9 @@ def query_devices(device_types: Optional[List[str]] = None) -> List[str]:
|
||||
devices = []
|
||||
if device_types is None:
|
||||
device_types = [
|
||||
IREE_TO_SHARK_DRIVER_MAP[name]
|
||||
IREE_TO_AMDSHARK_DRIVER_MAP[name]
|
||||
for name in query_available_drivers()
|
||||
if name in IREE_TO_SHARK_DRIVER_MAP
|
||||
if name in IREE_TO_AMDSHARK_DRIVER_MAP
|
||||
]
|
||||
for device_type in device_types:
|
||||
driver = get_driver(_IREE_DEVICE_MAP[device_type])
|
||||
@@ -121,19 +121,19 @@ def query_devices(device_types: Optional[List[str]] = None) -> List[str]:
|
||||
def compile_stress_test_module(
|
||||
device_types: List[str], mlir_model: str, func_name: str, mlir_dialect: str
|
||||
) -> List[str]:
|
||||
shark_module_paths = []
|
||||
amdshark_module_paths = []
|
||||
for device_type in device_types:
|
||||
logging.info(
|
||||
f"Compiling stress test model for device type {device_type}."
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_model,
|
||||
func_name,
|
||||
mlir_dialect=mlir_dialect,
|
||||
device=device_type,
|
||||
)
|
||||
shark_module_paths.append(shark_module.save_module())
|
||||
return shark_module_paths
|
||||
amdshark_module_paths.append(amdshark_module.save_module())
|
||||
return amdshark_module_paths
|
||||
|
||||
|
||||
def stress_test(
|
||||
@@ -169,21 +169,21 @@ def stress_test(
|
||||
# This needs to run in a subprocess because when compiling for CUDA,
|
||||
# some stuff get intialized and cuInit will fail in a forked process
|
||||
# later. It should be just compiling, but alas.
|
||||
shark_module_paths_set = executor.submit(
|
||||
amdshark_module_paths_set = executor.submit(
|
||||
compile_stress_test_module,
|
||||
device_types_set,
|
||||
mlir_model,
|
||||
func_name,
|
||||
mlir_dialect,
|
||||
).result()
|
||||
device_type_shark_module_path_map = {
|
||||
device_type_amdshark_module_path_map = {
|
||||
device_type: module_path
|
||||
for device_type, module_path in zip(
|
||||
device_types_set, shark_module_paths_set
|
||||
device_types_set, amdshark_module_paths_set
|
||||
)
|
||||
}
|
||||
device_name_shark_module_path_map = {
|
||||
device_name: device_type_shark_module_path_map[
|
||||
device_name_amdshark_module_path_map = {
|
||||
device_name: device_type_amdshark_module_path_map[
|
||||
get_device_type(device_name)
|
||||
]
|
||||
for device_name in device_names
|
||||
@@ -193,7 +193,7 @@ def stress_test(
|
||||
# in IREE and a subsequent call to `iree.runtime.SystemContext.add_vm_module`
|
||||
# in a forked process will hang.
|
||||
with multiprocessing.Pool(
|
||||
len(device_name_shark_module_path_map) * oversubscription_factor
|
||||
len(device_name_amdshark_module_path_map) * oversubscription_factor
|
||||
) as process_pool:
|
||||
process_pool.starmap(
|
||||
stress_test_compiled_model,
|
||||
@@ -212,7 +212,7 @@ def stress_test(
|
||||
stress_test_index,
|
||||
)
|
||||
for stress_test_index, (device_name, module_path) in enumerate(
|
||||
list(device_name_shark_module_path_map.items())
|
||||
list(device_name_amdshark_module_path_map.items())
|
||||
* oversubscription_factor
|
||||
)
|
||||
],
|
||||
@@ -1,10 +1,10 @@
|
||||
# RUN: %PYTHON %s
|
||||
import numpy as np
|
||||
from shark.shark_importer import SharkImporter
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
import pytest
|
||||
from shark.parser import shark_args
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.tflite_utils import TFLitePreprocessor
|
||||
from amdshark.parser import amdshark_args
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.tflite_utils import TFLitePreprocessor
|
||||
import sys
|
||||
|
||||
# model_path = "https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1?lite-format=tflite"
|
||||
@@ -66,32 +66,32 @@ class AlbertTfliteModuleTester:
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
amdshark_args.save_mlir = self.save_mlir
|
||||
amdshark_args.save_vmfb = self.save_vmfb
|
||||
tflite_preprocessor = TFLitePreprocessor(model_name="albert_lite_base")
|
||||
|
||||
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
|
||||
inputs = tflite_preprocessor.get_inputs()
|
||||
tflite_interpreter = tflite_preprocessor.get_interpreter()
|
||||
|
||||
my_shark_importer = SharkImporter(
|
||||
my_amdshark_importer = AMDSharkImporter(
|
||||
module=tflite_interpreter,
|
||||
inputs=inputs,
|
||||
frontend="tflite",
|
||||
raw_model_file=raw_model_file_path,
|
||||
)
|
||||
mlir_model, func_name = my_shark_importer.import_mlir()
|
||||
mlir_model, func_name = my_amdshark_importer.import_mlir()
|
||||
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module=mlir_model,
|
||||
function_name=func_name,
|
||||
device=self.device,
|
||||
mlir_dialect="tflite",
|
||||
)
|
||||
|
||||
# Case1: Use shark_importer default generate inputs
|
||||
shark_module.compile()
|
||||
mlir_results = shark_module.forward(inputs)
|
||||
# Case1: Use amdshark_importer default generate inputs
|
||||
amdshark_module.compile()
|
||||
mlir_results = amdshark_module.forward(inputs)
|
||||
## post process results for compare
|
||||
input_details, output_details = tflite_preprocessor.get_model_details()
|
||||
mlir_results = list(mlir_results)
|
||||
@@ -105,14 +105,14 @@ class AlbertTfliteModuleTester:
|
||||
input_details, output_details = tflite_preprocessor.get_model_details()
|
||||
inputs = generate_inputs(input_details) # new inputs
|
||||
|
||||
shark_module = SharkInference(
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module=mlir_model,
|
||||
function_name=func_name,
|
||||
device=self.device,
|
||||
mlir_dialect="tflite",
|
||||
)
|
||||
shark_module.compile()
|
||||
mlir_results = shark_module.forward(inputs)
|
||||
amdshark_module.compile()
|
||||
mlir_results = amdshark_module.forward(inputs)
|
||||
## post process results for compare
|
||||
tflite_results = tflite_preprocessor.get_golden_output()
|
||||
compare_results(mlir_results, tflite_results, output_details)
|
||||
@@ -22,7 +22,7 @@ def test_stress_test():
|
||||
subprocess.check_call(
|
||||
[
|
||||
sys.executable,
|
||||
importlib.util.find_spec("shark.stress_test").origin,
|
||||
importlib.util.find_spec("amdshark.stress_test").origin,
|
||||
"--model=squeezenet1_0",
|
||||
"--devices",
|
||||
"cpu",
|
||||
62
amdshark/tests/test_txt2img_ui.py
Normal file
62
amdshark/tests/test_txt2img_ui.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import unittest
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
export_settings,
|
||||
load_settings,
|
||||
all_gradio_labels,
|
||||
)
|
||||
|
||||
|
||||
class TestExportSettings(unittest.TestCase):
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
@patch("json.dump")
|
||||
def test_export_settings(self, mock_json_dump, mock_file):
|
||||
test_values = ["value1", "value2", "value3"]
|
||||
expected_output = {
|
||||
"txt2img": {
|
||||
label: value
|
||||
for label, value in zip(all_gradio_labels, test_values)
|
||||
}
|
||||
}
|
||||
|
||||
export_settings(*test_values)
|
||||
mock_file.assert_called_once_with("./ui/settings.json", "w")
|
||||
mock_json_dump.assert_called_once_with(
|
||||
expected_output, mock_file(), indent=4
|
||||
)
|
||||
|
||||
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=mock_open,
|
||||
read_data='{"txt2img": {"some_setting": "some_value"}}',
|
||||
)
|
||||
def test_load_settings_file_exists(self, mock_file, mock_json_load):
|
||||
mock_json_load.return_value = {
|
||||
"txt2img": {
|
||||
"txt2img_custom_model": "custom_model_value",
|
||||
"custom_vae": "custom_vae_value",
|
||||
}
|
||||
}
|
||||
|
||||
settings = load_settings()
|
||||
self.assertEqual(settings[0], "custom_model_value")
|
||||
self.assertEqual(settings[1], "custom_vae_value")
|
||||
|
||||
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
|
||||
@patch("builtins.open", side_effect=FileNotFoundError)
|
||||
def test_load_settings_file_not_found(self, mock_file, mock_json_load):
|
||||
settings = load_settings()
|
||||
|
||||
default_lora_weights = "None"
|
||||
self.assertEqual(settings[4], default_lora_weights)
|
||||
|
||||
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="{}")
|
||||
def test_load_settings_key_error(self, mock_file, mock_json_load):
|
||||
mock_json_load.return_value = {}
|
||||
|
||||
settings = load_settings()
|
||||
default_lora_weights = "None"
|
||||
self.assertEqual(settings[4], default_lora_weights)
|
||||
@@ -96,7 +96,7 @@ class TFLitePreprocessor:
|
||||
|
||||
print("Setting up for TMP_WORK_DIR")
|
||||
self.workdir = os.path.join(
|
||||
os.path.dirname(__file__), "./../gen_shark_tank"
|
||||
os.path.dirname(__file__), "./../gen_amdshark_tank"
|
||||
)
|
||||
os.makedirs(self.workdir, exist_ok=True)
|
||||
print(f"TMP_WORK_DIR = {self.workdir}")
|
||||
@@ -28,7 +28,7 @@ from torch_mlir.eager_mode.torch_mlir_tensor import (
|
||||
no_dispatch,
|
||||
)
|
||||
from torch_mlir.eager_mode import torch_mlir_tensor
|
||||
from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
|
||||
from amdshark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
|
||||
|
||||
|
||||
backend = EagerModeIREELinalgOnTensorsBackend("cpu")
|
||||
@@ -16,7 +16,7 @@ from torch_mlir.ir import StringAttr
|
||||
import torch_mlir
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
import tempfile
|
||||
from shark.parser import shark_args
|
||||
from amdshark.parser import amdshark_args
|
||||
import io
|
||||
|
||||
mlir_type_mapping_dict = {
|
||||
48
apps/amdshark_studio/amdshark_studio.spec
Normal file
48
apps/amdshark_studio/amdshark_studio.spec
Normal file
@@ -0,0 +1,48 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from apps.amdshark_studio.studio_imports import pathex, datas, hiddenimports
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
a = Analysis(
|
||||
['web/index.py'],
|
||||
pathex=pathex,
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
module_collection_mode={
|
||||
'gradio': 'py', # Collect gradio package as source .py files
|
||||
},
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='nodai_amdshark_studio',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=False,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
107
apps/amdshark_studio/api/controlnet.py
Normal file
107
apps/amdshark_studio/api/controlnet.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors
|
||||
import os
|
||||
import PIL
|
||||
import numpy as np
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_generated_imgs_path,
|
||||
)
|
||||
from datetime import datetime
|
||||
from PIL import Image
|
||||
from gradio.components.image_editor import (
|
||||
EditorValue,
|
||||
)
|
||||
|
||||
|
||||
class control_adapter:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
):
|
||||
self.model = None
|
||||
|
||||
def export_control_adapter_model(model_keyword):
|
||||
return None
|
||||
|
||||
def export_xl_control_adapter_model(model_keyword):
|
||||
return None
|
||||
|
||||
|
||||
class preprocessors:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
):
|
||||
self.model = None
|
||||
|
||||
def export_controlnet_model(model_keyword):
|
||||
return None
|
||||
|
||||
|
||||
control_adapter_map = {
|
||||
"sd15": {
|
||||
"canny": {"initializer": control_adapter.export_control_adapter_model},
|
||||
"openpose": {"initializer": control_adapter.export_control_adapter_model},
|
||||
"scribble": {"initializer": control_adapter.export_control_adapter_model},
|
||||
"zoedepth": {"initializer": control_adapter.export_control_adapter_model},
|
||||
},
|
||||
"sdxl": {
|
||||
"canny": {"initializer": control_adapter.export_xl_control_adapter_model},
|
||||
},
|
||||
}
|
||||
preprocessor_model_map = {
|
||||
"canny": {"initializer": preprocessors.export_controlnet_model},
|
||||
"openpose": {"initializer": preprocessors.export_controlnet_model},
|
||||
"scribble": {"initializer": preprocessors.export_controlnet_model},
|
||||
"zoedepth": {"initializer": preprocessors.export_controlnet_model},
|
||||
}
|
||||
|
||||
|
||||
class PreprocessorModel:
|
||||
def __init__(
|
||||
self,
|
||||
hf_model_id,
|
||||
device="cpu",
|
||||
):
|
||||
self.model = hf_model_id
|
||||
self.device = device
|
||||
|
||||
def compile(self):
|
||||
print("compile not implemented for preprocessor.")
|
||||
return
|
||||
|
||||
def run(self, inputs):
|
||||
print("run not implemented for preprocessor.")
|
||||
return inputs
|
||||
|
||||
|
||||
def cnet_preview(model, input_image):
|
||||
curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
|
||||
control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints")
|
||||
if not os.path.exists(control_imgs_path):
|
||||
os.mkdir(control_imgs_path)
|
||||
img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png")
|
||||
match model:
|
||||
case "canny":
|
||||
canny = PreprocessorModel("canny")
|
||||
result = canny(
|
||||
np.array(input_image),
|
||||
100,
|
||||
200,
|
||||
)
|
||||
Image.fromarray(result).save(fp=img_dest)
|
||||
return result, img_dest
|
||||
case "openpose":
|
||||
openpose = PreprocessorModel("openpose")
|
||||
result = openpose(np.array(input_image))
|
||||
Image.fromarray(result[0]).save(fp=img_dest)
|
||||
return result, img_dest
|
||||
case "zoedepth":
|
||||
zoedepth = PreprocessorModel("ZoeDepth")
|
||||
result = zoedepth(np.array(input_image))
|
||||
Image.fromarray(result).save(fp=img_dest)
|
||||
return result, img_dest
|
||||
case "scribble":
|
||||
input_image.save(fp=img_dest)
|
||||
return input_image, img_dest
|
||||
case _:
|
||||
return None, None
|
||||
125
apps/amdshark_studio/api/initializers.py
Normal file
125
apps/amdshark_studio/api/initializers.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import importlib
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import warnings
|
||||
import json
|
||||
from threading import Thread
|
||||
|
||||
from apps.amdshark_studio.modules.timer import startup_timer
|
||||
|
||||
from apps.amdshark_studio.web.utils.tmp_configs import (
|
||||
config_tmp,
|
||||
clear_tmp_mlir,
|
||||
clear_tmp_imgs,
|
||||
amdshark_tmp,
|
||||
)
|
||||
|
||||
|
||||
def imports():
|
||||
import torch # noqa: F401
|
||||
|
||||
startup_timer.record("import torch")
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=DeprecationWarning, module="torch"
|
||||
)
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
|
||||
|
||||
import gradio # noqa: F401
|
||||
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
global_obj._init()
|
||||
startup_timer.record("initialize globals")
|
||||
|
||||
from apps.amdshark_studio.modules import (
|
||||
img_processing,
|
||||
) # noqa: F401
|
||||
|
||||
startup_timer.record("other imports")
|
||||
|
||||
|
||||
def initialize():
|
||||
configure_sigint_handler()
|
||||
# Setup to use amdshark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
|
||||
config_tmp()
|
||||
# clear_tmp_mlir()
|
||||
clear_tmp_imgs()
|
||||
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
create_model_folders,
|
||||
)
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
create_model_folders()
|
||||
|
||||
import gradio as gr
|
||||
|
||||
# initialize_rest(reload_script_modules=False)
|
||||
|
||||
|
||||
def initialize_rest(*, reload_script_modules=False):
|
||||
"""
|
||||
Called both from initialize() and when reloading the webui.
|
||||
"""
|
||||
# Keep this for adding reload options to the webUI.
|
||||
|
||||
|
||||
def dumpstacks():
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
id2name = {th.ident: th.name for th in threading.enumerate()}
|
||||
code = []
|
||||
for threadId, stack in sys._current_frames().items():
|
||||
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
|
||||
for filename, lineno, name, line in traceback.extract_stack(stack):
|
||||
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
||||
if line:
|
||||
code.append(" " + line.strip())
|
||||
with open(os.path.join(amdshark_tmp, "stack_dump.log"), "w") as f:
|
||||
f.write("\n".join(code))
|
||||
|
||||
|
||||
def setup_middleware(app):
|
||||
from starlette.middleware.gzip import GZipMiddleware
|
||||
|
||||
app.middleware_stack = (
|
||||
None # reset current middleware to allow modifying user provided list
|
||||
)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
configure_cors_middleware(app)
|
||||
app.build_middleware_stack() # rebuild middleware stack on-the-fly
|
||||
|
||||
|
||||
def configure_cors_middleware(app):
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
cors_options = {
|
||||
"allow_methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
"allow_credentials": True,
|
||||
}
|
||||
if cmd_opts.api_accept_origin:
|
||||
cors_options["allow_origins"] = cmd_opts.api_accept_origin.split(",")
|
||||
|
||||
app.add_middleware(CORSMiddleware, **cors_options)
|
||||
|
||||
|
||||
def configure_sigint_handler():
|
||||
# make the program just exit at ctrl+c without waiting for anything
|
||||
def sigint_handler(sig, frame):
|
||||
print(f"Interrupted with signal {sig} in {frame}")
|
||||
|
||||
dumpstacks()
|
||||
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
475
apps/amdshark_studio/api/llm.py
Normal file
475
apps/amdshark_studio/api/llm.py
Normal file
@@ -0,0 +1,475 @@
|
||||
from turbine_models.custom_models import stateless_llama
|
||||
from turbine_models.model_runner import vmfbRunner
|
||||
from turbine_models.gen_external_params.gen_external_params import gen_external_params
|
||||
import time
|
||||
from amdshark.iree_utils.compile_utils import compile_module_to_flatbuffer
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_resource_path,
|
||||
get_checkpoints_path,
|
||||
)
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.api.utils import parse_device
|
||||
from urllib.request import urlopen
|
||||
import iree.runtime as ireert
|
||||
from itertools import chain
|
||||
import gc
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
llm_model_map = {
|
||||
"meta-llama/Llama-2-7b-chat-hf": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
|
||||
"stop_token": 2,
|
||||
"max_tokens": 4096,
|
||||
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
|
||||
},
|
||||
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
|
||||
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
|
||||
"stop_token": 2,
|
||||
"max_tokens": 4096,
|
||||
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
|
||||
},
|
||||
"TinyPixel/small-llama2": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "TinyPixel/small-llama2",
|
||||
"compile_flags": ["--iree-opt-const-expr-hoisting=True"],
|
||||
"stop_token": 2,
|
||||
"max_tokens": 1024,
|
||||
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
|
||||
},
|
||||
}
|
||||
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<s>", "</s>"
|
||||
|
||||
DEFAULT_CHAT_SYS_PROMPT = """<s>[INST] <<SYS>>
|
||||
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <</SYS>>\n\n
|
||||
"""
|
||||
|
||||
|
||||
def append_user_prompt(history, input_prompt):
|
||||
user_prompt = f"{B_INST} {input_prompt} {E_INST}"
|
||||
history += user_prompt
|
||||
return history
|
||||
|
||||
|
||||
class LanguageModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_auth_token=None,
|
||||
device=None,
|
||||
quantization="int4",
|
||||
precision="",
|
||||
external_weights=None,
|
||||
use_system_prompt=True,
|
||||
streaming_llm=False,
|
||||
):
|
||||
_, _, self.triple = parse_device(device)
|
||||
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
|
||||
self.device = device.split("=>")[-1].strip()
|
||||
self.backend = self.device.split("://")[0]
|
||||
self.driver = self.backend
|
||||
if "cpu" in device:
|
||||
self.device = "cpu"
|
||||
self.backend = "llvm-cpu"
|
||||
self.driver = "local-task"
|
||||
|
||||
print(f"Selected {self.backend} as IREE target backend.")
|
||||
self.precision = "f32" if "cpu" in device else "f16"
|
||||
self.quantization = quantization
|
||||
self.safe_name = self.hf_model_name.replace("/", "_").replace("-", "_")
|
||||
self.external_weight_file = None
|
||||
# TODO: find a programmatic solution for model arch spec instead of hardcoding llama2
|
||||
self.file_spec = "_".join(
|
||||
[
|
||||
self.safe_name,
|
||||
self.precision,
|
||||
]
|
||||
)
|
||||
if self.quantization != "None":
|
||||
self.file_spec += "_" + self.quantization
|
||||
|
||||
if external_weights in ["safetensors", "gguf"]:
|
||||
self.external_weight_file = get_resource_path(
|
||||
os.path.join("..", self.file_spec + "." + external_weights)
|
||||
)
|
||||
else:
|
||||
self.external_weights = None
|
||||
self.external_weight_file = None
|
||||
|
||||
if streaming_llm:
|
||||
# Add streaming suffix to file spec after setting external weights filename.
|
||||
self.file_spec += "_streaming"
|
||||
self.streaming_llm = streaming_llm
|
||||
|
||||
self.tempfile_name = get_resource_path(
|
||||
os.path.join("..", f"{self.file_spec}.tempfile")
|
||||
)
|
||||
# TODO: Tag vmfb with target triple of device instead of HAL backend
|
||||
self.vmfb_name = str(
|
||||
get_resource_path(
|
||||
os.path.join("..", f"{self.file_spec}_{self.backend}.vmfb.tempfile")
|
||||
)
|
||||
)
|
||||
|
||||
self.max_tokens = llm_model_map[model_name]["max_tokens"]
|
||||
self.iree_module_dict = None
|
||||
self.use_system_prompt = use_system_prompt
|
||||
self.global_iter = 0
|
||||
self.prev_token_len = 0
|
||||
self.first_input = True
|
||||
self.hf_auth_token = hf_auth_token
|
||||
if self.external_weight_file is not None:
|
||||
if not os.path.exists(self.external_weight_file):
|
||||
print(
|
||||
f"External weight file {self.external_weight_file} does not exist. Generating..."
|
||||
)
|
||||
gen_external_params(
|
||||
hf_model_name=self.hf_model_name,
|
||||
quantization=self.quantization,
|
||||
weight_path=self.external_weight_file,
|
||||
hf_auth_token=hf_auth_token,
|
||||
precision=self.precision,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"External weight file {self.external_weight_file} found for {self.vmfb_name}"
|
||||
)
|
||||
self.external_weight_file = str(self.external_weight_file)
|
||||
|
||||
if os.path.exists(self.vmfb_name) and (
|
||||
external_weights is None or os.path.exists(str(self.external_weight_file))
|
||||
):
|
||||
self.runner = vmfbRunner(
|
||||
device=self.driver,
|
||||
vmfb_path=self.vmfb_name,
|
||||
external_weight_path=self.external_weight_file,
|
||||
)
|
||||
if self.streaming_llm:
|
||||
self.model = self.runner.ctx.modules.streaming_state_update
|
||||
else:
|
||||
self.model = self.runner.ctx.modules.state_update
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_name,
|
||||
use_fast=False,
|
||||
use_auth_token=hf_auth_token,
|
||||
)
|
||||
elif not os.path.exists(self.tempfile_name):
|
||||
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
|
||||
"initializer"
|
||||
](
|
||||
self.hf_model_name,
|
||||
hf_auth_token,
|
||||
compile_to="torch",
|
||||
external_weights=external_weights,
|
||||
precision=self.precision,
|
||||
quantization=self.quantization,
|
||||
streaming_llm=self.streaming_llm,
|
||||
decomp_attn=True,
|
||||
)
|
||||
with open(self.tempfile_name, "w+") as f:
|
||||
f.write(self.torch_ir)
|
||||
del self.torch_ir
|
||||
gc.collect()
|
||||
self.compile()
|
||||
else:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_name,
|
||||
use_fast=False,
|
||||
use_auth_token=hf_auth_token,
|
||||
)
|
||||
self.compile()
|
||||
# Reserved for running HF torch model as reference.
|
||||
self.hf_mod = None
|
||||
|
||||
def compile(self) -> None:
|
||||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
|
||||
# ONLY architecture/api-specific compile-time flags for each backend, if needed.
|
||||
# hf_model_id-specific global flags currently in model map.
|
||||
flags = []
|
||||
if "cpu" in self.backend:
|
||||
flags.extend(
|
||||
[
|
||||
"--iree-global-opt-enable-quantized-matmul-reassociation",
|
||||
]
|
||||
)
|
||||
elif self.backend == "vulkan":
|
||||
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
|
||||
elif self.backend == "rocm":
|
||||
flags.extend(
|
||||
[
|
||||
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
|
||||
"--iree-llvmgpu-enable-prefetch=true",
|
||||
"--iree-opt-outer-dim-concat=true",
|
||||
"--iree-flow-enable-aggressive-fusion",
|
||||
]
|
||||
)
|
||||
if "gfx9" in self.triple:
|
||||
flags.extend(
|
||||
[
|
||||
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
|
||||
"--iree-codegen-llvmgpu-use-vector-distribution=true",
|
||||
]
|
||||
)
|
||||
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
self.tempfile_name,
|
||||
device=self.device,
|
||||
frontend="auto",
|
||||
model_config_path=None,
|
||||
extra_args=flags,
|
||||
write_to=self.vmfb_name,
|
||||
)
|
||||
self.runner = vmfbRunner(
|
||||
device=self.driver,
|
||||
vmfb_path=self.vmfb_name,
|
||||
external_weight_path=self.external_weight_file,
|
||||
)
|
||||
if self.streaming_llm:
|
||||
self.model = self.runner.ctx.modules.streaming_state_update
|
||||
else:
|
||||
self.model = self.runner.ctx.modules.state_update
|
||||
|
||||
def sanitize_prompt(self, prompt):
|
||||
if isinstance(prompt, list):
|
||||
prompt = list(chain.from_iterable(prompt))
|
||||
prompt = " ".join([x for x in prompt if isinstance(x, str)])
|
||||
prompt = prompt.replace("\n", " ")
|
||||
prompt = prompt.replace("\t", " ")
|
||||
prompt = prompt.replace("\r", " ")
|
||||
if self.use_system_prompt and self.global_iter == 0:
|
||||
prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt)
|
||||
return prompt
|
||||
else:
|
||||
return f"{B_INST} {prompt} {E_INST}"
|
||||
|
||||
def chat(self, prompt):
|
||||
prompt = self.sanitize_prompt(prompt)
|
||||
|
||||
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
def format_out(results):
|
||||
return torch.tensor(results.to_host()[0][0])
|
||||
|
||||
history = []
|
||||
for iter in range(self.max_tokens):
|
||||
if self.streaming_llm:
|
||||
token_slice = max(self.prev_token_len - 1, 0)
|
||||
input_tensor = input_tensor[:, token_slice:]
|
||||
if self.streaming_llm and self.model["get_seq_step"]() > 600:
|
||||
print("Evicting cache space!")
|
||||
self.model["evict_kvcache_space"]()
|
||||
token_len = input_tensor.shape[-1]
|
||||
device_inputs = [
|
||||
ireert.asdevicearray(self.runner.config.device, input_tensor)
|
||||
]
|
||||
if self.first_input or not self.streaming_llm:
|
||||
st_time = time.time()
|
||||
token = self.model["run_initialize"](*device_inputs)
|
||||
total_time = time.time() - st_time
|
||||
token_len += 1
|
||||
self.first_input = False
|
||||
else:
|
||||
st_time = time.time()
|
||||
token = self.model["run_cached_initialize"](*device_inputs)
|
||||
total_time = time.time() - st_time
|
||||
token_len += 1
|
||||
|
||||
history.append(format_out(token))
|
||||
while (
|
||||
format_out(token) != llm_model_map[self.hf_model_name]["stop_token"]
|
||||
and len(history) < self.max_tokens
|
||||
):
|
||||
dec_time = time.time()
|
||||
if self.streaming_llm and self.model["get_seq_step"]() > 600:
|
||||
print("Evicting cache space!")
|
||||
self.model["evict_kvcache_space"]()
|
||||
token = self.model["run_forward"](token)
|
||||
history.append(format_out(token))
|
||||
total_time = time.time() - dec_time
|
||||
yield self.tokenizer.decode(history), total_time
|
||||
|
||||
self.prev_token_len = token_len + len(history)
|
||||
|
||||
if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]:
|
||||
break
|
||||
|
||||
for i in range(len(history)):
|
||||
if type(history[i]) != int:
|
||||
history[i] = int(history[i])
|
||||
result_output = self.tokenizer.decode(history)
|
||||
self.global_iter += 1
|
||||
return result_output, total_time
|
||||
|
||||
# Reference HF model function for sanity checks.
|
||||
def chat_hf(self, prompt):
|
||||
if self.hf_mod is None:
|
||||
self.hf_mod = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_name,
|
||||
torch_dtype=torch.float,
|
||||
token=self.hf_auth_token,
|
||||
)
|
||||
prompt = self.sanitize_prompt(prompt)
|
||||
|
||||
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
history = []
|
||||
for iter in range(self.max_tokens):
|
||||
token_len = input_tensor.shape[-1]
|
||||
if self.first_input:
|
||||
st_time = time.time()
|
||||
result = self.hf_mod(input_tensor)
|
||||
token = torch.argmax(result.logits[:, -1, :], dim=1)
|
||||
total_time = time.time() - st_time
|
||||
token_len += 1
|
||||
pkv = result.past_key_values
|
||||
self.first_input = False
|
||||
|
||||
history.append(int(token))
|
||||
while token != llm_model_map[self.hf_model_name]["stop_token"]:
|
||||
dec_time = time.time()
|
||||
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
|
||||
history.append(int(token))
|
||||
total_time = time.time() - dec_time
|
||||
token = torch.argmax(result.logits[:, -1, :], dim=1)
|
||||
pkv = result.past_key_values
|
||||
yield self.tokenizer.decode(history), total_time
|
||||
|
||||
self.prev_token_len = token_len + len(history)
|
||||
|
||||
if token == llm_model_map[self.hf_model_name]["stop_token"]:
|
||||
break
|
||||
for i in range(len(history)):
|
||||
if type(history[i]) != int:
|
||||
history[i] = int(history[i])
|
||||
result_output = self.tokenizer.decode(history)
|
||||
self.global_iter += 1
|
||||
return result_output, total_time
|
||||
|
||||
|
||||
def get_mfma_spec_path(target_chip, save_dir):
|
||||
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
|
||||
attn_spec = urlopen(url).read().decode("utf-8")
|
||||
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
|
||||
if os.path.exists(spec_path):
|
||||
return spec_path
|
||||
with open(spec_path, "w") as f:
|
||||
f.write(attn_spec)
|
||||
return spec_path
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
from datetime import datetime as dt
|
||||
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
print(f"Input keys : {InputData.keys()}")
|
||||
|
||||
# print(f"model : {InputData['model']}")
|
||||
|
||||
is_chat_completion_api = (
|
||||
"messages" in InputData.keys()
|
||||
) # else it is the legacy `completion` api
|
||||
|
||||
# For Debugging input data from API
|
||||
if is_chat_completion_api:
|
||||
print(f"message -> role : {InputData['messages'][0]['role']}")
|
||||
print(f"message -> content : {InputData['messages'][0]['content']}")
|
||||
else:
|
||||
print(f"prompt : {InputData['prompt']}")
|
||||
|
||||
model_name = (
|
||||
InputData["model"]
|
||||
if "model" in InputData.keys()
|
||||
else "meta-llama/Llama-2-7b-chat-hf"
|
||||
)
|
||||
model_path = llm_model_map[model_name]
|
||||
device = InputData["device"] if "device" in InputData.keys() else "cpu"
|
||||
precision = "fp16"
|
||||
max_tokens = InputData["max_tokens"] if "max_tokens" in InputData.keys() else 4096
|
||||
|
||||
device_id = None
|
||||
if not global_obj.get_llm_obj():
|
||||
print("\n[LOG] Initializing new pipeline...")
|
||||
global_obj.clear_cache()
|
||||
gc.collect()
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "cpu" in device:
|
||||
device = "cpu"
|
||||
precision = "fp32"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
llm_model = LanguageModel(
|
||||
model_name=model_name,
|
||||
hf_auth_token=cmd_opts.hf_auth_token,
|
||||
device=device,
|
||||
quantization=cmd_opts.quantization,
|
||||
external_weights="safetensors",
|
||||
use_system_prompt=True,
|
||||
streaming_llm=False,
|
||||
)
|
||||
global_obj.set_llm_obj(llm_model)
|
||||
else:
|
||||
llm_model = global_obj.get_llm_obj()
|
||||
|
||||
llm_model.max_tokens = max_tokens
|
||||
# TODO: add role dict for different models
|
||||
if is_chat_completion_api:
|
||||
# TODO: add funtionality for multiple messages
|
||||
prompt = append_user_prompt(
|
||||
InputData["messages"][0]["role"], InputData["messages"][0]["content"]
|
||||
)
|
||||
else:
|
||||
prompt = InputData["prompt"]
|
||||
print("prompt = ", prompt)
|
||||
|
||||
for res_op, _ in llm_model.chat(prompt):
|
||||
if is_chat_completion_api:
|
||||
choices = [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": res_op, # since we are yeilding the result
|
||||
},
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
else:
|
||||
choices = [
|
||||
{
|
||||
"text": res_op,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
|
||||
return {
|
||||
"id": end_time,
|
||||
"object": "chat.completion" if is_chat_completion_api else "text_completion",
|
||||
"created": int(end_time),
|
||||
"choices": choices,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
lm = LanguageModel(
|
||||
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
|
||||
hf_auth_token=None,
|
||||
device="cpu-task",
|
||||
external_weights="safetensors",
|
||||
)
|
||||
|
||||
print("model loaded")
|
||||
for i in lm.chat("hi, what are you?"):
|
||||
print(i)
|
||||
505
apps/amdshark_studio/api/sd.py
Normal file
505
apps/amdshark_studio/api/sd.py
Normal file
@@ -0,0 +1,505 @@
|
||||
import gc
|
||||
import torch
|
||||
import gradio as gr
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import copy
|
||||
import importlib.util
|
||||
import sys
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from turbine_models.custom_models.sd_inference.sd_pipeline import AMDSharkSDPipeline
|
||||
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
|
||||
AMDSharkSDXLPipeline,
|
||||
)
|
||||
|
||||
|
||||
from apps.amdshark_studio.api.controlnet import control_adapter_map
|
||||
from apps.amdshark_studio.api.utils import parse_device
|
||||
from apps.amdshark_studio.web.utils.state import status_label
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
safe_name,
|
||||
get_resource_path,
|
||||
get_checkpoints_path,
|
||||
)
|
||||
|
||||
from apps.amdshark_studio.modules.img_processing import (
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
from apps.amdshark_studio.modules.ckpt_processing import (
|
||||
preprocessCKPT,
|
||||
save_irpa,
|
||||
)
|
||||
|
||||
EMPTY_SD_MAP = {
|
||||
"clip": None,
|
||||
"scheduler": None,
|
||||
"unet": None,
|
||||
"vae_decode": None,
|
||||
}
|
||||
|
||||
EMPTY_SDXL_MAP = {
|
||||
"prompt_encoder": None,
|
||||
"scheduled_unet": None,
|
||||
"vae_decode": None,
|
||||
"pipeline": None,
|
||||
"full_pipeline": None,
|
||||
}
|
||||
|
||||
EMPTY_FLAGS = {
|
||||
"clip": None,
|
||||
"unet": None,
|
||||
"vae": None,
|
||||
"pipeline": None,
|
||||
}
|
||||
|
||||
|
||||
def load_script(source, module_name):
|
||||
"""
|
||||
reads file source and loads it as a module
|
||||
|
||||
:param source: file to load
|
||||
:param module_name: name of module to register in sys.modules
|
||||
:return: loaded module
|
||||
"""
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, source)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
# This class is responsible for executing image generation and creating
|
||||
# /managing a set of compiled modules to run Stable Diffusion. The init
|
||||
# aims to be as general as possible, and the class will infer and compile
|
||||
# a list of necessary modules or a combined "pipeline module" for a
|
||||
# specified job based on the inference task.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model_id,
|
||||
height: int,
|
||||
width: int,
|
||||
batch_size: int,
|
||||
steps: int,
|
||||
scheduler: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
target_triple: str = None,
|
||||
custom_vae: str = None,
|
||||
num_loras: int = 0,
|
||||
import_ir: bool = True,
|
||||
is_controlled: bool = False,
|
||||
external_weights: str = "safetensors",
|
||||
):
|
||||
self.precision = precision
|
||||
self.compiled_pipeline = False
|
||||
self.base_model_id = base_model_id
|
||||
self.custom_vae = custom_vae
|
||||
self.is_sdxl = "xl" in self.base_model_id.lower()
|
||||
self.is_custom = ".py" in self.base_model_id.lower()
|
||||
if self.is_custom:
|
||||
custom_module = load_script(
|
||||
os.path.join(get_checkpoints_path("scripts"), self.base_model_id),
|
||||
"custom_pipeline",
|
||||
)
|
||||
self.turbine_pipe = custom_module.StudioPipeline
|
||||
self.model_map = custom_module.MODEL_MAP
|
||||
elif self.is_sdxl:
|
||||
self.turbine_pipe = AMDSharkSDXLPipeline
|
||||
self.model_map = EMPTY_SDXL_MAP
|
||||
else:
|
||||
self.turbine_pipe = AMDSharkSDPipeline
|
||||
self.model_map = EMPTY_SD_MAP
|
||||
max_length = 64
|
||||
target_backend, self.rt_device, triple = parse_device(device, target_triple)
|
||||
pipe_id_list = [
|
||||
safe_name(base_model_id),
|
||||
str(batch_size),
|
||||
str(max_length),
|
||||
f"{str(height)}x{str(width)}",
|
||||
precision,
|
||||
triple,
|
||||
]
|
||||
if num_loras > 0:
|
||||
pipe_id_list.append(str(num_loras) + "lora")
|
||||
if is_controlled:
|
||||
pipe_id_list.append("controlled")
|
||||
if custom_vae:
|
||||
pipe_id_list.append(custom_vae)
|
||||
self.pipe_id = "_".join(pipe_id_list)
|
||||
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
|
||||
self.weights_path = Path(
|
||||
os.path.join(
|
||||
get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)
|
||||
)
|
||||
)
|
||||
if not os.path.exists(self.weights_path):
|
||||
os.mkdir(self.weights_path)
|
||||
|
||||
decomp_attn = True
|
||||
attn_spec = None
|
||||
if triple in ["gfx940", "gfx942", "gfx90a"]:
|
||||
decomp_attn = False
|
||||
attn_spec = "mfma"
|
||||
elif triple in ["gfx1100", "gfx1103", "gfx1150"]:
|
||||
decomp_attn = False
|
||||
attn_spec = "wmma"
|
||||
if triple in ["gfx1103", "gfx1150"]:
|
||||
# external weights have issues on igpu
|
||||
external_weights = None
|
||||
elif target_backend == "llvm-cpu":
|
||||
decomp_attn = False
|
||||
|
||||
self.sd_pipe = self.turbine_pipe(
|
||||
hf_model_name=base_model_id,
|
||||
scheduler_id=scheduler,
|
||||
height=height,
|
||||
width=width,
|
||||
precision=precision,
|
||||
max_length=max_length,
|
||||
batch_size=batch_size,
|
||||
num_inference_steps=steps,
|
||||
device=target_backend,
|
||||
iree_target_triple=triple,
|
||||
ireec_flags=EMPTY_FLAGS,
|
||||
attn_spec=attn_spec,
|
||||
decomp_attn=decomp_attn,
|
||||
pipeline_dir=self.pipeline_dir,
|
||||
external_weights_dir=self.weights_path,
|
||||
external_weights=external_weights,
|
||||
custom_vae=custom_vae,
|
||||
)
|
||||
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
|
||||
gc.collect()
|
||||
|
||||
def prepare_pipe(
|
||||
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
|
||||
):
|
||||
print(f"\n[LOG] Preparing pipeline...")
|
||||
self.is_img2img = False
|
||||
mlirs = copy.deepcopy(self.model_map)
|
||||
vmfbs = copy.deepcopy(self.model_map)
|
||||
weights = copy.deepcopy(self.model_map)
|
||||
if not self.is_sdxl:
|
||||
compiled_pipeline = False
|
||||
self.compiled_pipeline = compiled_pipeline
|
||||
|
||||
if custom_weights:
|
||||
custom_weights = os.path.join(
|
||||
get_checkpoints_path("checkpoints"),
|
||||
safe_name(self.base_model_id.split("/")[-1]),
|
||||
custom_weights,
|
||||
)
|
||||
diffusers_weights_path = preprocessCKPT(custom_weights, self.precision)
|
||||
for key in weights:
|
||||
if key in ["scheduled_unet", "unet"]:
|
||||
unet_weights_path = os.path.join(
|
||||
diffusers_weights_path,
|
||||
"unet",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
weights[key] = save_irpa(unet_weights_path, "unet.")
|
||||
|
||||
elif key in ["clip", "prompt_encoder"]:
|
||||
if not self.is_sdxl:
|
||||
sd1_path = os.path.join(
|
||||
diffusers_weights_path, "text_encoder", "model.safetensors"
|
||||
)
|
||||
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
|
||||
else:
|
||||
clip_1_path = os.path.join(
|
||||
diffusers_weights_path, "text_encoder", "model.safetensors"
|
||||
)
|
||||
clip_2_path = os.path.join(
|
||||
diffusers_weights_path,
|
||||
"text_encoder_2",
|
||||
"model.safetensors",
|
||||
)
|
||||
weights[key] = [
|
||||
save_irpa(clip_1_path, "text_encoder_model_1."),
|
||||
save_irpa(clip_2_path, "text_encoder_model_2."),
|
||||
]
|
||||
|
||||
elif key in ["vae_decode"] and weights[key] is None:
|
||||
vae_weights_path = os.path.join(
|
||||
diffusers_weights_path,
|
||||
"vae",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
weights[key] = save_irpa(vae_weights_path, "vae.")
|
||||
|
||||
vmfbs, weights = self.sd_pipe.check_prepared(
|
||||
mlirs, vmfbs, weights, interactive=False
|
||||
)
|
||||
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
|
||||
self.sd_pipe.load_pipeline(
|
||||
vmfbs, weights, self.rt_device, self.compiled_pipeline
|
||||
)
|
||||
print(
|
||||
"\n[LOG] Pipeline successfully prepared for runtime. Generating images..."
|
||||
)
|
||||
return
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
ondemand,
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
):
|
||||
img = self.sd_pipe.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
1,
|
||||
guidance_scale,
|
||||
seed,
|
||||
return_imgs=True,
|
||||
)
|
||||
return img
|
||||
|
||||
|
||||
def amdshark_sd_fn_dict_input(
|
||||
sd_kwargs: dict,
|
||||
):
|
||||
print("\n[LOG] Submitting Request...")
|
||||
|
||||
for key in sd_kwargs:
|
||||
if sd_kwargs[key] in [None, []]:
|
||||
sd_kwargs[key] = None
|
||||
if sd_kwargs[key] in ["None"]:
|
||||
sd_kwargs[key] = ""
|
||||
if key == "seed":
|
||||
sd_kwargs[key] = int(sd_kwargs[key])
|
||||
|
||||
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
|
||||
if not sd_kwargs["device"]:
|
||||
gr.Warning("No device specified. Please specify a device.")
|
||||
return None, ""
|
||||
if sd_kwargs["height"] not in [512, 1024]:
|
||||
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
|
||||
return None, ""
|
||||
if sd_kwargs["height"] != sd_kwargs["width"]:
|
||||
gr.Warning("Height and width must be the same. This is a temporary limitation.")
|
||||
return None, ""
|
||||
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
|
||||
if sd_kwargs["steps"] > 10:
|
||||
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
|
||||
return None, ""
|
||||
if sd_kwargs["guidance_scale"] > 3:
|
||||
gr.Warning(
|
||||
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
|
||||
)
|
||||
return None, ""
|
||||
if sd_kwargs["target_triple"] == "":
|
||||
if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "":
|
||||
gr.Warning(
|
||||
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
|
||||
)
|
||||
return None, ""
|
||||
|
||||
generated_imgs = yield from amdshark_sd_fn(**sd_kwargs)
|
||||
return generated_imgs
|
||||
|
||||
|
||||
def amdshark_sd_fn(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image: list,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
seed: list,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
base_model_id: str,
|
||||
custom_weights: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
target_triple: str,
|
||||
ondemand: bool,
|
||||
compiled_pipeline: bool,
|
||||
resample_type: str,
|
||||
controlnets: dict,
|
||||
embeddings: dict,
|
||||
):
|
||||
sd_kwargs = locals()
|
||||
if not isinstance(sd_init_image, list):
|
||||
sd_init_image = [sd_init_image]
|
||||
is_img2img = True if sd_init_image[0] is not None else False
|
||||
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
adapters = {}
|
||||
is_controlled = False
|
||||
control_mode = None
|
||||
hints = []
|
||||
num_loras = 0
|
||||
import_ir = True
|
||||
for i in embeddings:
|
||||
num_loras += 1 if embeddings[i] else 0
|
||||
if "model" in controlnets:
|
||||
for i, model in enumerate(controlnets["model"]):
|
||||
if "xl" not in base_model_id.lower():
|
||||
adapters[f"control_adapter_{model}"] = {
|
||||
"hf_id": control_adapter_map["runwayml/stable-diffusion-v1-5"][
|
||||
model
|
||||
],
|
||||
"strength": controlnets["strength"][i],
|
||||
}
|
||||
else:
|
||||
adapters[f"control_adapter_{model}"] = {
|
||||
"hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][
|
||||
model
|
||||
],
|
||||
"strength": controlnets["strength"][i],
|
||||
}
|
||||
if model is not None:
|
||||
is_controlled = True
|
||||
control_mode = controlnets["control_mode"]
|
||||
for i in controlnets["hint"]:
|
||||
hints.append[i]
|
||||
|
||||
submit_pipe_kwargs = {
|
||||
"base_model_id": base_model_id,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"batch_size": batch_size,
|
||||
"precision": precision,
|
||||
"device": device,
|
||||
"target_triple": target_triple,
|
||||
"custom_vae": custom_vae,
|
||||
"num_loras": num_loras,
|
||||
"import_ir": import_ir,
|
||||
"is_controlled": is_controlled,
|
||||
"steps": steps,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
submit_prep_kwargs = {
|
||||
"custom_weights": custom_weights,
|
||||
"adapters": adapters,
|
||||
"embeddings": embeddings,
|
||||
"is_img2img": is_img2img,
|
||||
"compiled_pipeline": compiled_pipeline,
|
||||
}
|
||||
submit_run_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"image": sd_init_image,
|
||||
"strength": strength,
|
||||
"guidance_scale": guidance_scale,
|
||||
"seed": seed,
|
||||
"ondemand": ondemand,
|
||||
"resample_type": resample_type,
|
||||
"control_mode": control_mode,
|
||||
"hints": hints,
|
||||
}
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
|
||||
):
|
||||
print("\n[LOG] Initializing new pipeline...")
|
||||
global_obj.clear_cache()
|
||||
gc.collect()
|
||||
|
||||
# Initializes the pipeline and retrieves IR based on all
|
||||
# parameters that are static in the turbine output format,
|
||||
# which is currently MLIR in the torch dialect.
|
||||
|
||||
sd_pipe = StableDiffusion(
|
||||
**submit_pipe_kwargs,
|
||||
)
|
||||
global_obj.set_sd_obj(sd_pipe)
|
||||
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
|
||||
if (
|
||||
not global_obj.get_prep_kwargs()
|
||||
or global_obj.get_prep_kwargs() != submit_prep_kwargs
|
||||
):
|
||||
global_obj.set_prep_kwargs(submit_prep_kwargs)
|
||||
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
|
||||
|
||||
generated_imgs = []
|
||||
for current_batch in range(batch_count):
|
||||
start_time = time.time()
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
|
||||
if not isinstance(out_imgs, list):
|
||||
out_imgs = [out_imgs]
|
||||
# total_time = time.time() - start_time
|
||||
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
|
||||
# print(f"\n[LOG] {text_output}")
|
||||
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
# break
|
||||
# else:
|
||||
for batch in range(batch_size):
|
||||
save_output_img(
|
||||
out_imgs[batch],
|
||||
seed,
|
||||
sd_kwargs,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
# TODO: make seed changes over batch counts more configurable.
|
||||
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
|
||||
yield generated_imgs, status_label(
|
||||
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
return (generated_imgs, "")
|
||||
|
||||
|
||||
def unload_sd():
|
||||
print("Unloading models.")
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
global_obj.clear_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
print("Inject call to cancel longer API calls.")
|
||||
return
|
||||
|
||||
|
||||
def view_json_file(file_path):
|
||||
content = ""
|
||||
with open(file_path, "r") as fopen:
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
|
||||
def safe_name(name):
|
||||
return name.replace("/", "_").replace("\\", "_").replace(".", "_")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
global_obj._init()
|
||||
|
||||
sd_json = view_json_file(
|
||||
get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json"))
|
||||
)
|
||||
sd_kwargs = json.loads(sd_json)
|
||||
for arg in vars(cmd_opts):
|
||||
if arg in sd_kwargs:
|
||||
sd_kwargs[arg] = getattr(cmd_opts, arg)
|
||||
for i in amdshark_sd_fn_dict_input(sd_kwargs):
|
||||
print(i)
|
||||
389
apps/amdshark_studio/api/utils.py
Normal file
389
apps/amdshark_studio/api/utils.py
Normal file
@@ -0,0 +1,389 @@
|
||||
import numpy as np
|
||||
import json
|
||||
from random import (
|
||||
randint,
|
||||
seed as seed_random,
|
||||
getstate as random_getstate,
|
||||
setstate as random_setstate,
|
||||
)
|
||||
|
||||
from pathlib import Path
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
# TODO: migrate these utils to studio
|
||||
from amdshark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
def get_devices_by_name(driver_name):
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
driver_name = iree_device_map(driver_name)
|
||||
device_list_dict = get_all_devices(driver_name)
|
||||
print(f"{driver_name} devices are available.")
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
cpu_name = get_cpu_info()["brand_raw"]
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_name = (
|
||||
cpu_name if device["name"] == "default" else device["name"]
|
||||
)
|
||||
if "local" in driver_name:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
# for drivers with single devices
|
||||
# let the default device be selected without any indexing
|
||||
if len(device_list_dict) == 1:
|
||||
device_list.append(f"{device_name} => {driver_name}")
|
||||
else:
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
|
||||
from amdshark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
hip_devices = get_devices_by_name("hip")
|
||||
available_devices.extend(hip_devices)
|
||||
|
||||
for idx, device_str in enumerate(available_devices):
|
||||
if "AMD Radeon(TM) Graphics =>" in device_str:
|
||||
igpu_id_candidates = [
|
||||
x.split("w/")[-1].split("=>")[0]
|
||||
for x in available_devices
|
||||
if "M Graphics" in x
|
||||
]
|
||||
for igpu_name in igpu_id_candidates:
|
||||
if igpu_name:
|
||||
available_devices[idx] = device_str.replace(
|
||||
"AMD Radeon(TM) Graphics", igpu_name
|
||||
)
|
||||
break
|
||||
return available_devices
|
||||
|
||||
|
||||
def set_init_device_flags():
|
||||
if "vulkan" in cmd_opts.device:
|
||||
# set runtime flags for vulkan.
|
||||
set_iree_runtime_flags()
|
||||
|
||||
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
||||
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
|
||||
if not cmd_opts.iree_vulkan_target_triple:
|
||||
triple = get_vulkan_target_triple(device_name)
|
||||
if triple is not None:
|
||||
cmd_opts.iree_vulkan_target_triple = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{cmd_opts.iree_vulkan_target_triple}."
|
||||
)
|
||||
elif "cuda" in cmd_opts.device:
|
||||
cmd_opts.device = "cuda"
|
||||
elif "metal" in cmd_opts.device:
|
||||
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
|
||||
if not cmd_opts.iree_metal_target_platform:
|
||||
from amdshark.iree_utils.metal_utils import get_metal_target_triple
|
||||
|
||||
triple = get_metal_target_triple(device_name)
|
||||
if triple is not None:
|
||||
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{cmd_opts.iree_metal_target_platform}."
|
||||
)
|
||||
elif "cpu" in cmd_opts.device:
|
||||
cmd_opts.device = "cpu"
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if cmd_opts.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if cmd_opts.device_allocator_heap_key:
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
def parse_device(device_str, target_override=""):
|
||||
from amdshark.iree_utils.compile_utils import (
|
||||
clean_device_info,
|
||||
get_iree_target_triple,
|
||||
iree_target_map,
|
||||
)
|
||||
|
||||
rt_driver, device_id = clean_device_info(device_str)
|
||||
target_backend = iree_target_map(rt_driver)
|
||||
if device_id:
|
||||
rt_device = f"{rt_driver}://{device_id}"
|
||||
else:
|
||||
rt_device = rt_driver
|
||||
|
||||
if target_override:
|
||||
return target_backend, rt_device, target_override
|
||||
match target_backend:
|
||||
case "vulkan-spirv":
|
||||
triple = get_iree_target_triple(device_str)
|
||||
return target_backend, rt_device, triple
|
||||
case "rocm":
|
||||
triple = get_rocm_target_chip(device_str)
|
||||
return target_backend, rt_device, triple
|
||||
case "llvm-cpu":
|
||||
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
|
||||
|
||||
|
||||
def get_rocm_target_chip(device_str):
|
||||
# TODO: Use a data file to map device_str to target chip.
|
||||
rocm_chip_map = {
|
||||
"6700": "gfx1031",
|
||||
"6800": "gfx1030",
|
||||
"6900": "gfx1030",
|
||||
"7900": "gfx1100",
|
||||
"MI300X": "gfx942",
|
||||
"MI300A": "gfx940",
|
||||
"MI210": "gfx90a",
|
||||
"MI250": "gfx90a",
|
||||
"MI100": "gfx908",
|
||||
"MI50": "gfx906",
|
||||
"MI60": "gfx906",
|
||||
"780M": "gfx1103",
|
||||
}
|
||||
for key in rocm_chip_map:
|
||||
if key in device_str:
|
||||
return rocm_chip_map[key]
|
||||
raise AssertionError(
|
||||
f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/AMD-SHARK-Studio/issues."
|
||||
)
|
||||
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
"""
|
||||
Inputs: driver_name
|
||||
Returns a list of all the available devices for a given driver sorted by
|
||||
the iree path names of the device as in --list_devices option in iree.
|
||||
"""
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return device_list_src
|
||||
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired
|
||||
combination of name/path.
|
||||
"""
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
|
||||
def get_output_value(dev_dict):
|
||||
if key_combination == 1:
|
||||
return f"{driver}://{dev_dict['path']}"
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
|
||||
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
return device_map
|
||||
|
||||
|
||||
def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags = []
|
||||
if len(cmd_opts.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
|
||||
)
|
||||
if "rocm" in cmd_opts.device:
|
||||
from amdshark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
if cmd_opts.iree_constant_folding == False:
|
||||
iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
iree_flags.append(
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
)
|
||||
if cmd_opts.data_tiling == False:
|
||||
iree_flags.append("--iree-opt-data-tiling=False")
|
||||
|
||||
if "vae" not in model:
|
||||
# Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# dims before dispatch formation right now.
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
return iree_flags
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user
|
||||
selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for
|
||||
the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
try:
|
||||
device_mapping = device_map[device]
|
||||
except KeyError:
|
||||
raise ValueError(f"Device '{device}' is not a valid device.")
|
||||
return device_mapping
|
||||
|
||||
def get_devices_by_name(driver_name):
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
driver_name = iree_device_map(driver_name)
|
||||
device_list_dict = get_all_devices(driver_name)
|
||||
print(f"{driver_name} devices are available.")
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
cpu_name = get_cpu_info()["brand_raw"]
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_name = (
|
||||
cpu_name if device["name"] == "default" else device["name"]
|
||||
)
|
||||
if "local" in driver_name:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
# for drivers with single devices
|
||||
# let the default device be selected without any indexing
|
||||
if len(device_list_dict) == 1:
|
||||
device_list.append(f"{device_name} => {driver_name}")
|
||||
else:
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
from amdshark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
return available_devices
|
||||
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the
|
||||
# supported range (including -1)
|
||||
def sanitize_seed(seed: int | str):
|
||||
seed = int(seed)
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
return seed
|
||||
|
||||
|
||||
# take a seed expression in an input format and convert it to
|
||||
# a list of integers, where possible
|
||||
def parse_seed_input(seed_input: str | list | int):
|
||||
if isinstance(seed_input, str):
|
||||
try:
|
||||
seed_input = json.loads(seed_input)
|
||||
except (ValueError, TypeError):
|
||||
seed_input = None
|
||||
|
||||
if isinstance(seed_input, int):
|
||||
return [seed_input]
|
||||
|
||||
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
|
||||
return seed_input
|
||||
|
||||
raise TypeError(
|
||||
"Seed input must be an integer or an array of integers in JSON format"
|
||||
)
|
||||
145
apps/amdshark_studio/modules/ckpt_processing.py
Normal file
145
apps/amdshark_studio/modules/ckpt_processing.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import requests
|
||||
import torch
|
||||
import safetensors
|
||||
from iree.turbine.aot.params import (
|
||||
ParameterArchiveBuilder,
|
||||
)
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from omegaconf import OmegaConf
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
create_vae_diffusers_config,
|
||||
convert_ldm_vae_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"):
|
||||
path = Path(custom_weights)
|
||||
diffusers_path = path.parent.absolute()
|
||||
diffusers_directory_name = os.path.join("diffusers", path.stem + f"_{precision}")
|
||||
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
|
||||
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
|
||||
path_to_diffusers = complete_path_to_diffusers.as_posix()
|
||||
return path_to_diffusers
|
||||
|
||||
|
||||
def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False):
|
||||
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision)
|
||||
if next(Path(path_to_diffusers).iterdir(), None):
|
||||
print("Checkpoint already loaded at : ", path_to_diffusers)
|
||||
return path_to_diffusers
|
||||
else:
|
||||
print(
|
||||
"Diffusers' checkpoint will be identified here : ",
|
||||
path_to_diffusers,
|
||||
)
|
||||
from_safetensors = (
|
||||
True if custom_weights.lower().endswith(".safetensors") else False
|
||||
)
|
||||
# EMA weights usually yield higher quality images for inference but
|
||||
# non-EMA weights have been yielding better results in our case.
|
||||
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
|
||||
# they want to go for EMA weight extraction or not.
|
||||
extract_ema = False
|
||||
print("Loading diffusers' pipeline from original stable diffusion checkpoint")
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path_or_dict=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
num_in_channels=num_in_channels,
|
||||
)
|
||||
if precision == "fp16":
|
||||
pipe.to(dtype=torch.float16)
|
||||
pipe.save_pretrained(path_to_diffusers)
|
||||
del pipe
|
||||
print("Loading complete")
|
||||
return path_to_diffusers
|
||||
|
||||
|
||||
def save_irpa(weights_path, prepend_str):
|
||||
weights = safetensors.torch.load_file(weights_path)
|
||||
archive = ParameterArchiveBuilder()
|
||||
for key in weights.keys():
|
||||
new_key = prepend_str + key
|
||||
archive.add_tensor(new_key, weights[key])
|
||||
|
||||
irpa_file = weights_path.replace(".safetensors", ".irpa")
|
||||
archive.save(irpa_file)
|
||||
return irpa_file
|
||||
|
||||
|
||||
def convert_original_vae(vae_checkpoint):
|
||||
vae_state_dict = {}
|
||||
for key in list(vae_checkpoint.keys()):
|
||||
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
|
||||
|
||||
config_url = (
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
|
||||
"main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=512)
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config)
|
||||
return converted_vae_checkpoint
|
||||
|
||||
|
||||
def process_custom_pipe_weights(custom_weights):
|
||||
if custom_weights != "":
|
||||
if custom_weights.startswith("https://civitai.com/api/"):
|
||||
# download the checkpoint from civitai if we don't already have it
|
||||
weights_path = get_civitai_checkpoint(custom_weights)
|
||||
|
||||
# act as if we were given the local file as custom_weights originally
|
||||
custom_weights_tgt = get_path_to_diffusers_checkpoint(weights_path)
|
||||
custom_weights_params = weights_path
|
||||
|
||||
else:
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
custom_weights_params = custom_weights
|
||||
|
||||
return custom_weights_params, custom_weights_tgt
|
||||
|
||||
|
||||
def get_civitai_checkpoint(url: str):
|
||||
with requests.get(url, allow_redirects=True, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# civitai api returns the filename in the content disposition
|
||||
base_filename = re.findall(
|
||||
'"([^"]*)"', response.headers["Content-Disposition"]
|
||||
)[0]
|
||||
destination_path = Path.cwd() / (cmd_opts.model_dir or "models") / base_filename
|
||||
|
||||
# we don't have this model downloaded yet
|
||||
if not destination_path.is_file():
|
||||
print(f"downloading civitai model from {url} to {destination_path}")
|
||||
|
||||
size = int(response.headers["content-length"], 0)
|
||||
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=65536):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
# we already have this model downloaded
|
||||
else:
|
||||
print(f"civitai model already downloaded to {destination_path}")
|
||||
|
||||
response.close()
|
||||
return destination_path.as_posix()
|
||||
185
apps/amdshark_studio/modules/embeddings.py
Normal file
185
apps/amdshark_studio/modules/embeddings.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import json
|
||||
import safetensors
|
||||
from dataclasses import dataclass
|
||||
from safetensors.torch import load_file
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_checkpoint_pathfile,
|
||||
get_path_stem,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAweight:
|
||||
up: torch.tensor
|
||||
down: torch.tensor
|
||||
mid: torch.tensor
|
||||
alpha: torch.float32 = 1.0
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
state_dict = load_file(use_lora)
|
||||
else:
|
||||
state_dict = torch.load(use_lora)
|
||||
|
||||
# gather the weights from the LoRA in a more convenient form, assumes
|
||||
# everything will have an up.weight.
|
||||
weight_dict: dict[str, LoRAweight] = {}
|
||||
for key in state_dict:
|
||||
if key.startswith(splitting_prefix) and key.endswith("up.weight"):
|
||||
stem = key.split("up.weight")[0]
|
||||
weight_key = stem.removesuffix(".lora_")
|
||||
weight_key = weight_key.removesuffix("_lora_")
|
||||
weight_key = weight_key.removesuffix(".lora_linear_layer.")
|
||||
|
||||
if weight_key not in weight_dict:
|
||||
weight_dict[weight_key] = LoRAweight(
|
||||
state_dict[f"{stem}up.weight"],
|
||||
state_dict[f"{stem}down.weight"],
|
||||
state_dict.get(f"{stem}mid.weight", None),
|
||||
(
|
||||
state_dict[f"{weight_key}.alpha"]
|
||||
/ state_dict[f"{stem}up.weight"].shape[1]
|
||||
if f"{weight_key}.alpha" in state_dict
|
||||
else 1.0
|
||||
),
|
||||
)
|
||||
|
||||
# Directly update weight in model
|
||||
|
||||
# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
|
||||
# and similar code in https://github.com/huggingface/diffusers/issues/3064
|
||||
|
||||
# TODO: handle mid weights (how do they even work?)
|
||||
for key, lora_weight in weight_dict.items():
|
||||
curr_layer = model
|
||||
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
while len(layer_infos) > -1:
|
||||
try:
|
||||
curr_layer = curr_layer.__getattr__(temp_name)
|
||||
if len(layer_infos) > 0:
|
||||
temp_name = layer_infos.pop(0)
|
||||
elif len(layer_infos) == 0:
|
||||
break
|
||||
except Exception:
|
||||
if len(temp_name) > 0:
|
||||
temp_name += "_" + layer_infos.pop(0)
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
|
||||
weight = curr_layer.weight.data
|
||||
scale = lora_weight.alpha * lora_strength
|
||||
if len(weight.size()) == 2:
|
||||
if len(lora_weight.up.shape) == 4:
|
||||
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
change = torch.mm(lora_weight.up, lora_weight.down)
|
||||
elif lora_weight.down.size()[2:4] == (1, 1):
|
||||
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
change = torch.nn.functional.conv2d(
|
||||
lora_weight.down.permute(1, 0, 2, 3),
|
||||
lora_weight.up,
|
||||
).permute(1, 0, 2, 3)
|
||||
|
||||
curr_layer.weight.data += change * scale
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
|
||||
extensions = [".bin", ".safetensors", ".pt"]
|
||||
if not any([extension in use_lora for extension in extensions]):
|
||||
# We assume if it is a HF ID with standalone LoRA weights.
|
||||
unet.load_attn_procs(use_lora)
|
||||
return unet
|
||||
|
||||
main_file_name = get_path_stem(use_lora)
|
||||
if ".bin" in use_lora:
|
||||
main_file_name += ".bin"
|
||||
elif ".safetensors" in use_lora:
|
||||
main_file_name += ".safetensors"
|
||||
elif ".pt" in use_lora:
|
||||
main_file_name += ".pt"
|
||||
else:
|
||||
sys.exit("Only .bin and .safetensors format for LoRA is supported")
|
||||
|
||||
try:
|
||||
dir_name = os.path.dirname(use_lora)
|
||||
unet.load_attn_procs(dir_name, weight_name=main_file_name)
|
||||
return unet
|
||||
except:
|
||||
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)
|
||||
|
||||
|
||||
def update_lora_weight(model, use_lora, model_name, lora_strength=1.0):
|
||||
if "unet" in model_name:
|
||||
return update_lora_weight_for_unet(model, use_lora, lora_strength)
|
||||
try:
|
||||
return processLoRA(model, use_lora, "lora_te_", lora_strength)
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def get_lora_metadata(lora_filename):
|
||||
# get the metadata from the file
|
||||
filename = get_checkpoint_pathfile(lora_filename, "lora")
|
||||
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
|
||||
# guard clause for if there isn't any metadata
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
# metadata is a dictionary of strings, the values of the keys we're
|
||||
# interested in are actually json, and need to be loaded as such
|
||||
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
|
||||
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
|
||||
tag_dirs = [dir for dir in tag_frequencies.keys()]
|
||||
|
||||
# gather the tag frequency information for all the datasets trained
|
||||
all_frequencies = {}
|
||||
for dataset in tag_dirs:
|
||||
frequencies = sorted(
|
||||
[entry for entry in tag_frequencies[dataset].items()],
|
||||
reverse=True,
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
|
||||
# get a figure for the total number of images processed for this dataset
|
||||
# either then number actually listed or in its dataset_dir entry or
|
||||
# the highest frequency's number if that doesn't exist
|
||||
img_count = dataset_dirs.get(dir, {}).get("img_count", frequencies[0][1])
|
||||
|
||||
# add the dataset frequencies to the overall frequencies replacing the
|
||||
# frequency counts on the tags with a percentage/ratio
|
||||
all_frequencies.update(
|
||||
[(entry[0], entry[1] / img_count) for entry in frequencies]
|
||||
)
|
||||
|
||||
trained_model_id = " ".join(
|
||||
[
|
||||
metadata.get("ss_sd_model_hash", ""),
|
||||
metadata.get("ss_sd_model_name", ""),
|
||||
metadata.get("ss_base_model_version", ""),
|
||||
]
|
||||
).strip()
|
||||
|
||||
# return the topmost <count> of all frequencies in all datasets
|
||||
return {
|
||||
"model": trained_model_id,
|
||||
"frequencies": sorted(
|
||||
all_frequencies.items(), reverse=True, key=lambda x: x[1]
|
||||
),
|
||||
}
|
||||
202
apps/amdshark_studio/modules/img_processing.py
Normal file
202
apps/amdshark_studio/modules/img_processing.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from csv import DictWriter
|
||||
from PIL import Image, PngImagePlugin
|
||||
from pathlib import Path
|
||||
from datetime import datetime as dt
|
||||
from base64 import decode
|
||||
|
||||
|
||||
resamplers = {
|
||||
"Lanczos": Image.Resampling.LANCZOS,
|
||||
"Nearest Neighbor": Image.Resampling.NEAREST,
|
||||
"Bilinear": Image.Resampling.BILINEAR,
|
||||
"Bicubic": Image.Resampling.BICUBIC,
|
||||
"Hamming": Image.Resampling.HAMMING,
|
||||
"Box": Image.Resampling.BOX,
|
||||
}
|
||||
|
||||
resampler_list = resamplers.keys()
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed, extra_info=None):
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
if extra_info is None:
|
||||
extra_info = {}
|
||||
generated_imgs_path = Path(
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompt"][0][:15])
|
||||
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
|
||||
|
||||
img_model = extra_info["base_model_id"]
|
||||
if extra_info["custom_weights"] not in [None, "None"]:
|
||||
img_model = Path(os.path.basename(extra_info["custom_weights"])).stem
|
||||
|
||||
img_vae = None
|
||||
if extra_info["custom_vae"]:
|
||||
img_vae = Path(os.path.basename(extra_info["custom_vae"])).stem
|
||||
|
||||
img_loras = None
|
||||
if extra_info["embeddings"]:
|
||||
img_lora = []
|
||||
for i in extra_info["embeddings"]:
|
||||
img_lora += Path(os.path.basename(cmd_opts.use_lora)).stem
|
||||
img_loras = ", ".join(img_lora)
|
||||
|
||||
if cmd_opts.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if cmd_opts.write_metadata_to_png:
|
||||
# Using a conditional expression caused problems, so setting a new
|
||||
# variable for now.
|
||||
# if cmd_opts.use_hiresfix:
|
||||
# png_size_text = (
|
||||
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
|
||||
# )
|
||||
# else:
|
||||
png_size_text = f"{extra_info['width']}x{extra_info['height']}"
|
||||
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{extra_info['prompt'][0]}"
|
||||
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
|
||||
f"\nSteps: {extra_info['steps']},"
|
||||
f"Sampler: {extra_info['scheduler']}, "
|
||||
f"CFG scale: {extra_info['guidance_scale']}, "
|
||||
f"Seed: {img_seed},"
|
||||
f"Size: {png_size_text}, "
|
||||
f"Model: {img_model}, "
|
||||
f"VAE: {img_vae}, "
|
||||
f"LoRA: {img_loras}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
if cmd_opts.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {cmd_opts.output_img_format} is not "
|
||||
f"supported yet. Image saved as png instead."
|
||||
f"Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
# To be as low-impact as possible to the existing CSV format, we append
|
||||
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
|
||||
# importance for each data point. Something to consider.
|
||||
new_entry = {}
|
||||
|
||||
new_entry.update(extra_info)
|
||||
|
||||
csv_mode = "a" if os.path.isfile(csv_path) else "w"
|
||||
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
if csv_mode == "w":
|
||||
dictwriter_obj.writeheader()
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
|
||||
# For stencil, the input image can be of any size, but we need to ensure that
|
||||
# it conforms with our model constraints :-
|
||||
# Both width and height should be in the range of [128, 768] and multiple of 8.
|
||||
# This utility function performs the transformation on the input image while
|
||||
# also maintaining the aspect ratio before sending it to the stencil pipeline.
|
||||
def resize_stencil(image: Image.Image, width, height, resampler_type=None):
|
||||
aspect_ratio = width / height
|
||||
min_size = min(width, height)
|
||||
if min_size < 128:
|
||||
n_size = 128
|
||||
if width == min_size:
|
||||
width = n_size
|
||||
height = n_size / aspect_ratio
|
||||
else:
|
||||
height = n_size
|
||||
width = n_size * aspect_ratio
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
n_width = width // 8
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
|
||||
min_size = min(width, height)
|
||||
if min_size > 768:
|
||||
n_size = 768
|
||||
if width == min_size:
|
||||
height = n_size
|
||||
width = n_size * aspect_ratio
|
||||
else:
|
||||
width = n_size
|
||||
height = n_size / aspect_ratio
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
n_width = width // 8
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
if resampler_type in resamplers:
|
||||
resampler = resamplers[resampler_type]
|
||||
else:
|
||||
resampler = resamplers["Nearest Neighbor"]
|
||||
new_image = image.resize((n_width, n_height), resampler=resampler)
|
||||
return new_image, n_width, n_height
|
||||
|
||||
|
||||
def process_sd_init_image(self, sd_init_image, resample_type):
|
||||
if isinstance(sd_init_image, list):
|
||||
images = []
|
||||
for img in sd_init_image:
|
||||
img, _ = self.process_sd_init_image(img, resample_type)
|
||||
images.append(img)
|
||||
is_img2img = True
|
||||
return images, is_img2img
|
||||
if isinstance(sd_init_image, str):
|
||||
if os.path.isfile(sd_init_image):
|
||||
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
|
||||
image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type)
|
||||
else:
|
||||
image = None
|
||||
is_img2img = False
|
||||
elif isinstance(sd_init_image, Image.Image):
|
||||
image = sd_init_image.convert("RGB")
|
||||
elif sd_init_image:
|
||||
image = sd_init_image["image"].convert("RGB")
|
||||
else:
|
||||
image = None
|
||||
is_img2img = False
|
||||
if image:
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
image = image.resize((self.width, self.height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
is_img2img = True
|
||||
image = image_arr
|
||||
return image, is_img2img
|
||||
37
apps/amdshark_studio/modules/logger.py
Normal file
37
apps/amdshark_studio/modules/logger.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import sys
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, filename, filter=None):
|
||||
self.terminal = sys.stdout
|
||||
self.log = open(filename, "w")
|
||||
self.filter = filter
|
||||
|
||||
def write(self, message):
|
||||
for x in message.split("\n"):
|
||||
if self.filter in x:
|
||||
self.log.write(message)
|
||||
else:
|
||||
self.terminal.write(message)
|
||||
|
||||
def flush(self):
|
||||
self.terminal.flush()
|
||||
self.log.flush()
|
||||
|
||||
def isatty(self):
|
||||
return False
|
||||
|
||||
|
||||
def logger_test(x):
|
||||
print("[LOG] This is a test")
|
||||
print(f"This is another test, without the filter")
|
||||
return x
|
||||
|
||||
|
||||
def read_sd_logs():
|
||||
sys.stdout.flush()
|
||||
with open("amdshark_tmp/sd.log", "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
sys.stdout = Logger("amdshark_tmp/sd.log", filter="[LOG]")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user