mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 14:58:11 -05:00
Compare commits
1 Commits
main
...
minilmLoad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
14a56ca9b0 |
5
.flake8
5
.flake8
@@ -1,5 +0,0 @@
|
||||
[flake8]
|
||||
count = 1
|
||||
show-source = 1
|
||||
select = E9,F63,F7,F82
|
||||
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py
|
||||
37
.github/workflows/gh-pages-releases.yml
vendored
37
.github/workflows/gh-pages-releases.yml
vendored
@@ -1,37 +0,0 @@
|
||||
# See: https://github.com/llvm/torch-mlir/issues/1374
|
||||
name: Publish releases page
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
scrape_and_publish_releases:
|
||||
name: "Scrape and publish releases"
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
# Don't run this in everyone's forks.
|
||||
if: github.repository == 'nod-ai/AMD-SHARK-Studio'
|
||||
|
||||
steps:
|
||||
- name: Checking out repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
token: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
- name: Run scrape releases script
|
||||
run: python ./build_tools/scrape_releases.py nod-ai AMD-SHARK-Studio > /tmp/index.html
|
||||
shell: bash
|
||||
- run: git fetch --all
|
||||
- run: git switch github-pages
|
||||
- run: git config --global user.email "none@none.com"
|
||||
- run: git config --global user.name "nod-ai"
|
||||
- run: mv /tmp/index.html package-index/index.html
|
||||
- run: git add package-index/index.html
|
||||
|
||||
# Only try to make a commit if the file has changed.
|
||||
- run: git diff --cached --exit-code || git commit -m "Update releases."
|
||||
|
||||
- name: GitHub Push
|
||||
uses: ad-m/github-push-action@d91a481090679876dfc4178fef17f286781251df # v0.8.0
|
||||
with:
|
||||
github_token: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
branch: github-pages
|
||||
87
.github/workflows/nightly.yml
vendored
87
.github/workflows/nightly.yml
vendored
@@ -9,68 +9,85 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
windows-build:
|
||||
runs-on: 7950X
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
python-version: ["3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
|
||||
- name: Setup pip cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Compute version
|
||||
shell: powershell
|
||||
run: |
|
||||
$package_version = $(Get-Date -UFormat "%Y%m%d")+"."+${{ github.run_number }}
|
||||
$package_version_ = $(Get-Date -UFormat "%Y%m%d")+"_"+${{ github.run_number }}
|
||||
$tag_name=$package_version
|
||||
echo "package_version=$package_version" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
echo "package_version_=$package_version_" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
echo "tag_name=$tag_name" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
tag_name="${package_version}"
|
||||
echo "package_version=${package_version}" >> $GITHUB_ENV
|
||||
echo "tag_name=${tag_name}" >> $GITHUB_ENV
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: ncipollo/release-action@440c8c1cb0ed28b9f43e4d1d670870f059653174 # v1.16.0
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
tag: ${{ env.tag_name }}
|
||||
name: nod.ai AMDSHARK ${{ env.tag_name }}
|
||||
tag_name: ${{ env.tag_name }}
|
||||
release_name: nod.ai SHARK ${{ env.tag_name }}
|
||||
body: |
|
||||
Automatic snapshot release of nod.ai AMDSHARK.
|
||||
Automatic snapshot release of nod.ai SHARK.
|
||||
draft: true
|
||||
prerelease: true
|
||||
|
||||
- name: Build Package
|
||||
shell: powershell
|
||||
prerelease: false
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
python process_skipfiles.py
|
||||
$env:AMDSHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip install -e .
|
||||
pip freeze -l
|
||||
pyinstaller .\apps\amdshark_studio\amdshark_studio.spec
|
||||
mv ./dist/nodai_amdshark_studio.exe ./dist/nodai_amdshark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\amdshark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_amdshark_studio_${{ env.package_version_ }}.exe
|
||||
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest yapf toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/SHARK-Runtime/releases; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
|
||||
yapf -i --style .style.yapf shark/*.py
|
||||
|
||||
- name: Build and validate the package
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/SHARK-Runtime/releases
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
pytest -k 'not benchmark' --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py --ignore=shark/tests/test_shark_importer.py --ignore=tank/tf/
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@fe47e06814723c7b1bea3a7e46cf93d5f020d0c3 # v3
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ./dist/nodai*
|
||||
#asset_content_type: application/vnd.microsoft.portable-executable
|
||||
assets_path: ./wheelhouse/nodai_*.whl
|
||||
|
||||
- name: Publish Release
|
||||
id: publish_release
|
||||
uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6
|
||||
uses: eregon/publish-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
|
||||
102
.github/workflows/test-models.yml
vendored
Normal file
102
.github/workflows/test-models.yml
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Validate torch-models on Shark Runtime
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-linux:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup pip cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest yapf toml
|
||||
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude lit.cfg.py
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude lit.cfg.py
|
||||
yapf -i --style .style.yapf shark/*.py
|
||||
|
||||
- name: Validate Models
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest -k 'not benchmark' --ignore=tank/tf/ --ignore=shark/tests/test_shark_importer.py
|
||||
|
||||
perf-macOS:
|
||||
runs-on: MacStudio
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Validate Models dependencies
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python3.10 IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest -k 'not benchmark' --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py --ignore=tank/tf/ --ignore=shark/tests/test_shark_importer.py
|
||||
|
||||
perf-linux:
|
||||
runs-on: a100
|
||||
timeout-minutes: 45
|
||||
continue-on-error: true
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup pip cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Validate Models
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --ignore=shark/tests/test_shark_importer.py --ignore=tank/tf/
|
||||
85
.github/workflows/test-studio.yml
vendored
85
.github/workflows/test-studio.yml
vendored
@@ -1,85 +0,0 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Validate AMDShark Studio
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'amdshark/examples/**'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'amdshark/examples/**'
|
||||
workflow_dispatch:
|
||||
|
||||
# Ensure that only a single job or workflow using the same
|
||||
# concurrency group will run at a time. This would cancel
|
||||
# any in-progress jobs in the same github workflow and github
|
||||
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-validate:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [nodai-ubuntu-builder-large]
|
||||
suite: [cpu] #,cuda,vulkan]
|
||||
python-version: ["3.11"]
|
||||
include:
|
||||
- os: nodai-ubuntu-builder-large
|
||||
suite: lint
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Set Environment Variables
|
||||
run: |
|
||||
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Python Version File ${{ matrix.python-version }}
|
||||
run: |
|
||||
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: '${{ matrix.python-version }}'
|
||||
|
||||
- name: Install dependencies
|
||||
if: matrix.suite == 'lint'
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml black
|
||||
|
||||
- name: Lint with flake8
|
||||
if: matrix.suite == 'lint'
|
||||
run: |
|
||||
# black format check
|
||||
black --version
|
||||
black --check apps/amdshark_studio
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
|
||||
--statistics --exclude lit.cfg.py
|
||||
|
||||
- name: Validate Models on CPU
|
||||
if: matrix.suite == 'cpu'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
python${{ matrix.python-version }} -m venv amdshark.venv
|
||||
source amdshark.venv/bin/activate
|
||||
pip install -r requirements.txt --no-cache-dir
|
||||
pip install -e .
|
||||
# Disabled due to hang when exporting test llama2
|
||||
# python apps/amdshark_studio/tests/api_test.py
|
||||
49
.gitignore
vendored
49
.gitignore
vendored
@@ -2,8 +2,6 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.mlir
|
||||
*.vmfb
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
@@ -33,6 +31,7 @@ MANIFEST
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
@@ -159,53 +158,11 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
#.idea/
|
||||
|
||||
# vscode related
|
||||
.vscode
|
||||
|
||||
# AMDShark related artifacts
|
||||
# Shark related artefacts
|
||||
*venv/
|
||||
amdshark_tmp/
|
||||
*.vmfb
|
||||
.use-iree
|
||||
tank/dict_configs.py
|
||||
*.csv
|
||||
reproducers/
|
||||
apps/amdshark_studio/web/configs
|
||||
|
||||
# ORT related artefacts
|
||||
cache_models/
|
||||
onnx_models/
|
||||
|
||||
# Generated images
|
||||
generated_imgs/
|
||||
|
||||
# Custom model related artefacts
|
||||
variants.json
|
||||
/models/
|
||||
*.safetensors
|
||||
|
||||
# models folder
|
||||
apps/stable_diffusion/web/models/
|
||||
|
||||
# model artifacts (AMDSHARK)
|
||||
*.tempfile
|
||||
*.mlir
|
||||
*.vmfb
|
||||
|
||||
# Stencil annotators.
|
||||
stencil_annotator/
|
||||
|
||||
# For DocuChat
|
||||
apps/language_models/langchain/user_path/
|
||||
db_dir_UserData
|
||||
|
||||
# Embeded browser cache and other
|
||||
apps/stable_diffusion/web/EBWebView/
|
||||
|
||||
# Llama2 tokenizer configs
|
||||
llama2_tokenizer_configs/
|
||||
|
||||
# Webview2 runtime artefacts
|
||||
EBWebView/
|
||||
|
||||
6
.gitmodules
vendored
6
.gitmodules
vendored
@@ -1,4 +1,4 @@
|
||||
[submodule "inference/thirdparty/amdshark-runtime"]
|
||||
path = inference/thirdparty/amdshark-runtime
|
||||
url =https://github.com/nod-ai/SRT.git
|
||||
[submodule "inference/thirdparty/shark-runtime"]
|
||||
path = inference/thirdparty/shark-runtime
|
||||
url =https://github.com/nod-ai/SHARK-Runtime.git
|
||||
branch = shark-06032022
|
||||
|
||||
3
.style.yapf
Normal file
3
.style.yapf
Normal file
@@ -0,0 +1,3 @@
|
||||
[style]
|
||||
based_on_style = google
|
||||
column_limit = 80
|
||||
218
LICENSE
218
LICENSE
@@ -1,218 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
---- LLVM Exceptions to the Apache 2.0 License ----
|
||||
|
||||
As an exception, if, as a result of your compiling your source code, portions
|
||||
of this Software are embedded into an Object form of such source code, you
|
||||
may redistribute such embedded portions in such Object form without complying
|
||||
with the conditions of Sections 4(a), 4(b) and 4(d) of the License.
|
||||
|
||||
In addition, if you combine or link compiled forms of this Software with
|
||||
software that is licensed under the GPLv2 ("Combined Software") and if a
|
||||
court of competent jurisdiction determines that the patent provision (Section
|
||||
3), the indemnity provision (Section 9) or other Section of the License
|
||||
conflicts with the conditions of the GPLv2, you may retroactively and
|
||||
prospectively choose to deem waived or otherwise exclude such Section(s) of
|
||||
the License, but only in their entirety and only with respect to the Combined
|
||||
Software.
|
||||
444
README.md
444
README.md
@@ -1,175 +1,29 @@
|
||||
# AMDSHARK
|
||||
# SHARK
|
||||
|
||||
High Performance Machine Learning Distribution
|
||||
High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerators and Heterogeneous Clusters
|
||||
|
||||
<h2>NOTE: This project is not currently maintained.</h2>
|
||||
[](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
|
||||
[](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
|
||||
|
||||
*The latest versions of this project are developments towards a refactor on top of IREE-Turbine. Until further notice, make sure you use an .exe release or a checkout of the `AMDSHARK-1.0` branch, for a working AMDSHARK-Studio*
|
||||
## Communication Channels
|
||||
|
||||
[](https://github.com/nod-ai/AMD-SHARK-Studio/actions/workflows/nightly.yml)
|
||||
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
|
||||
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
<summary>Prerequisites - Drivers </summary>
|
||||
|
||||
#### Install your Windows hardware drivers
|
||||
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
|
||||
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
#### Linux Drivers
|
||||
* MESA / RADV drivers wont work with FP16. Please use the latest AMGPU-PRO drivers (non-pro OSS drivers also wont work) or the latest NVidia Linux Drivers.
|
||||
|
||||
Other users please ensure you have your latest vendor drivers and Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home) and if you are using vulkan check `vulkaninfo` works in a terminal window
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
### Quick Start for AMDSHARK Stable Diffusion for Windows 10/11 Users
|
||||
|
||||
Install the Driver from [Prerequisites](https://github.com/nod-ai/AMD-SHARK-Studio#install-your-hardware-drivers) above
|
||||
|
||||
Download the [stable release](https://github.com/nod-ai/AMD-SHARK-Studio/releases/latest) or the most recent [AMDSHARK 1.0 pre-release](https://github.com/nod-ai/AMD-SHARK-Studio/releases).
|
||||
|
||||
Double click the .exe, or [run from the command line](#running) (recommended), and you should have the [UI](http://localhost:8080/) in the browser.
|
||||
|
||||
If you have custom models put them in a `models/` directory where the .exe is.
|
||||
|
||||
Enjoy.
|
||||
|
||||
<details>
|
||||
<summary>More installation notes</summary>
|
||||
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
|
||||
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
|
||||
|
||||
## Running
|
||||
|
||||
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE)
|
||||
* The first run may take few minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
|
||||
* You will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
|
||||
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/.
|
||||
* If you prefer to always run in the browser, use the `--ui=web` command argument when running the EXE.
|
||||
|
||||
## Stopping
|
||||
|
||||
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Advanced Installation (Only for developers)</summary>
|
||||
|
||||
## Advanced Installation (Windows, Linux and macOS) for developers
|
||||
|
||||
### Windows 10/11 Users
|
||||
|
||||
* Install Git for Windows from [here](https://git-scm.com/download/win) if you don't already have it.
|
||||
|
||||
## Check out the code
|
||||
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/AMD-SHARK-Studio.git
|
||||
cd AMD-SHARK-Studio
|
||||
```
|
||||
|
||||
## Switch to the Correct Branch (IMPORTANT!)
|
||||
|
||||
Currently AMDSHARK is being rebuilt for [Turbine](https://github.com/iree-org/iree-turbine) on the `main` branch. For now you are strongly discouraged from using `main` unless you are working on the rebuild effort, and should not expect the code there to produce a working application for Image Generation, So for now you'll need switch over to the `AMDSHARK-1.0` branch and use the stable code.
|
||||
|
||||
```shell
|
||||
git checkout AMDSHARK-1.0
|
||||
```
|
||||
|
||||
The following setup instructions assume you are on this branch.
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
|
||||
### Windows 10/11 Users
|
||||
|
||||
* Install the latest Python 3.11.x version from [here](https://www.python.org/downloads/windows/)
|
||||
|
||||
#### Allow the install script to run in Powershell
|
||||
```powershell
|
||||
set-executionpolicy remotesigned
|
||||
```
|
||||
|
||||
#### Setup venv and install necessary packages (torch-mlir, nodLabs/AMDShark, ...)
|
||||
```powershell
|
||||
./setup_venv.ps1 #You can re-run this script to get the latest version
|
||||
```
|
||||
|
||||
### Linux / macOS Users
|
||||
|
||||
```shell
|
||||
./setup_venv.sh
|
||||
source amdshark1.venv/bin/activate
|
||||
```
|
||||
|
||||
### Run Stable Diffusion on your device - WebUI
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(amdshark1.venv) PS C:\g\amdshark> cd .\apps\stable_diffusion\web\
|
||||
(amdshark1.venv) PS C:\g\amdshark\apps\stable_diffusion\web> python .\index.py
|
||||
```
|
||||
#### Linux / macOS Users
|
||||
```shell
|
||||
(amdshark1.venv) > cd apps/stable_diffusion/web
|
||||
(amdshark1.venv) > python index.py
|
||||
```
|
||||
|
||||
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
|
||||
|
||||
|
||||
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
|
||||
|
||||
|
||||
|
||||
### Run Stable Diffusion on your device - Commandline
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(amdshark1.venv) PS C:\g\amdshark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
```
|
||||
|
||||
#### Linux / macOS Users
|
||||
```shell
|
||||
python3.11 apps/stable_diffusion/scripts/main.py --app=txt2img --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
```
|
||||
|
||||
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
|
||||
</details>
|
||||
|
||||
The output on a AMD 7900XTX would look something like:
|
||||
|
||||
```shell
|
||||
Average step time: 47.19188690185547ms/it
|
||||
Clip Inference time (ms) = 109.531
|
||||
VAE Inference time (ms): 78.590
|
||||
|
||||
Total image generation time: 2.5788655281066895sec
|
||||
```
|
||||
|
||||
Here are some samples generated:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
Find us on [AMDSHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Binary Installation</summary>
|
||||
|
||||
<summary>Installation (Linux and macOS)</summary>
|
||||
|
||||
### Setup a new pip Virtual Environment
|
||||
|
||||
This step sets up a new VirtualEnv for Python
|
||||
|
||||
|
||||
```shell
|
||||
python --version #Check you have 3.11 on Linux, macOS or Windows Powershell
|
||||
python -m venv amdshark_venv
|
||||
source amdshark_venv/bin/activate # Use amdshark_venv/Scripts/activate on Windows
|
||||
python --version #Check you have 3.7->3.10 on Linux or 3.10 on macOS
|
||||
python -m venv shark_venv
|
||||
source shark_venv/bin/activate
|
||||
|
||||
# If you are using conda create and activate a new conda env
|
||||
|
||||
@@ -177,150 +31,98 @@ source amdshark_venv/bin/activate # Use amdshark_venv/Scripts/activate on Wind
|
||||
python -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
*macOS Metal* users please install https://sdk.lunarg.com/sdk/download/latest/mac/vulkan-sdk.dmg and enable "System wide install"
|
||||
*macOS Metal* users please install https://sdk.lunarg.com/sdk/download/latest/mac/vulkan-sdk.dmg
|
||||
|
||||
### Install AMD-SHARK
|
||||
|
||||
This step pip installs AMD-SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
|
||||
### Install SHARK
|
||||
|
||||
This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10
|
||||
|
||||
```shell
|
||||
pip install nodai-amdshark -f https://nod-ai.github.io/AMD-SHARK-Studio/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip install nodai-shark -f https://github.com/nod-ai/SHARK/releases -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/shark-runtime/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
|
||||
### Run amdshark tank model tests.
|
||||
```shell
|
||||
pytest tank/test_models.py
|
||||
```
|
||||
See tank/README.md for a more detailed walkthrough of our pytest suite and CLI.
|
||||
If you are on an Intel macOS machine you need this [workaround](https://github.com/nod-ai/SHARK/issues/102) for an upstream issue.
|
||||
|
||||
### Download and run Resnet50 sample
|
||||
|
||||
|
||||
```shell
|
||||
curl -O https://raw.githubusercontent.com/nod-ai/AMD-SHARK-Studio/main/amdshark/examples/amdshark_inference/resnet50_script.py
|
||||
curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/resnet50_script.py
|
||||
#Install deps for test script
|
||||
pip install --pre torch torchvision torchaudio tqdm pillow gsutil --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python ./resnet50_script.py --device="cpu" #use cuda or vulkan or metal
|
||||
pip install --pre torch torchvision torchaudio tqdm pillow --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python ./resnet50_script.py --device="cpu" #use cuda or vulkan or metal
|
||||
```
|
||||
|
||||
|
||||
### Download and run BERT (MiniLM) sample
|
||||
```shell
|
||||
curl -O https://raw.githubusercontent.com/nod-ai/AMD-SHARK-Studio/main/amdshark/examples/amdshark_inference/minilm_jit.py
|
||||
curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/minilm_jit.py
|
||||
#Install deps for test script
|
||||
pip install transformers torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
|
||||
python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Development, Testing and Benchmarks</summary>
|
||||
<summary>Source Installation</summary>
|
||||
|
||||
If you want to use Python3.11 and with TF Import tools you can use the environment variables like:
|
||||
Set `USE_IREE=1` to use upstream IREE
|
||||
```
|
||||
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
|
||||
```
|
||||
## Check out the code
|
||||
|
||||
### Run any of the hundreds of AMDSHARK tank models via the test framework
|
||||
```shell
|
||||
python -m amdshark.examples.amdshark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
|
||||
# Or a pytest
|
||||
pytest tank/test_models.py -k "MiniLM"
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
```
|
||||
|
||||
### How to use your locally built IREE / Torch-MLIR with AMDSHARK
|
||||
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
|
||||
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
|
||||
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
|
||||
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
|
||||
for Torch-MLIR.
|
||||
|
||||
How to use your locally built Torch-MLIR with AMDSHARK:
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
```shell
|
||||
1.) Run `./setup_venv.sh in AMDSHARK` and activate `amdshark.venv` virtual env.
|
||||
2.) Run `pip uninstall torch-mlir`.
|
||||
3.) Go to your local Torch-MLIR directory.
|
||||
4.) Activate mlir_venv virtual envirnoment.
|
||||
5.) Run `pip uninstall -r requirements.txt`.
|
||||
6.) Run `pip install -r requirements.txt`.
|
||||
7.) Build Torch-MLIR.
|
||||
8.) Activate amdshark.venv virtual environment from the Torch-MLIR directory.
|
||||
8.) Run `export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples` in the Torch-MLIR directory.
|
||||
9.) Go to the AMDSHARK directory.
|
||||
```
|
||||
Now the AMDSHARK will use your locally build Torch-MLIR repo.
|
||||
|
||||
|
||||
## Benchmarking Dispatches
|
||||
|
||||
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
|
||||
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
|
||||
|
||||
For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
|
||||
```
|
||||
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
|
||||
```
|
||||
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.
|
||||
|
||||
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `AMDSharkInference`, and the benchmarks will be generated when compiled. E.G:
|
||||
|
||||
```
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_model,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
dispatch_benchmarks="all",
|
||||
dispatch_benchmarks_dir="results"
|
||||
)
|
||||
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
|
||||
./setup_venv.sh
|
||||
# Please activate the venv after installation.
|
||||
```
|
||||
|
||||
Output will include:
|
||||
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
|
||||
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
|
||||
- An .mlir file containing the dispatch benchmark
|
||||
- A compiled .vmfb file containing the dispatch benchmark
|
||||
- An .mlir file containing just the hal executable
|
||||
- A compiled .vmfb file of the hal executable
|
||||
- A .txt file containing benchmark output
|
||||
### Run a demo script
|
||||
```shell
|
||||
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
|
||||
```
|
||||
|
||||
|
||||
See tank/README.md for further instructions on how to run model tests and benchmarks from the AMDSHARK tank.
|
||||
### Run all model tests on CPU/GPU/VULKAN/Metal
|
||||
```shell
|
||||
pytest shark/tests/models
|
||||
|
||||
# If on Linux for quicker results:
|
||||
pytest shark/tests/models -n auto
|
||||
```
|
||||
|
||||
### Run all model benchmark tests on CPU/GPU/VULKAN/Metal
|
||||
```shell
|
||||
pytest shark/tests/benchmarks
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>API Reference</summary>
|
||||
|
||||
### AMDShark Inference API
|
||||
### Shark Inference API
|
||||
|
||||
```
|
||||
from shark_runner import SharkInference
|
||||
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
|
||||
# AMDSharkImporter imports mlir file from the torch, tensorflow or tf-lite module.
|
||||
|
||||
mlir_importer = AMDSharkImporter(
|
||||
torch_module,
|
||||
(input),
|
||||
frontend="torch", #tf, #tf-lite
|
||||
)
|
||||
torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
|
||||
|
||||
# AMDSharkInference accepts mlir in linalg, mhlo, and tosa dialect.
|
||||
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
amdshark_module = AMDSharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward((input))
|
||||
shark_module = SharkInference(
|
||||
module = model class.
|
||||
(input,) = inputs to model (must be a torch-tensor)
|
||||
dynamic (boolean) = Pass the input shapes as static or dynamic.
|
||||
device = `cpu`, `gpu` or `vulkan` is supported.
|
||||
tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )
|
||||
shark_module.set_frontend("pytorch") # Use tensorflow, mhlo, linalg, tosa
|
||||
shark_module.compile()
|
||||
|
||||
result = shark_module.forward(inputs)
|
||||
```
|
||||
|
||||
|
||||
### Example demonstrating running MHLO IR.
|
||||
|
||||
```
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from shark.shark_inference import SharkInference
|
||||
import numpy as np
|
||||
|
||||
mhlo_ir = r"""builtin.module {
|
||||
@@ -333,51 +135,115 @@ mhlo_ir = r"""builtin.module {
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
amdshark_module = AMDSharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward((arg0, arg1))
|
||||
|
||||
shark_module = SharkInference(mhlo_ir, (arg0, arg1))
|
||||
shark_module.set_frontend("mhlo")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward((arg0, arg1)))
|
||||
```
|
||||
</details>
|
||||
|
||||
## Examples Using the REST API
|
||||
|
||||
* [Setting up AMDSHARK for use with Blender](./docs/amdshark_sd_blender.md)
|
||||
* [Setting up AMDSHARK for use with Koboldcpp](./docs/amdshark_sd_koboldcpp.md)
|
||||
|
||||
## Supported and Validated Models
|
||||
|
||||
AMDSHARK is maintained to support the latest innovations in ML Models:
|
||||
<details>
|
||||
<summary>PyTorch Models</summary>
|
||||
|
||||
| TF HuggingFace Models | AMDSHARK-CPU | AMDSHARK-CUDA | AMDSHARK-METAL |
|
||||
|---------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DistilBERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| GPT2 | :green_heart: | :green_heart: | :green_heart: |
|
||||
| BLOOM | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Stable Diffusion | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Vision Transformer | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ResNet50 | :green_heart: | :green_heart: | :green_heart: |
|
||||
### Huggingface PyTorch Models
|
||||
|
||||
For a complete list of the models supported in AMDSHARK, please refer to [tank/README.md](https://github.com/nod-ai/AMD-SHARK-Studio/blob/main/tank/README.md).
|
||||
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :heavy_check_mark: (JIT) | :heavy_check_mark: | | |
|
||||
| Albert | :heavy_check_mark: (JIT) | :heavy_check_mark: | | |
|
||||
| BigBird | :heavy_check_mark: (AOT) | | | |
|
||||
| DistilBERT | :heavy_check_mark: (JIT) | :heavy_check_mark: | | |
|
||||
| GPT2 | :x: (AOT) | | | |
|
||||
|
||||
## Communication Channels
|
||||
### Torchvision Models
|
||||
|
||||
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|--------------------|----------------------|----------|----------|-------------|
|
||||
| AlexNet | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| DenseNet121 | :heavy_check_mark: (Script) | | | |
|
||||
| MNasNet1_0 | :heavy_check_mark: (Script) | | | |
|
||||
| MobileNetV2 | :heavy_check_mark: (Script) | | | |
|
||||
| MobileNetV3 | :heavy_check_mark: (Script) | | | |
|
||||
| Unet | :x: (Script) | | | |
|
||||
| Resnet18 | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| Resnet50 | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| Resnet101 | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| Resnext50_32x4d | :heavy_check_mark: (Script) | | | |
|
||||
| ShuffleNet_v2 | :x: (Script) | | | |
|
||||
| SqueezeNet | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| EfficientNet | :heavy_check_mark: (Script) | | | |
|
||||
| Regnet | :heavy_check_mark: (Script) | | | |
|
||||
| Resnest | :x: (Script) | | | |
|
||||
| Vision Transformer | :heavy_check_mark: (Script) | | | |
|
||||
| VGG 16 | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| Wide Resnet | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| RAFT | :x: (JIT) | | | |
|
||||
|
||||
* [AMDSHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the AMDSHARK team and other users
|
||||
* [GitHub issues](https://github.com/nod-ai/AMD-SHARK-Studio/issues): Feature requests, bugs etc
|
||||
For more information refer to [MODEL TRACKING SHEET](https://docs.google.com/spreadsheets/d/15PcjKeHZIrB5LfDyuw7DGEEE8XnQEX2aX8lm8qbxV8A/edit#gid=0)
|
||||
|
||||
### PyTorch Training Models
|
||||
|
||||
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :x: | :x: | | |
|
||||
| FullyConnected | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>JAX Models</summary>
|
||||
|
||||
|
||||
### JAX Models
|
||||
|
||||
| Models | JAX-MHLO lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| DALL-E | :x: | :x: | | |
|
||||
| FullyConnected | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>TFLite Models</summary>
|
||||
|
||||
### TFLite Models
|
||||
|
||||
| Models | TOSA/LinAlg | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :x: | :x: | | |
|
||||
| FullyConnected | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>TF Models</summary>
|
||||
|
||||
### Tensorflow Models
|
||||
|
||||
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :x: | :x: | | |
|
||||
| FullyConnected | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
|
||||
</details>
|
||||
|
||||
## Related Projects
|
||||
|
||||
|
||||
<details>
|
||||
<summary>IREE Project Channels</summary>
|
||||
|
||||
* [Upstream IREE issues](https://github.com/google/iree/issues): Feature requests,
|
||||
bugs, and other work tracking
|
||||
* [Upstream IREE Discord server](https://discord.gg/wEWh6Z9nMU): Daily development
|
||||
* [Upstream IREE Discord server](https://discord.gg/26P4xW4): Daily development
|
||||
discussions with the core team and collaborators
|
||||
* [iree-discuss email list](https://groups.google.com/forum/#!forum/iree-discuss):
|
||||
Announcements, general and low-priority discussion
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>MLIR and Torch-MLIR Project Channels</summary>
|
||||
|
||||
@@ -385,10 +251,10 @@ For a complete list of the models supported in AMDSHARK, please refer to [tank/R
|
||||
* Torch-MLIR Github issues [here](https://github.com/llvm/torch-mlir/issues)
|
||||
* [`torch-mlir` section](https://llvm.discourse.group/c/projects-that-want-to-become-official-llvm-projects/torch-mlir/41) of LLVM Discourse
|
||||
* Weekly meetings on Mondays 9AM PST. See [here](https://discourse.llvm.org/t/community-meeting-developer-hour-refactoring-recurring-meetings/62575) for more information.
|
||||
* [MLIR topic within LLVM Discourse](https://llvm.discourse.group/c/llvm-project/mlir/31) AMDSHARK and IREE is enabled by and heavily relies on [MLIR](https://mlir.llvm.org).
|
||||
* [MLIR topic within LLVM Discourse](https://llvm.discourse.group/c/llvm-project/mlir/31) SHARK and IREE is enabled by and heavily relies on [MLIR](https://mlir.llvm.org).
|
||||
</details>
|
||||
|
||||
|
||||
## License
|
||||
|
||||
nod.ai AMDSHARK is licensed under the terms of the Apache 2.0 License with LLVM Exceptions.
|
||||
nod.ai SHARK is licensed under the terms of the Apache 2.0 License with LLVM Exceptions.
|
||||
See [LICENSE](LICENSE) for more information.
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
from torch._dynamo import register_backend
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_backend
|
||||
def amdshark(model, inputs, *, options):
|
||||
try:
|
||||
from amdshark.dynamo_backend.utils import AMDSharkBackend
|
||||
except ImportError:
|
||||
log.exception(
|
||||
"Unable to import AMDSHARK - High Performance Machine Learning Distribution"
|
||||
"Please install the right version of AMDSHARK that matches the PyTorch version being used. "
|
||||
"Refer to https://github.com/nod-ai/AMD-SHARK-Studio/ for details."
|
||||
)
|
||||
raise
|
||||
return AMDSharkBackend(model, inputs, options)
|
||||
|
||||
|
||||
def has_amdshark():
|
||||
try:
|
||||
importlib.import_module("amdshark")
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
@@ -1,501 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from amdshark.amdshark_runner import AMDSharkRunner
|
||||
from amdshark.iree_utils.compile_utils import (
|
||||
export_iree_module_to_vmfb,
|
||||
load_flatbuffer,
|
||||
get_iree_runtime_config,
|
||||
)
|
||||
from amdshark.iree_utils.benchmark_utils import (
|
||||
build_benchmark_args,
|
||||
run_benchmark_module,
|
||||
)
|
||||
from amdshark.parser import amdshark_args
|
||||
from datetime import datetime
|
||||
import time
|
||||
from typing import Optional
|
||||
import csv
|
||||
import os
|
||||
|
||||
TF_CPU_DEVICE = "/CPU:0"
|
||||
TF_GPU_DEVICE = "/GPU:0"
|
||||
|
||||
|
||||
def _bytes_to_mb_str(bytes_: Optional[int]) -> str:
|
||||
return "" if bytes_ is None else f"{bytes_ / 1e6:.6f}"
|
||||
|
||||
|
||||
class OnnxFusionOptions(object):
|
||||
def __init__(self):
|
||||
self.disable_gelu = False
|
||||
self.disable_layer_norm = False
|
||||
self.disable_attention = False
|
||||
self.disable_skip_layer_norm = False
|
||||
self.disable_embed_layer_norm = False
|
||||
self.disable_bias_skip_layer_norm = False
|
||||
self.disable_bias_gelu = False
|
||||
self.enable_gelu_approximation = False
|
||||
self.use_mask_index = False
|
||||
self.no_attention_mask = False
|
||||
|
||||
|
||||
def check_requirements(frontend):
|
||||
import importlib
|
||||
|
||||
has_pkgs = False
|
||||
if frontend == "torch":
|
||||
tv_spec = importlib.util.find_spec("torchvision")
|
||||
has_pkgs = tv_spec is not None
|
||||
|
||||
elif frontend in ["tensorflow", "tf"]:
|
||||
keras_spec = importlib.util.find_spec("keras")
|
||||
tf_spec = importlib.util.find_spec("tensorflow")
|
||||
has_pkgs = keras_spec is not None and tf_spec is not None
|
||||
|
||||
return has_pkgs
|
||||
|
||||
|
||||
class AMDSharkBenchmarkRunner(AMDSharkRunner):
|
||||
# AMDSharkRunner derived class with Benchmarking capabilities.
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: bytes,
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
extra_args: list = [],
|
||||
):
|
||||
self.device = amdshark_args.device if device == "none" else device
|
||||
self.enable_tf32 = amdshark_args.enable_tf32
|
||||
self.frontend_model = None
|
||||
self.vmfb_file = None
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
self.import_args = {}
|
||||
self.temp_file_to_unlink = None
|
||||
if not os.path.isfile(mlir_module):
|
||||
print(
|
||||
"Warning: Initializing AMDSharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize AMDSharkInference with a path to a MLIR module on your hard disk instead."
|
||||
)
|
||||
self.compile_str = True
|
||||
else:
|
||||
self.compile_str = False
|
||||
AMDSharkRunner.__init__(
|
||||
self,
|
||||
mlir_module,
|
||||
device,
|
||||
self.mlir_dialect,
|
||||
self.extra_args,
|
||||
compile_vmfb=False,
|
||||
)
|
||||
self.vmfb_file = export_iree_module_to_vmfb(
|
||||
mlir_module,
|
||||
device,
|
||||
".",
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
params = load_flatbuffer(
|
||||
self.vmfb_file,
|
||||
device,
|
||||
mmap=True,
|
||||
)
|
||||
self.iree_compilation_module = params["vmfb"]
|
||||
self.iree_config = params["config"]
|
||||
self.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
|
||||
def setup_cl(self, input_tensors):
|
||||
self.benchmark_cl = build_benchmark_args(
|
||||
self.vmfb_file,
|
||||
self.device,
|
||||
input_tensors,
|
||||
mlir_dialect=self.mlir_dialect,
|
||||
)
|
||||
|
||||
def benchmark_frontend(self, modelname):
|
||||
if self.mlir_dialect in ["linalg", "torch"]:
|
||||
return self.benchmark_torch(modelname)
|
||||
|
||||
elif self.mlir_dialect in ["mhlo", "tf"]:
|
||||
return self.benchmark_tf(modelname)
|
||||
|
||||
def benchmark_torch(self, modelname, device="cpu"):
|
||||
import torch
|
||||
from tank.model_utils import get_torch_model
|
||||
|
||||
# TODO: Pass this as an arg. currently the best way is to setup with BENCHMARK=1 if we want to use torch+cuda, else use cpu.
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if device == "cuda":
|
||||
torch.set_default_device("cuda:0")
|
||||
# if self.enable_tf32:
|
||||
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
torch.set_default_dtype(torch.float32)
|
||||
torch.set_default_device("cpu")
|
||||
torch_device = torch.device("cuda:0" if device == "cuda" else "cpu")
|
||||
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
|
||||
frontend_model = HFmodel.model
|
||||
frontend_model.to(torch_device)
|
||||
if device == "cuda":
|
||||
frontend_model.cuda()
|
||||
input.to(torch.device("cuda:0"))
|
||||
print(input)
|
||||
else:
|
||||
frontend_model.cpu()
|
||||
input.cpu()
|
||||
|
||||
for i in range(amdshark_args.num_warmup_iterations):
|
||||
frontend_model.forward(input)
|
||||
|
||||
if device == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
begin = time.time()
|
||||
for i in range(amdshark_args.num_iterations):
|
||||
out = frontend_model.forward(input)
|
||||
end = time.time()
|
||||
if device == "cuda":
|
||||
stats = torch.cuda.memory_stats()
|
||||
device_peak_b = stats["allocated_bytes.all.peak"]
|
||||
frontend_model.to(torch.device("cpu"))
|
||||
input.to(torch.device("cpu"))
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
device_peak_b = None
|
||||
|
||||
print(
|
||||
f"Torch benchmark:{amdshark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{amdshark_args.num_iterations}"
|
||||
)
|
||||
if device == "cuda":
|
||||
# Set device to CPU so we don't run into segfaults exiting pytest subprocesses.
|
||||
torch_device = torch.device("cpu")
|
||||
return [
|
||||
f"{amdshark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/amdshark_args.num_iterations)*1000}",
|
||||
"", # host_peak_b (CPU usage) is not reported by PyTorch.
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
|
||||
def benchmark_tf(self, modelname):
|
||||
import os
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
import tensorflow as tf
|
||||
|
||||
visible_default = tf.config.list_physical_devices("GPU")
|
||||
try:
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
visible_devices = tf.config.get_visible_devices()
|
||||
for device in visible_devices:
|
||||
assert device.device_type != "GPU"
|
||||
except:
|
||||
# Invalid device or cannot modify virtual devices once initialized.
|
||||
pass
|
||||
|
||||
from tank.model_utils_tf import get_tf_model
|
||||
|
||||
# tf_device = TF_GPU_DEVICE if self.device == "cuda" else TF_CPU_DEVICE
|
||||
tf_device = TF_CPU_DEVICE
|
||||
with tf.device(tf_device):
|
||||
(
|
||||
model,
|
||||
input,
|
||||
) = get_tf_model(
|
||||
modelname, self.import_args
|
||||
)[:2]
|
||||
frontend_model = model
|
||||
|
||||
for i in range(amdshark_args.num_warmup_iterations):
|
||||
frontend_model.forward(*input)
|
||||
|
||||
if tf_device == TF_GPU_DEVICE:
|
||||
tf.config.experimental.reset_memory_stats(tf_device)
|
||||
begin = time.time()
|
||||
for i in range(amdshark_args.num_iterations):
|
||||
out = frontend_model.forward(*input)
|
||||
end = time.time()
|
||||
if tf_device == TF_GPU_DEVICE:
|
||||
memory_info = tf.config.experimental.get_memory_info(tf_device)
|
||||
device_peak_b = memory_info["peak"]
|
||||
else:
|
||||
# tf.config.experimental does not currently support measuring
|
||||
# CPU memory usage.
|
||||
device_peak_b = None
|
||||
|
||||
print(
|
||||
f"TF benchmark:{amdshark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{amdshark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{amdshark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/amdshark_args.num_iterations)*1000}",
|
||||
"", # host_peak_b (CPU usage) is not reported by TensorFlow.
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
|
||||
def benchmark_c(self):
|
||||
iter_per_second, host_peak_b, device_peak_b = run_benchmark_module(
|
||||
self.benchmark_cl
|
||||
)
|
||||
print(f"AMDShark-IREE-C benchmark:{iter_per_second} iter/second")
|
||||
return [
|
||||
f"{iter_per_second}",
|
||||
f"{1000/iter_per_second}",
|
||||
_bytes_to_mb_str(host_peak_b),
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
|
||||
def benchmark_python(self, inputs):
|
||||
input_list = [x for x in inputs]
|
||||
for i in range(amdshark_args.num_warmup_iterations):
|
||||
self.run("forward", input_list)
|
||||
|
||||
begin = time.time()
|
||||
for i in range(amdshark_args.num_iterations):
|
||||
out = self.run("forward", input_list)
|
||||
end = time.time()
|
||||
print(
|
||||
f"AMDShark-IREE Python benchmark:{amdshark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{amdshark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{amdshark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/amdshark_args.num_iterations)*1000}",
|
||||
]
|
||||
|
||||
def benchmark_onnx(self, modelname, inputs):
|
||||
if self.device == "cuda":
|
||||
print(
|
||||
"Currently GPU benchmarking on ONNX is not supported in AMDSHARK."
|
||||
)
|
||||
return ["N/A", "N/A"]
|
||||
else:
|
||||
from onnxruntime.transformers.benchmark import run_onnxruntime
|
||||
from onnxruntime.transformers.huggingface_models import MODELS
|
||||
from onnxruntime.transformers.benchmark_helper import (
|
||||
ConfigModifier,
|
||||
Precision,
|
||||
)
|
||||
import psutil
|
||||
|
||||
if modelname == "microsoft/MiniLM-L12-H384-uncased":
|
||||
modelname = "bert-base-uncased"
|
||||
if modelname not in MODELS:
|
||||
print(
|
||||
f"{modelname} is currently not supported in ORT's HF. Check \
|
||||
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/huggingface_models.py \
|
||||
for currently supported models. Exiting benchmark ONNX."
|
||||
)
|
||||
return ["N/A", "N/A"]
|
||||
use_gpu = self.device == "cuda"
|
||||
num_threads = psutil.cpu_count(logical=False)
|
||||
batch_sizes = [1]
|
||||
sequence_lengths = [128]
|
||||
cache_dir = os.path.join(".", "cache_models")
|
||||
onnx_dir = os.path.join(".", "onnx_models")
|
||||
verbose = False
|
||||
input_counts = [1]
|
||||
optimize_onnx = True
|
||||
validate_onnx = False
|
||||
disable_ort_io_binding = False
|
||||
use_raw_attention_mask = True
|
||||
model_fusion_statistics = {}
|
||||
overwrite = False
|
||||
model_source = "pt" # Either "pt" or "tf"
|
||||
provider = None
|
||||
config_modifier = ConfigModifier(None)
|
||||
onnx_args = OnnxFusionOptions()
|
||||
result = run_onnxruntime(
|
||||
use_gpu,
|
||||
provider,
|
||||
(modelname,),
|
||||
None,
|
||||
config_modifier,
|
||||
Precision.FLOAT32,
|
||||
num_threads,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
amdshark_args.num_iterations,
|
||||
input_counts,
|
||||
optimize_onnx,
|
||||
validate_onnx,
|
||||
cache_dir,
|
||||
onnx_dir,
|
||||
verbose,
|
||||
overwrite,
|
||||
disable_ort_io_binding,
|
||||
use_raw_attention_mask,
|
||||
model_fusion_statistics,
|
||||
model_source,
|
||||
onnx_args,
|
||||
)
|
||||
print(
|
||||
f"ONNX ORT-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{amdshark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
result[0]["QPS"],
|
||||
result[0]["average_latency_ms"],
|
||||
]
|
||||
|
||||
def get_metadata(self, modelname):
|
||||
metadata_path = os.path.join(".", "tank", "model_metadata.csv")
|
||||
with open(metadata_path, mode="r") as csvfile:
|
||||
torch_reader = csv.reader(csvfile, delimiter=",")
|
||||
fields = next(torch_reader)
|
||||
for row in torch_reader:
|
||||
torch_model_name = row[0]
|
||||
if torch_model_name == modelname:
|
||||
param_count = row[3]
|
||||
model_tags = row[4]
|
||||
model_notes = row[5]
|
||||
return [param_count, model_tags, model_notes]
|
||||
|
||||
def compare_bench_results(self, baseline: str, result: str):
|
||||
if baseline is not None:
|
||||
# Takes a baseline and a result string and calculates a comparison, e.g. "1.04x baseline".
|
||||
a = float(baseline)
|
||||
b = float(result)
|
||||
comparison = a / b
|
||||
comp_str = f"{round(comparison, 2)}x baseline"
|
||||
else:
|
||||
comp_str = "N/A"
|
||||
|
||||
return comp_str
|
||||
|
||||
def benchmark_all_csv(
|
||||
self,
|
||||
inputs: tuple,
|
||||
modelname,
|
||||
dynamic,
|
||||
device_str,
|
||||
frontend,
|
||||
import_args,
|
||||
mode="native",
|
||||
):
|
||||
self.setup_cl(inputs)
|
||||
self.import_args = import_args
|
||||
self.mode = mode
|
||||
field_names = [
|
||||
"model",
|
||||
"batch_size",
|
||||
"engine",
|
||||
"dialect",
|
||||
"device",
|
||||
"shape_type",
|
||||
"data_type",
|
||||
"iter/sec",
|
||||
"ms/iter",
|
||||
"vs. PyTorch/TF",
|
||||
"iterations",
|
||||
"param_count",
|
||||
"tags",
|
||||
"notes",
|
||||
"datetime",
|
||||
"host_memory_mb",
|
||||
"device_memory_mb",
|
||||
"measured_host_memory_mb",
|
||||
"measured_device_memory_mb",
|
||||
]
|
||||
# "frontend" must be the first element.
|
||||
if self.mode == "native":
|
||||
engines = ["amdshark_python", "amdshark_iree_c"]
|
||||
if self.mode == "baseline":
|
||||
engines = ["frontend"]
|
||||
if self.mode == "all":
|
||||
engines = ["frontend", "amdshark_python", "amdshark_iree_c"]
|
||||
|
||||
if amdshark_args.onnx_bench == True:
|
||||
engines.append("onnxruntime")
|
||||
|
||||
if not os.path.exists("bench_results.csv"):
|
||||
with open("bench_results.csv", mode="w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(field_names)
|
||||
|
||||
with open("bench_results.csv", mode="a", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=field_names)
|
||||
bench_info = {}
|
||||
bench_info["model"] = modelname
|
||||
bench_info["batch_size"] = str(import_args["batch_size"])
|
||||
bench_info["dialect"] = self.mlir_dialect
|
||||
bench_info["iterations"] = amdshark_args.num_iterations
|
||||
if dynamic == True:
|
||||
bench_info["shape_type"] = "dynamic"
|
||||
else:
|
||||
bench_info["shape_type"] = "static"
|
||||
bench_info["device"] = device_str
|
||||
if "fp16" in modelname:
|
||||
bench_info["data_type"] = "float16"
|
||||
else:
|
||||
bench_info["data_type"] = inputs[0].dtype
|
||||
|
||||
for e in engines:
|
||||
engine_result = {}
|
||||
self.frontend_result = None
|
||||
if e == "frontend":
|
||||
engine_result["engine"] = frontend
|
||||
if check_requirements(frontend):
|
||||
(
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
engine_result["host_memory_mb"],
|
||||
engine_result["device_memory_mb"],
|
||||
) = self.benchmark_frontend(modelname)
|
||||
self.frontend_result = engine_result["ms/iter"]
|
||||
engine_result["vs. PyTorch/TF"] = "baseline"
|
||||
(
|
||||
engine_result["param_count"],
|
||||
engine_result["tags"],
|
||||
engine_result["notes"],
|
||||
) = self.get_metadata(modelname)
|
||||
else:
|
||||
self.frontend_result = None
|
||||
continue
|
||||
|
||||
elif e == "amdshark_python":
|
||||
engine_result["engine"] = "amdshark_python"
|
||||
(
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
) = self.benchmark_python(inputs)
|
||||
|
||||
engine_result[
|
||||
"vs. PyTorch/TF"
|
||||
] = self.compare_bench_results(
|
||||
self.frontend_result, engine_result["ms/iter"]
|
||||
)
|
||||
|
||||
elif e == "amdshark_iree_c":
|
||||
engine_result["engine"] = "amdshark_iree_c"
|
||||
(
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
engine_result["host_memory_mb"],
|
||||
engine_result["device_memory_mb"],
|
||||
) = self.benchmark_c()
|
||||
|
||||
engine_result[
|
||||
"vs. PyTorch/TF"
|
||||
] = self.compare_bench_results(
|
||||
self.frontend_result, engine_result["ms/iter"]
|
||||
)
|
||||
|
||||
elif e == "onnxruntime":
|
||||
engine_result["engine"] = "onnxruntime"
|
||||
(
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
) = self.benchmark_onnx(modelname, inputs)
|
||||
|
||||
engine_result["datetime"] = str(datetime.now())
|
||||
writer.writerow(bench_info | engine_result)
|
||||
@@ -1,241 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import import_with_fx, save_mlir
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from typing import List, Tuple
|
||||
from io import BytesIO
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
return [lhs[0], rhs[0]]
|
||||
else:
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
|
||||
amdshark_module = None
|
||||
if os.path.isfile(vmfb_path):
|
||||
amdshark_module = AMDSharkInference(
|
||||
None,
|
||||
device=device,
|
||||
mlir_dialect=mlir_dialect,
|
||||
)
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
amdshark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
return amdshark_module
|
||||
|
||||
|
||||
def compile_module(
|
||||
amdshark_module, extended_model_name, generate_vmfb, extra_args=[]
|
||||
):
|
||||
if generate_vmfb:
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
|
||||
if os.path.isfile(vmfb_path):
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
amdshark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
|
||||
)
|
||||
path = amdshark_module.save_module(
|
||||
os.getcwd(), extended_model_name, extra_args
|
||||
)
|
||||
amdshark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
amdshark_module.compile(extra_args)
|
||||
return amdshark_module
|
||||
|
||||
|
||||
def compile_int_precision(
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name
|
||||
):
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
weight_group_size = 128
|
||||
quantize_model(
|
||||
get_model_impl(model),
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
torchscript_module = import_with_fx(
|
||||
model,
|
||||
inputs,
|
||||
precision=precision,
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
mlir_module = torch_mlir.compile(
|
||||
torchscript_module,
|
||||
inputs,
|
||||
output_type="torch",
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
mlir_file_path = os.path.join(
|
||||
os.getcwd(), f"{extended_model_name}_linalg.mlir"
|
||||
)
|
||||
with open(mlir_file_path, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
print(mlir_module.operation.get_asm())
|
||||
mlir_module = str(mlir_module)
|
||||
mlir_module = mlir_module.encode("UTF-8")
|
||||
mlir_module = BytesIO(mlir_module)
|
||||
bytecode = mlir_module.read()
|
||||
bytecode_path = os.path.join(
|
||||
os.getcwd(), f"{extended_model_name}_linalg.mlirbc"
|
||||
)
|
||||
with open(bytecode_path, "wb") as f:
|
||||
f.write(bytecode)
|
||||
del bytecode
|
||||
del mlir_module
|
||||
print(f"Elided IR written for {extended_model_name}")
|
||||
return bytecode_path
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module=bytecode_path, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
return (
|
||||
compile_module(
|
||||
amdshark_module,
|
||||
extended_model_name=extended_model_name,
|
||||
generate_vmfb=generate_vmfb,
|
||||
extra_args=extra_args,
|
||||
),
|
||||
bytecode_path,
|
||||
)
|
||||
|
||||
|
||||
def amdshark_compile_through_fx(
|
||||
model,
|
||||
inputs,
|
||||
extended_model_name,
|
||||
precision,
|
||||
f16_input_mask=None,
|
||||
save_dir=tempfile.gettempdir(),
|
||||
debug=False,
|
||||
generate_or_load_vmfb=True,
|
||||
extra_args=[],
|
||||
device=None,
|
||||
mlir_dialect="tm_tensor",
|
||||
):
|
||||
is_f16 = precision == "fp16"
|
||||
if generate_or_load_vmfb:
|
||||
amdshark_module = load_vmfb(
|
||||
extended_model_name=extended_model_name,
|
||||
device=device,
|
||||
mlir_dialect=mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
if amdshark_module:
|
||||
return (
|
||||
amdshark_module,
|
||||
None,
|
||||
)
|
||||
|
||||
from amdshark.parser import amdshark_args
|
||||
|
||||
if "cuda" in device:
|
||||
amdshark_args.enable_tf32 = True
|
||||
|
||||
if precision in ["int4", "int8"]:
|
||||
mlir_module = compile_int_precision(
|
||||
model,
|
||||
inputs,
|
||||
precision,
|
||||
device,
|
||||
generate_or_load_vmfb,
|
||||
extended_model_name,
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
else:
|
||||
(
|
||||
bytecode,
|
||||
_,
|
||||
) = import_with_fx(
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=f16_input_mask,
|
||||
debug=debug,
|
||||
model_name=extended_model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
mlir_module = save_mlir(
|
||||
mlir_module=bytecode,
|
||||
model_name=extended_model_name,
|
||||
mlir_dialect=mlir_dialect,
|
||||
)
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module,
|
||||
device=device,
|
||||
mlir_dialect=mlir_dialect,
|
||||
)
|
||||
return (
|
||||
compile_module(
|
||||
amdshark_module,
|
||||
extended_model_name,
|
||||
generate_vmfb=generate_or_load_vmfb,
|
||||
extra_args=extra_args,
|
||||
),
|
||||
mlir_module,
|
||||
)
|
||||
@@ -1,297 +0,0 @@
|
||||
# Lint as: python3
|
||||
"""AMDSHARK Downloader"""
|
||||
# Requirements : Put amdshark_tank in AMDSHARK directory
|
||||
# /AMDSHARK
|
||||
# /gen_amdshark_tank
|
||||
# /tflite
|
||||
# /albert_lite_base
|
||||
# /...model_name...
|
||||
# /tf
|
||||
# /pytorch
|
||||
#
|
||||
#
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
from tqdm.std import tqdm
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from amdshark.parser import amdshark_args
|
||||
from google.cloud import storage
|
||||
|
||||
|
||||
def download_public_file(
|
||||
full_gs_url, destination_folder_name, single_file=False
|
||||
):
|
||||
"""Downloads a public blob from the bucket."""
|
||||
# bucket_name = "gs://your-bucket-name/path/to/file"
|
||||
# destination_file_name = "local/path/to/file"
|
||||
|
||||
storage_client = storage.Client.create_anonymous_client()
|
||||
bucket_name = full_gs_url.split("/")[2]
|
||||
source_blob_name = None
|
||||
dest_filename = None
|
||||
desired_file = None
|
||||
if single_file:
|
||||
desired_file = full_gs_url.split("/")[-1]
|
||||
source_blob_name = "/".join(full_gs_url.split("/")[3:-1])
|
||||
destination_folder_name, dest_filename = os.path.split(
|
||||
destination_folder_name
|
||||
)
|
||||
else:
|
||||
source_blob_name = "/".join(full_gs_url.split("/")[3:])
|
||||
bucket = storage_client.bucket(bucket_name)
|
||||
blobs = bucket.list_blobs(prefix=source_blob_name)
|
||||
if not os.path.exists(destination_folder_name):
|
||||
os.mkdir(destination_folder_name)
|
||||
for blob in blobs:
|
||||
blob_name = blob.name.split("/")[-1]
|
||||
if single_file:
|
||||
if blob_name == desired_file:
|
||||
destination_filename = os.path.join(
|
||||
destination_folder_name, dest_filename
|
||||
)
|
||||
with open(destination_filename, "wb") as f:
|
||||
with tqdm.wrapattr(
|
||||
f, "write", total=blob.size
|
||||
) as file_obj:
|
||||
storage_client.download_blob_to_file(blob, file_obj)
|
||||
else:
|
||||
continue
|
||||
|
||||
else:
|
||||
destination_filename = os.path.join(
|
||||
destination_folder_name, blob_name
|
||||
)
|
||||
if os.path.isdir(destination_filename):
|
||||
continue
|
||||
with open(destination_filename, "wb") as f:
|
||||
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
|
||||
storage_client.download_blob_to_file(blob, file_obj)
|
||||
|
||||
|
||||
input_type_to_np_dtype = {
|
||||
"float32": np.float32,
|
||||
"float64": np.float64,
|
||||
"bool": np.bool_,
|
||||
"int32": np.int32,
|
||||
"int64": np.int64,
|
||||
"uint8": np.uint8,
|
||||
"int8": np.int8,
|
||||
}
|
||||
|
||||
# Save the model in the home local so it needn't be fetched everytime in the CI.
|
||||
home = str(Path.home())
|
||||
alt_path = os.path.join(os.path.dirname(__file__), "../gen_amdshark_tank/")
|
||||
custom_path = amdshark_args.local_tank_cache
|
||||
|
||||
if custom_path is not None:
|
||||
if not os.path.exists(custom_path):
|
||||
os.mkdir(custom_path)
|
||||
|
||||
WORKDIR = custom_path
|
||||
|
||||
print(f"Using {WORKDIR} as local amdshark_tank cache directory.")
|
||||
|
||||
elif os.path.exists(alt_path):
|
||||
WORKDIR = alt_path
|
||||
print(
|
||||
f"Using {WORKDIR} as amdshark_tank directory. Delete this directory if you aren't working from locally generated amdshark_tank."
|
||||
)
|
||||
else:
|
||||
WORKDIR = os.path.join(home, ".local/amdshark_tank/")
|
||||
print(
|
||||
f"amdshark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache= flag"
|
||||
)
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
|
||||
|
||||
# Checks whether the directory and files exists.
|
||||
def check_dir_exists(model_name, frontend="torch", dynamic=""):
|
||||
model_dir = os.path.join(WORKDIR, model_name)
|
||||
|
||||
# Remove the _tf keyword from end only for non-SD models.
|
||||
if not any(model in model_name for model in ["clip", "unet", "vae"]):
|
||||
if frontend in ["tf", "tensorflow"]:
|
||||
model_name = model_name[:-3]
|
||||
elif frontend in ["tflite"]:
|
||||
model_name = model_name[:-7]
|
||||
elif frontend in ["torch", "pytorch"]:
|
||||
model_name = model_name[:-6]
|
||||
|
||||
model_mlir_file_name = f"{model_name}{dynamic}_{frontend}.mlir"
|
||||
|
||||
if os.path.isdir(model_dir):
|
||||
if (
|
||||
os.path.isfile(os.path.join(model_dir, model_mlir_file_name))
|
||||
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
|
||||
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
|
||||
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
|
||||
and os.path.isfile(os.path.join(model_dir, "hash.npy"))
|
||||
):
|
||||
print(
|
||||
f"""Model artifacts for {model_name} found at {WORKDIR}..."""
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _internet_connected():
|
||||
import requests as req
|
||||
|
||||
try:
|
||||
req.get("http://1.1.1.1")
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def get_git_revision_short_hash() -> str:
|
||||
import subprocess
|
||||
|
||||
if amdshark_args.amdshark_prefix is not None:
|
||||
prefix_kw = amdshark_args.amdshark_prefix
|
||||
else:
|
||||
import json
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
src = os.path.join(dir_path, "..", "tank_version.json")
|
||||
with open(src, "r") as f:
|
||||
data = json.loads(f.read())
|
||||
prefix_kw = data["version"]
|
||||
print(f"Checking for updates from gs://amdshark_tank/{prefix_kw}")
|
||||
return prefix_kw
|
||||
|
||||
|
||||
def get_amdsharktank_prefix():
|
||||
tank_prefix = ""
|
||||
if not _internet_connected():
|
||||
print(
|
||||
"No internet connection. Using the model already present in the tank."
|
||||
)
|
||||
tank_prefix = "none"
|
||||
else:
|
||||
desired_prefix = get_git_revision_short_hash()
|
||||
storage_client_a = storage.Client.create_anonymous_client()
|
||||
base_bucket_name = "amdshark_tank"
|
||||
base_bucket = storage_client_a.bucket(base_bucket_name)
|
||||
dir_blobs = base_bucket.list_blobs(prefix=f"{desired_prefix}")
|
||||
for blob in dir_blobs:
|
||||
dir_blob_name = blob.name.split("/")
|
||||
if desired_prefix in dir_blob_name[0]:
|
||||
tank_prefix = dir_blob_name[0]
|
||||
break
|
||||
else:
|
||||
continue
|
||||
if tank_prefix == "":
|
||||
print(
|
||||
f"amdshark_tank bucket not found matching ({desired_prefix}). Defaulting to nightly."
|
||||
)
|
||||
tank_prefix = "nightly"
|
||||
return tank_prefix
|
||||
|
||||
|
||||
# Downloads the torch model from gs://amdshark_tank dir.
|
||||
def download_model(
|
||||
model_name,
|
||||
dynamic=False,
|
||||
tank_url=None,
|
||||
frontend=None,
|
||||
tuned=None,
|
||||
import_args={"batch_size": 1},
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
amdshark_args.amdshark_prefix = get_amdsharktank_prefix()
|
||||
if import_args["batch_size"] and import_args["batch_size"] != 1:
|
||||
model_dir_name = (
|
||||
model_name
|
||||
+ "_"
|
||||
+ frontend
|
||||
+ "_BS"
|
||||
+ str(import_args["batch_size"])
|
||||
)
|
||||
elif any(model in model_name for model in ["clip", "unet", "vae"]):
|
||||
# TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation.
|
||||
model_dir_name = model_name
|
||||
else:
|
||||
model_dir_name = model_name + "_" + frontend
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
|
||||
if not tank_url:
|
||||
tank_url = "gs://amdshark_tank/" + amdshark_args.amdshark_prefix
|
||||
|
||||
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
|
||||
if not check_dir_exists(
|
||||
model_dir_name, frontend=frontend, dynamic=dyn_str
|
||||
):
|
||||
print(
|
||||
f"Downloading artifacts for model {model_name} from: {full_gs_url}"
|
||||
)
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
|
||||
elif amdshark_args.force_update_tank == True:
|
||||
print(
|
||||
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
|
||||
)
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
else:
|
||||
if not _internet_connected():
|
||||
print(
|
||||
"No internet connection. Using the model already present in the tank."
|
||||
)
|
||||
else:
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash_url = (
|
||||
tank_url.rstrip("/") + "/" + model_dir_name + "/hash.npy"
|
||||
)
|
||||
download_public_file(
|
||||
gs_hash_url,
|
||||
os.path.join(model_dir, "upstream_hash.npy"),
|
||||
single_file=True,
|
||||
)
|
||||
try:
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
except FileNotFoundError:
|
||||
print(f"Model artifact hash not found at {model_dir}.")
|
||||
upstream_hash = None
|
||||
if local_hash != upstream_hash and amdshark_args.update_tank == True:
|
||||
print(f"Updating artifacts for model {model_name}...")
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
|
||||
elif local_hash != upstream_hash:
|
||||
print(
|
||||
"Hash does not match upstream in gs://amdshark_tank/. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Local and upstream hashes match. Using cached model artifacts."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
tuned_str = "" if tuned is None else "_" + tuned
|
||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||
mlir_filename = os.path.join(model_dir, model_name + suffix)
|
||||
print(
|
||||
f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..."
|
||||
)
|
||||
if not os.path.exists(mlir_filename):
|
||||
from tank.generate_amdsharktank import gen_amdshark_files
|
||||
|
||||
print(
|
||||
"The model data was not found. Trying to generate artifacts locally."
|
||||
)
|
||||
gen_amdshark_files(model_name, frontend, WORKDIR, import_args)
|
||||
|
||||
assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}"
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_filename, function_name, inputs_tuple, golden_out_tuple
|
||||
@@ -1,212 +0,0 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
from amdshark.amdshark_importer import import_with_fx, save_mlir
|
||||
import torchvision.models as models
|
||||
import copy
|
||||
import io
|
||||
import numpy as np
|
||||
import sys
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node
|
||||
from typing import Dict
|
||||
import torch_mlir
|
||||
|
||||
|
||||
def amdshark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_g, inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
bytecode_path = save_mlir(
|
||||
bytecode,
|
||||
model_name="amdshark_eager_module",
|
||||
frontend="torch",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module=bytecode_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
amdshark_module.compile(extra_args=[])
|
||||
return amdshark_module
|
||||
|
||||
|
||||
def _make_single_op_gm(node, captured_val, compiled_graph):
|
||||
"""Make a GraphModule that just executes the given node."""
|
||||
g = torch.fx.Graph()
|
||||
env = {}
|
||||
inputs = []
|
||||
for arg in node.args:
|
||||
if arg and hasattr(arg, "name"):
|
||||
env[arg.name] = g.placeholder(arg.name)
|
||||
if isinstance(captured_val[arg.name], (list, tuple)):
|
||||
for val in captured_val[arg.name]:
|
||||
inputs.append(val)
|
||||
else:
|
||||
inputs.append(captured_val[arg.name])
|
||||
|
||||
call = g.node_copy(node, lambda n: env[n.name])
|
||||
g.output(call)
|
||||
g.lint()
|
||||
single_node = torch.fx.GraphModule(torch.nn.Module(), g)
|
||||
compiled_module = amdshark_backend(single_node, inputs)
|
||||
compiled_graph[node.name] = {
|
||||
"module": compiled_module,
|
||||
"inputs": [i for i in env],
|
||||
"result": None,
|
||||
}
|
||||
return
|
||||
|
||||
|
||||
def compiled_graph(gm: torch.fx.GraphModule, attr_info):
|
||||
compiled_graph = {}
|
||||
g = gm.graph
|
||||
for node in g.nodes:
|
||||
if node.op == "call_function":
|
||||
if not (
|
||||
node.target in [torch.ops.aten.empty]
|
||||
or node.name.startswith("getitem")
|
||||
):
|
||||
_make_single_op_gm(node, attr_info, compiled_graph)
|
||||
|
||||
# Currently torch.aten.empty has an compilation issue, so running natively.
|
||||
elif node.target in [torch.ops.aten.empty]:
|
||||
compiled_graph[node.name] = {
|
||||
"target": node.target,
|
||||
"args": node.args,
|
||||
"kwargs": node.kwargs,
|
||||
"result": None,
|
||||
}
|
||||
# Get item is a simple case takes a tuple and return the tensor at a particular index.
|
||||
elif node.name.startswith("getitem"):
|
||||
compiled_graph[node.name] = {
|
||||
"input": node.args[0].name,
|
||||
"pos": node.args[1],
|
||||
"result": None,
|
||||
}
|
||||
|
||||
return compiled_graph
|
||||
|
||||
|
||||
class ShapeProp:
|
||||
"""
|
||||
Shape propagation. This class takes a `GraphModule`.
|
||||
Then, its `propagate` method executes the `GraphModule`
|
||||
node-by-node with the given arguments. As each operation
|
||||
executes, the ShapeProp class stores away the shape and
|
||||
element type for the output values of each operation on
|
||||
the `shape` and `dtype` attributes of the operation's
|
||||
`Node`.
|
||||
"""
|
||||
|
||||
def __init__(self, mod):
|
||||
self.mod = mod
|
||||
self.graph = mod.graph
|
||||
self.modules = dict(self.mod.named_modules())
|
||||
|
||||
def propagate(self, *args):
|
||||
args_iter = iter(args)
|
||||
env: Dict[str, Node] = {}
|
||||
|
||||
def load_arg(a):
|
||||
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
|
||||
|
||||
def fetch_attr(target: str):
|
||||
target_atoms = target.split(".")
|
||||
attr_itr = self.mod
|
||||
for i, atom in enumerate(target_atoms):
|
||||
if not hasattr(attr_itr, atom):
|
||||
raise RuntimeError(
|
||||
f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
|
||||
)
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
return attr_itr
|
||||
|
||||
for node in self.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
result = next(args_iter)
|
||||
elif node.op == "get_attr":
|
||||
result = fetch_attr(node.target)
|
||||
elif node.op == "call_function":
|
||||
result = node.target(
|
||||
*load_arg(node.args), **load_arg(node.kwargs)
|
||||
)
|
||||
elif node.op == "call_method":
|
||||
self_obj, *args = load_arg(node.args)
|
||||
kwargs = load_arg(node.kwargs)
|
||||
result = getattr(self_obj, node.target)(*args, **kwargs)
|
||||
elif node.op == "call_module":
|
||||
result = self.modules[node.target](
|
||||
*load_arg(node.args), **load_arg(node.kwargs)
|
||||
)
|
||||
|
||||
# This is the only code specific to shape propagation.
|
||||
# you can delete this `if` branch and this becomes
|
||||
# a generic GraphModule interpreter.
|
||||
if isinstance(result, torch.Tensor):
|
||||
node.shape = result.shape
|
||||
node.dtype = result.dtype
|
||||
|
||||
env[node.name] = result
|
||||
|
||||
return env
|
||||
|
||||
# return load_arg(self.graph.result)
|
||||
|
||||
|
||||
resnet18 = models.resnet18(pretrained=True)
|
||||
resnet18.train(False)
|
||||
input = (torch.randn(1, 3, 224, 224),)
|
||||
|
||||
print(resnet18(input[0]))
|
||||
|
||||
fx_graph = import_with_fx(resnet18, input, mlir_type="fx")
|
||||
|
||||
shape_prop = ShapeProp(fx_graph)
|
||||
|
||||
x = shape_prop.propagate(input[0])
|
||||
|
||||
amdshark_graph = compiled_graph(fx_graph, x)
|
||||
|
||||
|
||||
for key in amdshark_graph:
|
||||
if key.startswith("getitem"):
|
||||
input_val = amdshark_graph[key]["input"]
|
||||
pos = amdshark_graph[key]["pos"]
|
||||
if input_val not in amdshark_graph:
|
||||
amdshark_graph[key]["result"] = x[input_val][pos].detach()
|
||||
else:
|
||||
amdshark_graph[key]["result"] = amdshark_graph[input_val]["result"][
|
||||
pos
|
||||
].detach()
|
||||
elif key.startswith("empty"):
|
||||
operator = amdshark_graph[key]["target"]
|
||||
args = amdshark_graph[key]["args"]
|
||||
kwargs = amdshark_graph[key]["kwargs"]
|
||||
amdshark_graph[key]["result"] = operator(*args, **kwargs).detach()
|
||||
else:
|
||||
input_val = amdshark_graph[key]["inputs"]
|
||||
input_tensors = []
|
||||
for input in input_val:
|
||||
if input not in amdshark_graph:
|
||||
input_tensors.append(x[input].detach())
|
||||
else:
|
||||
input_tensors.append(amdshark_graph[input]["result"])
|
||||
|
||||
val = amdshark_graph[key]["module"]("forward", input_tensors)
|
||||
if isinstance(val, (tuple, list)):
|
||||
list_val = []
|
||||
for v in val:
|
||||
list_val.append(torch.from_numpy(v))
|
||||
amdshark_graph[key]["result"] = list_val
|
||||
else:
|
||||
amdshark_graph[key]["result"] = torch.from_numpy(val)
|
||||
|
||||
|
||||
print(amdshark_graph)
|
||||
@@ -1,153 +0,0 @@
|
||||
import re
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
import torch_mlir
|
||||
from iree.compiler import compile_file
|
||||
from amdshark.amdshark_importer import import_with_fx, get_f16_inputs, save_mlir
|
||||
|
||||
|
||||
class GenerateConfigFile:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
num_sharding_stages: int,
|
||||
sharding_stages_id: list[str],
|
||||
units_in_each_stage: list[int],
|
||||
model_input=None,
|
||||
config_file_path="model_config.json",
|
||||
):
|
||||
self.model = model
|
||||
self.num_sharding_stages = num_sharding_stages
|
||||
self.sharding_stages_id = sharding_stages_id
|
||||
assert self.num_sharding_stages == len(
|
||||
self.sharding_stages_id
|
||||
), "Number of sharding stages should be equal to the list of their ID"
|
||||
self.model_input = model_input
|
||||
self.config_file_path = config_file_path
|
||||
# (Nithin) this is a quick fix - revisit and rewrite
|
||||
self.units_in_each_stage = np.array(units_in_each_stage)
|
||||
self.track_loop = np.zeros(len(self.sharding_stages_id)).astype(int)
|
||||
|
||||
def split_into_dispatches(
|
||||
self,
|
||||
backend,
|
||||
fx_tracing_required=False,
|
||||
f16_model=False,
|
||||
torch_mlir_tracing=True,
|
||||
):
|
||||
graph_for_compilation = self.model
|
||||
if fx_tracing_required:
|
||||
graph_for_compilation = import_with_fx(
|
||||
self.model,
|
||||
self.model_input,
|
||||
is_f16=f16_model,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
graph_for_compilation,
|
||||
(self.model_input),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=torch_mlir_tracing,
|
||||
verbose=False,
|
||||
)
|
||||
module = module.operation.get_asm(large_elements_limit=4)
|
||||
module_file = save_mlir(
|
||||
module,
|
||||
model_name="module_pre_split",
|
||||
frontend="torch",
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
compiled_module_str = str(
|
||||
compile_file(
|
||||
module_file,
|
||||
target_backends=[backend],
|
||||
extra_args=[
|
||||
"--compile-to=flow",
|
||||
"--mlir-elide-elementsattrs-if-larger=4",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
substring_start_idx = [
|
||||
m.start()
|
||||
for m in re.finditer("flow.dispatch @", compiled_module_str)
|
||||
]
|
||||
dispatch_list = dict()
|
||||
|
||||
# dispatch_no is the 'i'th index of a dispatch out of n total dispatches of a model
|
||||
# dispatch_id is the unique id of a dispatch, multiple instances of the same dispatch
|
||||
# can occur in a model
|
||||
for dispatch_no, substring_idx in enumerate(substring_start_idx):
|
||||
dispatch_idx = (
|
||||
compiled_module_str[substring_idx:]
|
||||
.split(":")[0]
|
||||
.split("@")[-1]
|
||||
)
|
||||
key = "dispatch_no_" + str(dispatch_no)
|
||||
dispatch_list[key] = {n: "None" for n in self.sharding_stages_id}
|
||||
dispatch_list[key]["dispatch_id"] = dispatch_idx
|
||||
|
||||
self.generate_json(dispatch_list)
|
||||
|
||||
def split_into_layers(self):
|
||||
model_dictionary = dict()
|
||||
|
||||
for name, m in self.model.named_modules():
|
||||
if name == "":
|
||||
continue
|
||||
|
||||
# Remove non-leaf nodes from the config as they aren't an operation
|
||||
substring_before_final_period = name.split(".")[:-1]
|
||||
substring_before_final_period = ".".join(
|
||||
substring_before_final_period
|
||||
)
|
||||
if substring_before_final_period in model_dictionary:
|
||||
del model_dictionary[substring_before_final_period]
|
||||
|
||||
# layer_dict = {n: "None" for n in self.sharding_stages_id}
|
||||
|
||||
# By default embed increasing device id's for each layer
|
||||
increasing_wraparound_idx_list = (
|
||||
self.track_loop % self.units_in_each_stage
|
||||
)
|
||||
layer_dict = {
|
||||
n: int(increasing_wraparound_idx_list[idx][0][0])
|
||||
for idx, n in enumerate(self.sharding_stages_id)
|
||||
}
|
||||
self.track_loop += 1
|
||||
model_dictionary[name] = layer_dict
|
||||
|
||||
self.generate_json(model_dictionary)
|
||||
|
||||
def generate_json(self, artifacts):
|
||||
with open(self.config_file_path, "w") as outfile:
|
||||
json.dump(artifacts, outfile)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
|
||||
[1, 19]
|
||||
)
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna7B,
|
||||
CombinedModel,
|
||||
)
|
||||
|
||||
model = CombinedModel()
|
||||
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
|
||||
c.split_into_layers()
|
||||
@@ -1,819 +0,0 @@
|
||||
# Lint as: python3
|
||||
"""AMDSHARK Importer"""
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
def create_hash(file_name):
|
||||
with open(file_name, "rb") as f:
|
||||
file_hash = hashlib.blake2b(digest_size=64)
|
||||
while chunk := f.read(2**10):
|
||||
file_hash.update(chunk)
|
||||
|
||||
return file_hash.hexdigest()
|
||||
|
||||
|
||||
# List of the supported frontends.
|
||||
supported_frontends = {
|
||||
"tensorflow",
|
||||
"tf",
|
||||
"pytorch",
|
||||
"torch",
|
||||
"tf-lite",
|
||||
"tflite",
|
||||
}
|
||||
|
||||
|
||||
class AMDSharkImporter:
|
||||
"""
|
||||
AMDSharkImporter converts frontend modules into a
|
||||
mlir_module. The supported frameworks are tensorflow,
|
||||
pytorch, and tf-lite.
|
||||
|
||||
...
|
||||
|
||||
Attributes
|
||||
----------
|
||||
module :
|
||||
torch, tensorflow or tf-lite module.
|
||||
inputs :
|
||||
inputs to the module, may be required for the shape
|
||||
information.
|
||||
frontend: str
|
||||
frontend to which the module belongs.
|
||||
raw_model_file: str
|
||||
temp tflite model path
|
||||
|
||||
Methods
|
||||
-------
|
||||
import_mlir(is_dynamic, tracing_required, func_name):
|
||||
is_dynamic: input shapes to be totally dynamic (pytorch specific).
|
||||
tracing_required: whether tracing is required (pytorch specific.
|
||||
func_name: The function to be traced out or imported to mlir.
|
||||
|
||||
import_debug(is_dynamic, tracing_required, func_name):
|
||||
returns the converted (mlir_module,func_name) with inputs and golden
|
||||
outputs.
|
||||
The inputs and outputs are converted into np array.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
inputs: tuple = (),
|
||||
frontend: str = "torch",
|
||||
raw_model_file: str = "",
|
||||
return_str: bool = False,
|
||||
):
|
||||
self.module = module
|
||||
self.inputs = None if len(inputs) == 0 else inputs
|
||||
self.frontend = frontend
|
||||
if not self.frontend in supported_frontends:
|
||||
print(
|
||||
f"The frontend is not in the supported_frontends: {supported_frontends}"
|
||||
)
|
||||
sys.exit(1)
|
||||
self.raw_model_file = raw_model_file
|
||||
self.return_str = return_str
|
||||
|
||||
# NOTE: The default function for torch is "forward" and tf-lite is "main".
|
||||
|
||||
def _torch_mlir(self, is_dynamic, tracing_required, mlir_type):
|
||||
from amdshark.torch_mlir_utils import get_torch_mlir_module
|
||||
|
||||
return get_torch_mlir_module(
|
||||
self.module,
|
||||
self.inputs,
|
||||
is_dynamic,
|
||||
tracing_required,
|
||||
self.return_str,
|
||||
mlir_type,
|
||||
)
|
||||
|
||||
def _tf_mlir(self, func_name, save_dir="."):
|
||||
from iree.compiler import tf as tfc
|
||||
|
||||
return tfc.compile_module(
|
||||
self.module,
|
||||
exported_names=[func_name],
|
||||
import_only=True,
|
||||
output_file=save_dir,
|
||||
)
|
||||
|
||||
def _tflite_mlir(self, func_name, save_dir="."):
|
||||
from iree.compiler import tflite as tflitec
|
||||
|
||||
self.mlir_model = tflitec.compile_file(
|
||||
self.raw_model_file, # in tflite, it is a path to .tflite file, not a tflite interpreter
|
||||
input_type="tosa",
|
||||
import_only=True,
|
||||
output_file=save_dir,
|
||||
)
|
||||
return self.mlir_model
|
||||
|
||||
# Adds the conversion of the frontend with the private function.
|
||||
def import_mlir(
|
||||
self,
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
func_name="forward",
|
||||
save_dir=cmd_opts.tmp_dir, #"./amdshark_tmp/",
|
||||
mlir_type="linalg",
|
||||
):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
if self.inputs == None:
|
||||
print(
|
||||
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module"
|
||||
)
|
||||
sys.exit(1)
|
||||
return (
|
||||
self._torch_mlir(is_dynamic, tracing_required, mlir_type),
|
||||
func_name,
|
||||
)
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
return self._tf_mlir(func_name, save_dir), func_name
|
||||
if self.frontend in ["tflite", "tf-lite"]:
|
||||
func_name = "main"
|
||||
return self._tflite_mlir(func_name, save_dir), func_name
|
||||
|
||||
# Converts the frontend specific tensors into np array.
|
||||
def convert_to_numpy(self, array_tuple: tuple):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
return [x.detach().cpu().numpy() for x in array_tuple]
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
return [x.numpy() for x in array_tuple]
|
||||
|
||||
# Saves `function_name.npy`, `inputs.npz`, `golden_out.npz` and `model_name.mlir` in the directory `dir`.
|
||||
def save_data(
|
||||
self,
|
||||
dir,
|
||||
model_name,
|
||||
mlir_data,
|
||||
func_name,
|
||||
inputs,
|
||||
outputs,
|
||||
mlir_type="linalg",
|
||||
):
|
||||
import numpy as np
|
||||
|
||||
inputs_name = "inputs.npz"
|
||||
outputs_name = "golden_out.npz"
|
||||
func_file_name = "function_name"
|
||||
model_name_mlir = (
|
||||
model_name + "_" + self.frontend + "_" + mlir_type + ".mlir"
|
||||
)
|
||||
print(f"saving {model_name_mlir} to {dir}")
|
||||
try:
|
||||
inputs = [x.cpu().detach() for x in inputs]
|
||||
except AttributeError:
|
||||
try:
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
except AttributeError:
|
||||
inputs = [x for x in inputs]
|
||||
np.savez(os.path.join(dir, inputs_name), *inputs)
|
||||
np.savez(os.path.join(dir, outputs_name), *outputs)
|
||||
np.save(os.path.join(dir, func_file_name), np.array(func_name))
|
||||
if self.frontend == "torch":
|
||||
with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file:
|
||||
mlir_file.write(mlir_data)
|
||||
hash_gen_attempts = 2
|
||||
for i in range(hash_gen_attempts):
|
||||
try:
|
||||
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
|
||||
except FileNotFoundError as err:
|
||||
if i < hash_gen_attempts:
|
||||
continue
|
||||
else:
|
||||
raise err
|
||||
|
||||
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
|
||||
return
|
||||
|
||||
def import_debug(
|
||||
self,
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
func_name="forward",
|
||||
dir=tempfile.gettempdir(),
|
||||
model_name="model",
|
||||
golden_values=None,
|
||||
mlir_type="linalg",
|
||||
):
|
||||
if self.inputs == None:
|
||||
print(
|
||||
f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir."
|
||||
)
|
||||
sys.exit(1)
|
||||
model_name_mlir = (
|
||||
model_name + "_" + self.frontend + "_" + mlir_type + ".mlir"
|
||||
)
|
||||
artifact_path = os.path.join(dir, model_name_mlir)
|
||||
imported_mlir = self.import_mlir(
|
||||
is_dynamic,
|
||||
tracing_required,
|
||||
func_name,
|
||||
save_dir=artifact_path,
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
|
||||
# TODO: Check for multiple outputs.
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
import torch
|
||||
|
||||
golden_out = None
|
||||
if golden_values is not None:
|
||||
golden_out = golden_values
|
||||
else:
|
||||
golden_out = self.module(*self.inputs)
|
||||
if torch.is_tensor(golden_out):
|
||||
golden_out = tuple(
|
||||
golden_out.detach().cpu().numpy(),
|
||||
)
|
||||
else:
|
||||
golden_out = self.convert_to_numpy(golden_out)
|
||||
# Save the artifacts in the directory dir.
|
||||
self.save_data(
|
||||
dir,
|
||||
model_name,
|
||||
imported_mlir[0],
|
||||
imported_mlir[1],
|
||||
self.inputs,
|
||||
golden_out,
|
||||
mlir_type,
|
||||
)
|
||||
return (
|
||||
imported_mlir,
|
||||
self.convert_to_numpy(self.inputs),
|
||||
golden_out,
|
||||
)
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
import tensorflow as tf
|
||||
|
||||
golden_out = self.module.forward(*self.inputs)
|
||||
if tf.is_tensor(golden_out):
|
||||
golden_out = tuple(
|
||||
golden_out.numpy(),
|
||||
)
|
||||
elif golden_out is tuple:
|
||||
golden_out = self.convert_to_numpy(golden_out)
|
||||
elif hasattr(golden_out, "logits"):
|
||||
# from transformers import TFSequenceClassifierOutput
|
||||
golden_out = golden_out.logits
|
||||
else:
|
||||
golden_out = golden_out.last_hidden_state
|
||||
# Save the artifacts in the directory dir.
|
||||
self.save_data(
|
||||
dir,
|
||||
model_name,
|
||||
imported_mlir[0],
|
||||
imported_mlir[1],
|
||||
self.inputs,
|
||||
golden_out,
|
||||
)
|
||||
return (
|
||||
imported_mlir,
|
||||
self.convert_to_numpy(self.inputs),
|
||||
golden_out,
|
||||
)
|
||||
if self.frontend in ["tflite", "tf-lite"]:
|
||||
# TODO(Chi): Validate it for tflite models.
|
||||
golden_out = self.module.invoke_tflite(self.inputs)
|
||||
self.save_data(
|
||||
dir,
|
||||
model_name,
|
||||
imported_mlir[0],
|
||||
imported_mlir[1],
|
||||
self.inputs,
|
||||
golden_out,
|
||||
)
|
||||
return (
|
||||
imported_mlir,
|
||||
self.inputs,
|
||||
golden_out,
|
||||
)
|
||||
|
||||
|
||||
def get_f16_inputs(inputs, is_f16, f16_input_mask):
|
||||
if is_f16 == False:
|
||||
return inputs
|
||||
if f16_input_mask == None:
|
||||
return tuple([x.half() for x in inputs])
|
||||
|
||||
f16_masked_inputs = []
|
||||
for i in range(len(inputs)):
|
||||
if f16_input_mask[i]:
|
||||
f16_masked_inputs.append(inputs[i].half())
|
||||
else:
|
||||
f16_masked_inputs.append(inputs[i])
|
||||
|
||||
return tuple(f16_masked_inputs)
|
||||
|
||||
|
||||
# Upcasts the block/list of ops.
|
||||
def add_upcast(fx_g):
|
||||
import torch
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.target in [torch.ops.aten.mul]:
|
||||
# This is a very strict check.
|
||||
if hasattr(node.args[1], "target"):
|
||||
if (
|
||||
node.args[1].target in [torch.ops.aten.rsqrt]
|
||||
and node.args[1].args[0].target in [torch.ops.aten.add]
|
||||
and node.args[1].args[0].args[0].target
|
||||
in [torch.ops.aten.mean]
|
||||
and node.args[1].args[0].args[0].args[0].target
|
||||
in [torch.ops.aten.pow]
|
||||
):
|
||||
print("found an upcasting block let's upcast it.")
|
||||
pow_node = node.args[1].args[0].args[0].args[0]
|
||||
mul_node = node
|
||||
with fx_g.graph.inserting_before(pow_node):
|
||||
lhs = pow_node.args[0]
|
||||
upcast_lhs = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(lhs,),
|
||||
kwargs={"dtype": torch.float32},
|
||||
)
|
||||
pow_node.args = (upcast_lhs, pow_node.args[1])
|
||||
with fx_g.graph.inserting_before(mul_node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(mul_node,),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
mul_node.append(new_node)
|
||||
mul_node.replace_all_uses_with(new_node)
|
||||
new_node.args = (mul_node,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
def transform_fx(fx_g, quantized=False):
|
||||
import torch
|
||||
|
||||
kwargs_dict = {
|
||||
"dtype": torch.float16,
|
||||
"device": torch.device(type="cpu"),
|
||||
"pin_memory": False,
|
||||
}
|
||||
kwargs_dict1 = {
|
||||
"dtype": torch.float16,
|
||||
}
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
if quantized:
|
||||
continue
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.zeros,
|
||||
torch.ops.aten.zeros_like,
|
||||
]:
|
||||
if node.kwargs.get("dtype") == torch.float32:
|
||||
node.kwargs = kwargs_dict
|
||||
|
||||
# Vicuna
|
||||
if node.target in [
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("dtype") == torch.float32:
|
||||
node.kwargs = kwargs_dict1
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten.masked_fill,
|
||||
]:
|
||||
if node.args[2] > torch.finfo(torch.half).max:
|
||||
max_val = torch.finfo(torch.half).max
|
||||
node.args = (node.args[0], node.args[1], max_val)
|
||||
elif node.args[2] < torch.finfo(torch.half).min:
|
||||
min_val = torch.finfo(torch.half).min
|
||||
node.args = (node.args[0], node.args[1], min_val)
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten.full,
|
||||
]:
|
||||
if node.args[1] > torch.finfo(torch.half).max:
|
||||
max_val = torch.finfo(torch.half).max
|
||||
node.args = (node.args[0], max_val)
|
||||
node.kwargs = kwargs_dict
|
||||
elif node.args[1] < torch.finfo(torch.half).min:
|
||||
min_val = torch.finfo(torch.half).min
|
||||
node.args = (node.args[0], min_val)
|
||||
node.kwargs = kwargs_dict
|
||||
|
||||
# Inputs and outputs of aten.var.mean should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.var_mean]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node, node.args[1])
|
||||
|
||||
if node.name.startswith("getitem"):
|
||||
with fx_g.graph.inserting_before(node):
|
||||
if node.args[0].target in [torch.ops.aten.var_mean]:
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node,),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
# Required for cuda debugging.
|
||||
# for node in fx_g.graph.nodes:
|
||||
# if node.op == "call_function":
|
||||
# if node.kwargs.get("device") == torch.device(type="cpu"):
|
||||
# new_kwargs = node.kwargs.copy()
|
||||
# new_kwargs["device"] = torch.device(type="cuda")
|
||||
# node.kwargs = new_kwargs
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
def gptq_transforms(fx_g):
|
||||
import torch
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.ones,
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("device") == torch.device(device="cuda:0"):
|
||||
updated_kwargs = node.kwargs.copy()
|
||||
updated_kwargs["device"] = torch.device(device="cpu")
|
||||
node.kwargs = updated_kwargs
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("dtype") == torch.bfloat16:
|
||||
updated_kwargs = node.kwargs.copy()
|
||||
updated_kwargs["dtype"] = torch.float16
|
||||
node.kwargs = updated_kwargs
|
||||
|
||||
# Inputs of aten.native_layer_norm should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.native_layer_norm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (
|
||||
new_node_arg0,
|
||||
node.args[1],
|
||||
node.args[2],
|
||||
node.args[3],
|
||||
node.args[4],
|
||||
)
|
||||
|
||||
# Inputs of aten.mm should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.mm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
new_node_arg1 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[1], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node_arg0, new_node_arg1)
|
||||
|
||||
# Outputs of aten.mm should be downcasted to fp16.
|
||||
if type(node.args[0]) == torch.fx.node.Node and node.args[
|
||||
0
|
||||
].target in [torch.ops.aten.mm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
tmp = node.args[0]
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node.args[0],),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.args[0].append(new_node)
|
||||
node.args[0].replace_all_uses_with(new_node)
|
||||
new_node.args = (tmp,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
# Inputs of aten._softmax should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten._softmax]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node_arg0, node.args[1], node.args[2])
|
||||
|
||||
# Outputs of aten._softmax should be downcasted to fp16.
|
||||
if (
|
||||
type(node.args[0]) == torch.fx.node.Node
|
||||
and node.args[0].target in [torch.ops.aten._softmax]
|
||||
and node.target in [torch.ops.aten.expand]
|
||||
):
|
||||
with fx_g.graph.inserting_before(node):
|
||||
tmp = node.args[0]
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node.args[0],),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.args[0].append(new_node)
|
||||
node.args[0].replace_all_uses_with(new_node)
|
||||
new_node.args = (tmp,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
# Doesn't replace the None type.
|
||||
def change_fx_graph_return_to_tuple(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
# output nodes always have one argument
|
||||
node_arg = node.args[0]
|
||||
out_nodes = []
|
||||
if isinstance(node_arg, list):
|
||||
# Don't return NoneType elements.
|
||||
for out_node in node_arg:
|
||||
if not isinstance(out_node, type(None)):
|
||||
out_nodes.append(out_node)
|
||||
# If there is a single tensor/element to be returned don't
|
||||
# a tuple for it.
|
||||
if len(out_nodes) == 1:
|
||||
node.args = out_nodes
|
||||
else:
|
||||
node.args = (tuple(out_nodes),)
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
|
||||
def flatten_training_input(inputs):
|
||||
flattened_input = []
|
||||
for i in inputs:
|
||||
if isinstance(i, dict):
|
||||
for value in i.values():
|
||||
flattened_input.append(value.detach())
|
||||
elif isinstance(i, tuple):
|
||||
for value in i:
|
||||
flattened_input.append(value)
|
||||
else:
|
||||
flattened_input.append(i)
|
||||
return tuple(flattened_input)
|
||||
|
||||
|
||||
# TODO: Remove is_f16 and fix all calls with using precision instead
|
||||
# Applies fx conversion to the model and imports the mlir.
|
||||
def import_with_fx(
|
||||
model,
|
||||
inputs,
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
debug=False,
|
||||
training=False,
|
||||
return_str=False,
|
||||
save_dir=tempfile.gettempdir(),
|
||||
model_name="model",
|
||||
mlir_type="linalg",
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
precision="fp32",
|
||||
is_gptq=False,
|
||||
):
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
|
||||
golden_values = None
|
||||
if debug:
|
||||
try:
|
||||
golden_values = model(*inputs)
|
||||
except:
|
||||
golden_values = None
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
# TODO: Control the decompositions.
|
||||
decomps_list = [
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
torch.ops.aten.native_layer_norm,
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
||||
torch.ops.aten.index_add,
|
||||
torch.ops.aten.index_add_,
|
||||
]
|
||||
if precision in ["int4", "int8"] and not is_gptq:
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
block_quant_layer_level_manager,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
brevitas_layer_export_mode,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
LinearWeightBlockQuantHandlerFwd,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
replace_call_fn_target,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
matmul_rhs_group_quant_placeholder,
|
||||
)
|
||||
from brevitas.backport.fx.experimental.proxy_tensor import (
|
||||
make_fx as brevitas_make_fx,
|
||||
)
|
||||
|
||||
export_context_manager = brevitas_layer_export_mode
|
||||
export_class = block_quant_layer_level_manager(
|
||||
export_handlers=[LinearWeightBlockQuantHandlerFwd]
|
||||
)
|
||||
with export_context_manager(model, export_class):
|
||||
fx_g = brevitas_make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(decomps_list),
|
||||
)(*inputs)
|
||||
|
||||
transform_fx(fx_g, quantized=True)
|
||||
replace_call_fn_target(
|
||||
fx_g,
|
||||
src=matmul_rhs_group_quant_placeholder,
|
||||
target=torch.ops.quant.matmul_rhs_group_quant,
|
||||
)
|
||||
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
else:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(decomps_list),
|
||||
)(*inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
if is_f16:
|
||||
fx_g = fx_g.half()
|
||||
transform_fx(fx_g)
|
||||
# TODO: Have to make it more generic.
|
||||
add_upcast(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if is_gptq:
|
||||
gptq_transforms(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if mlir_type == "fx":
|
||||
return fx_g
|
||||
|
||||
if training:
|
||||
change_fx_graph_return_to_tuple(fx_g)
|
||||
inputs = flatten_training_input(inputs)
|
||||
|
||||
ts_graph = torch.jit.script(fx_g)
|
||||
if mlir_type == "torchscript":
|
||||
return ts_graph
|
||||
|
||||
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)
|
||||
mlir_importer = AMDSharkImporter(
|
||||
ts_graph,
|
||||
inputs,
|
||||
frontend="torch",
|
||||
return_str=return_str,
|
||||
)
|
||||
|
||||
if debug: # and not is_f16:
|
||||
(mlir_module, func_name), _, _ = mlir_importer.import_debug(
|
||||
dir=save_dir,
|
||||
model_name=model_name,
|
||||
golden_values=golden_values,
|
||||
mlir_type=mlir_type,
|
||||
is_dynamic=is_dynamic,
|
||||
tracing_required=tracing_required,
|
||||
)
|
||||
return mlir_module, func_name
|
||||
|
||||
mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type)
|
||||
return mlir_module, func_name
|
||||
|
||||
|
||||
# Saves a .mlir module python object to the directory 'dir' with 'model_name' and returns a path to the saved file.
|
||||
def save_mlir(
|
||||
mlir_module,
|
||||
model_name,
|
||||
mlir_dialect="linalg",
|
||||
frontend="torch",
|
||||
dir="",
|
||||
):
|
||||
model_name_mlir = (
|
||||
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
|
||||
)
|
||||
if dir == "":
|
||||
dir = cmd_opts.tmp_dir, #os.path.join(".", "amdshark_tmp")
|
||||
mlir_path = os.path.join(dir, model_name_mlir)
|
||||
print(f"saving {model_name_mlir} to {dir}")
|
||||
if not os.path.exists(dir):
|
||||
os.makedirs(dir)
|
||||
if frontend == "torch":
|
||||
with open(mlir_path, "wb") as mlir_file:
|
||||
mlir_file.write(mlir_module)
|
||||
|
||||
return mlir_path
|
||||
@@ -1,243 +0,0 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from amdshark.iree_utils.compile_utils import (
|
||||
export_iree_module_to_vmfb,
|
||||
load_flatbuffer,
|
||||
create_dispatch_dirs,
|
||||
compile_benchmark_dirs,
|
||||
)
|
||||
import os
|
||||
from amdshark.amdshark_runner import AMDSharkRunner
|
||||
from amdshark.parser import amdshark_args
|
||||
import numpy as np
|
||||
|
||||
|
||||
dtype_to_np_dtype = {
|
||||
"f32": np.float32,
|
||||
"f64": np.float64,
|
||||
"i32": np.int32,
|
||||
"i64": np.int64,
|
||||
"i1": np.bool_,
|
||||
}
|
||||
|
||||
|
||||
class AMDSharkInference:
|
||||
"""
|
||||
Runs prediction or inference on mlir_module.
|
||||
|
||||
...
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module or path represented in string; modules from torch-mlir are serialized in bytecode format.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
mlir_dialect: str
|
||||
The dialect in which the given mlir_module is in.
|
||||
Refer to {https://mlir.llvm.org/docs/Dialects/}
|
||||
is_benchmark: bool
|
||||
Whether this AMDSharkInference module should be benchmark-enabled.
|
||||
mmap: bool
|
||||
Whether to load/run vmfb using mmap. It's `True` by default.
|
||||
|
||||
Methods
|
||||
-------
|
||||
__call__(function_name, inputs=None):
|
||||
Runs the function with `function_name` within the mlir_module along
|
||||
with the given inputs, if the inputs are not given it autogenerates the
|
||||
inputs. Also, the inputs should be a numpy array.
|
||||
input_info():
|
||||
Gives the information about the inputs required by the `function_name`.
|
||||
This can be expensive as it does string matching to do so.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module,
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
is_benchmark: bool = False,
|
||||
dispatch_benchmark: str = None,
|
||||
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
|
||||
device_idx: int = None,
|
||||
mmap: bool = True,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
if mlir_module is not None:
|
||||
if mlir_module and not os.path.isfile(mlir_module):
|
||||
print(
|
||||
"Warning: Initializing AMDSharkInference with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize AMDSharkInference with a path to a MLIR module on your hard disk instead."
|
||||
)
|
||||
self.compile_str = True
|
||||
else:
|
||||
self.compile_str = False
|
||||
self.device = amdshark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.is_benchmark = is_benchmark
|
||||
self.device_idx = device_idx
|
||||
self.dispatch_benchmarks = (
|
||||
amdshark_args.dispatch_benchmarks
|
||||
if dispatch_benchmark is None
|
||||
else dispatch_benchmark
|
||||
)
|
||||
self.dispatch_benchmarks_dir = (
|
||||
amdshark_args.dispatch_benchmarks_dir
|
||||
if dispatch_benchmark_dir == "temp_dispatch_benchmarks"
|
||||
else dispatch_benchmark_dir
|
||||
)
|
||||
|
||||
self.amdshark_runner = None
|
||||
self.mmap = mmap
|
||||
self.rt_flags = rt_flags
|
||||
|
||||
def compile(self, extra_args=[]):
|
||||
if self.dispatch_benchmarks is not None:
|
||||
extra_args.append(
|
||||
f"--iree-hal-dump-executable-sources-to={self.dispatch_benchmarks_dir}"
|
||||
)
|
||||
extra_args.append(
|
||||
f"--iree-hal-dump-executable-binaries-to={self.dispatch_benchmarks_dir}"
|
||||
)
|
||||
temp_dir = self.dispatch_benchmarks_dir.split("/")
|
||||
temp_dir[-1] = "temp_" + temp_dir[-1]
|
||||
temp_dir = "/".join(temp_dir)
|
||||
self.temp_dispatch_benchmarks_dir = temp_dir
|
||||
extra_args.append(
|
||||
f"--iree-hal-dump-executable-benchmarks-to={self.temp_dispatch_benchmarks_dir}"
|
||||
)
|
||||
|
||||
if self.is_benchmark == True:
|
||||
from amdshark.amdshark_benchmark_runner import AMDSharkBenchmarkRunner
|
||||
|
||||
self.amdshark_runner = AMDSharkBenchmarkRunner(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
else:
|
||||
self.amdshark_runner = AMDSharkRunner(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
device_idx=self.device_idx,
|
||||
rt_flags=self.rt_flags,
|
||||
)
|
||||
|
||||
if self.dispatch_benchmarks is not None:
|
||||
create_dispatch_dirs(self.dispatch_benchmarks_dir, self.device)
|
||||
compile_benchmark_dirs(
|
||||
self.dispatch_benchmarks_dir,
|
||||
self.device,
|
||||
self.dispatch_benchmarks,
|
||||
)
|
||||
os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}")
|
||||
|
||||
# inputs are considered to be tuple of np.array.
|
||||
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
|
||||
return self.amdshark_runner.run(
|
||||
function_name, inputs, send_to_host, device=self.device
|
||||
)
|
||||
|
||||
# forward function.
|
||||
def forward(self, inputs: tuple, send_to_host=True):
|
||||
return self.amdshark_runner.run(
|
||||
"forward", inputs, send_to_host, device=self.device
|
||||
)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
def get_functions_in_module(self):
|
||||
return self.amdshark_runner.get_functions_in_module()
|
||||
|
||||
# Captures the static input information from the mlir_module.
|
||||
# TODO(pashu123): Generate the input information for dynamic shapes.
|
||||
def _input_info(self, function_name):
|
||||
# func_key to get the line which contains the function.
|
||||
func_key = "func.func @" + function_name
|
||||
func_header = None
|
||||
for line in str(self.mlir_module).splitlines():
|
||||
if func_key in line:
|
||||
func_header = line
|
||||
break
|
||||
if func_header is None:
|
||||
print(f"Function: {function_name} not found")
|
||||
|
||||
import re
|
||||
|
||||
inputs = re.findall("\(.*?\)", func_header)[0].split(",")
|
||||
shapes = []
|
||||
dtype = []
|
||||
for inp in inputs:
|
||||
shape_dtype = re.findall(r"<[^>]*>", inp)[0].split("x")
|
||||
shape_dtype[0], shape_dtype[-1] = (
|
||||
shape_dtype[0][1:],
|
||||
shape_dtype[-1][:-1],
|
||||
)
|
||||
shapes.append(tuple([int(x) for x in shape_dtype[:-1]]))
|
||||
dtype.append(shape_dtype[-1])
|
||||
|
||||
return shapes, dtype
|
||||
|
||||
# Generates random input to be feed into the graph.
|
||||
def generate_random_inputs(self, low=0, high=1):
|
||||
shapes, dtype = self._input_info()
|
||||
inputs = []
|
||||
for i, j in zip(shapes, dtype):
|
||||
inputs.append(
|
||||
np.random.uniform(low, high, size=i).astype(
|
||||
dtype_to_np_dtype[j]
|
||||
)
|
||||
)
|
||||
return tuple(inputs)
|
||||
|
||||
# TODO: Instead of passing directory and having names decided by the module
|
||||
# , user may want to save the module with manual names.
|
||||
def save_module(
|
||||
self, dir=os.getcwd(), module_name=None, extra_args=[], debug=False
|
||||
):
|
||||
return export_iree_module_to_vmfb(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
dir,
|
||||
self.mlir_dialect,
|
||||
module_name=module_name,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
|
||||
# load and return the module.
|
||||
def load_module(self, path, extra_args=[]):
|
||||
self.amdshark_runner = AMDSharkRunner(
|
||||
device=self.device,
|
||||
compile_vmfb=False,
|
||||
extra_args=extra_args,
|
||||
rt_flags=self.rt_flags,
|
||||
)
|
||||
params = load_flatbuffer(
|
||||
path,
|
||||
self.device,
|
||||
self.device_idx,
|
||||
mmap=self.mmap,
|
||||
rt_flags=self.rt_flags,
|
||||
)
|
||||
self.amdshark_runner.iree_compilation_module = params["vmfb"]
|
||||
self.amdshark_runner.iree_config = params["config"]
|
||||
self.amdshark_runner.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
return
|
||||
@@ -1,127 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from amdshark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
get_results,
|
||||
export_iree_module_to_vmfb,
|
||||
load_flatbuffer,
|
||||
)
|
||||
from amdshark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from amdshark.parser import amdshark_args
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
# supported dialects by the amdshark-runtime.
|
||||
supported_dialects = {
|
||||
"linalg",
|
||||
"auto",
|
||||
"stablehlo",
|
||||
"tosa",
|
||||
"tf-lite",
|
||||
"tm_tensor",
|
||||
}
|
||||
|
||||
|
||||
class AMDSharkRunner:
|
||||
"""
|
||||
Base class for AMDSharkInference and AMDSharkTrainer
|
||||
used to execute an mlir_module.
|
||||
|
||||
...
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module path, string, or bytecode.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
mlir_dialect: str
|
||||
The dialect in which the given mlir_module is in.
|
||||
Refer to {https://mlir.llvm.org/docs/Dialects/}
|
||||
|
||||
Methods
|
||||
-------
|
||||
run(function_name, inputs=None):
|
||||
Runs the function with `function_name` within the mlir_module along
|
||||
with the given inputs, if the inputs are not given it autogenerates the
|
||||
inputs. Also, the inputs should be a numpy array.
|
||||
input_info():
|
||||
Gives the information about the inputs required by the `function_name`.
|
||||
This can be expensive as it does string matching to do so.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: bytes = None,
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
extra_args: list = [],
|
||||
compile_vmfb: bool = True,
|
||||
device_idx: int = None,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
if self.mlir_module is not None:
|
||||
if not os.path.isfile(mlir_module):
|
||||
print(
|
||||
"Warning: Initializing AMDSharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize AMDSharkInference with a path to a MLIR module on your hard disk instead."
|
||||
)
|
||||
self.compile_str = True
|
||||
else:
|
||||
self.compile_str = False
|
||||
self.device = amdshark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
self.device_idx = device_idx
|
||||
self.rt_flags = rt_flags
|
||||
|
||||
if check_device_drivers(self.device):
|
||||
print(device_driver_info(self.device))
|
||||
sys.exit(1)
|
||||
|
||||
if compile_vmfb == True:
|
||||
# Compile the module to get the .vmfb.
|
||||
params = get_iree_compiled_module(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
device_idx=self.device_idx,
|
||||
rt_flags=self.rt_flags,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
self.iree_compilation_module = params["vmfb"]
|
||||
self.iree_config = params["config"]
|
||||
self.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
|
||||
def run(
|
||||
self, function_name, inputs: tuple, send_to_host=False, device=None
|
||||
):
|
||||
return get_results(
|
||||
self.iree_compilation_module,
|
||||
function_name,
|
||||
inputs,
|
||||
self.iree_config,
|
||||
self.mlir_dialect,
|
||||
send_to_host,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
def get_functions_in_module(self):
|
||||
return self.iree_compilation_module._vm_module.function_names
|
||||
@@ -1,154 +0,0 @@
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from torch._decomp import get_decompositions
|
||||
from torch.func import functionalize
|
||||
import io
|
||||
import torch_mlir
|
||||
|
||||
|
||||
# TODO: Control decompositions.
|
||||
def default_decompositions():
|
||||
return get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
torch.ops.aten.native_layer_norm,
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
class AMDSharkBackend:
|
||||
def __init__(
|
||||
self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict
|
||||
):
|
||||
self.fx_g = fx_g
|
||||
self.inputs = inputs
|
||||
self.amdshark_module = None
|
||||
self.device: str = options.get("device", "cpu")
|
||||
self.was_unwrapped: bool = False
|
||||
self.none_indices: list = []
|
||||
self._modify_fx_g()
|
||||
self.compile()
|
||||
|
||||
def _modify_fx_g(self):
|
||||
self.none_indices = _remove_nones(self.fx_g)
|
||||
self.was_unwrapped = _unwrap_single_tuple_return(self.fx_g)
|
||||
|
||||
def compile(self):
|
||||
gm = make_fx(
|
||||
functionalize(self.fx_g),
|
||||
decomposition_table=default_decompositions(),
|
||||
)(*self.inputs)
|
||||
gm.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
gm.recompile()
|
||||
strip_overloads(gm)
|
||||
ts_g = torch.jit.script(gm)
|
||||
mlir_module = torch_mlir.compile(
|
||||
ts_g, self.inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=self.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
amdshark_module.compile(extra_args=[])
|
||||
self.amdshark_module = amdshark_module
|
||||
|
||||
def __call__(self, *inputs):
|
||||
np_inputs = [x.contiguous().detach().cpu().numpy() for x in inputs]
|
||||
np_outs = self.amdshark_module("forward", np_inputs)
|
||||
if self.was_unwrapped:
|
||||
np_outs = [
|
||||
np_outs,
|
||||
]
|
||||
|
||||
if not isinstance(np_outs, list):
|
||||
res = torch.from_numpy(np_outs)
|
||||
return res
|
||||
|
||||
result = [torch.from_numpy(x) for x in np_outs]
|
||||
for r_in in self.none_indices:
|
||||
result.insert(r_in, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
@@ -1,25 +0,0 @@
|
||||
import torch
|
||||
import amdshark
|
||||
|
||||
|
||||
def foo(x, a):
|
||||
if x.shape[0] > 3:
|
||||
return x + a
|
||||
else:
|
||||
return x + 3
|
||||
|
||||
|
||||
amdshark_options = {"device": "cpu"}
|
||||
compiled = torch.compile(foo, backend="amdshark", options=amdshark_options)
|
||||
|
||||
input = torch.ones(4)
|
||||
|
||||
x = compiled(input, input)
|
||||
|
||||
print(x)
|
||||
|
||||
input = torch.ones(3)
|
||||
|
||||
x = compiled(input, input)
|
||||
|
||||
print(x)
|
||||
@@ -1,73 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
model = torch.hub.load(
|
||||
"pytorch/vision:v0.10.0", "squeezenet1_0", pretrained=True
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# from PIL import Image
|
||||
# from torchvision import transforms
|
||||
# import urllib
|
||||
#
|
||||
# url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
|
||||
# try: urllib.URLopener().retrieve(url, filename)
|
||||
# except: urllib.request.urlretrieve(url, filename)
|
||||
#
|
||||
#
|
||||
# input_image = Image.open(filename)
|
||||
# preprocess = transforms.Compose([
|
||||
# transforms.Resize(256),
|
||||
# transforms.CenterCrop(224),
|
||||
# transforms.ToTensor(),
|
||||
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
# ])
|
||||
# input_tensor = preprocess(input_image)
|
||||
# input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
||||
# print(input_batch.shape) # size = [1, 3, 224, 224]
|
||||
|
||||
# The above is code for generating sample inputs from an image. We can just use
|
||||
# random values for accuracy testing though
|
||||
input_batch = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
# Focus on CPU for now
|
||||
if False and torch.cuda.is_available():
|
||||
input_batch = input_batch.to("cuda")
|
||||
model.to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_batch)
|
||||
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
|
||||
golden_confidences = output[0]
|
||||
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
|
||||
golden_probabilities = torch.nn.functional.softmax(
|
||||
golden_confidences, dim=0
|
||||
).numpy()
|
||||
|
||||
golden_confidences = golden_confidences.numpy()
|
||||
|
||||
from amdshark.torch_mlir_lockstep_tensor import TorchMLIRLockstepTensor
|
||||
|
||||
input_detached_clone = input_batch.clone()
|
||||
eager_input_batch = TorchMLIRLockstepTensor(input_detached_clone)
|
||||
|
||||
print("getting torch-mlir result")
|
||||
|
||||
output = model(eager_input_batch)
|
||||
|
||||
static_output = output.elem
|
||||
confidences = static_output[0]
|
||||
probabilities = torch.nn.functional.softmax(
|
||||
torch.from_numpy(confidences), dim=0
|
||||
).numpy()
|
||||
|
||||
print("The obtained result via amdshark is: ", confidences)
|
||||
print("The golden result is:", golden_confidences)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
golden_confidences, confidences, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
golden_probabilities, probabilities, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
@@ -1,65 +0,0 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import CLIPProcessor, TFCLIPModel
|
||||
import tensorflow as tf
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
# Create a set of inputs
|
||||
clip_vit_inputs = [
|
||||
tf.TensorSpec(shape=[2, 7], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[2, 7], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
|
||||
]
|
||||
|
||||
|
||||
class CLIPModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(CLIPModule, self).__init__()
|
||||
self.m = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
self.m.predict = lambda x, y, z: self.m(
|
||||
input_ids=x, attention_mask=y, pixel_values=z
|
||||
)
|
||||
|
||||
@tf.function(input_signature=clip_vit_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, pixel_values):
|
||||
return self.m.predict(
|
||||
input_ids, attention_mask, pixel_values
|
||||
).logits_per_image
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(
|
||||
text=["a photo of a cat", "a photo of a dog"],
|
||||
images=image,
|
||||
return_tensors="tf",
|
||||
padding=True,
|
||||
)
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
CLIPModule(),
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["pixel_values"],
|
||||
),
|
||||
)
|
||||
amdshark_module.set_frontend("tensorflow")
|
||||
amdshark_module.compile()
|
||||
|
||||
print(
|
||||
amdshark_module.forward(
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["pixel_values"],
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -1,15 +0,0 @@
|
||||
## Running ESRGAN
|
||||
|
||||
```
|
||||
1. pip install numpy opencv-python
|
||||
2. mkdir InputImages
|
||||
(this is where all the input images will reside in)
|
||||
3. mkdir OutputImages
|
||||
(this is where the model will generate all the images)
|
||||
4. mkdir models
|
||||
(save the .pth checkpoint file here)
|
||||
5. python esrgan.py
|
||||
```
|
||||
|
||||
- Download [RRDB_ESRGAN_x4.pth](https://drive.google.com/drive/u/0/folders/17VYV_SoZZesU6mbxz2dMAIccSSlqLecY) and place it in the `models` directory as mentioned above in step 4.
|
||||
- Credits : [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
@@ -1,239 +0,0 @@
|
||||
from ast import arg
|
||||
import os.path as osp
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""Residual in Residual Dense Block"""
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(
|
||||
self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
fea = self.lrelu(
|
||||
self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument(
|
||||
"--mlir_loc",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location of the model's mlir file",
|
||||
)
|
||||
args = p.parse_args()
|
||||
###################################################
|
||||
|
||||
|
||||
def inference(input_m):
|
||||
return model(input_m)
|
||||
|
||||
|
||||
def load_mlir(mlir_loc):
|
||||
import os
|
||||
|
||||
if mlir_loc == None:
|
||||
return None
|
||||
print(f"Trying to load the model from {mlir_loc}.")
|
||||
with open(os.path.join(mlir_loc)) as f:
|
||||
mlir_module = f.read()
|
||||
return mlir_module
|
||||
|
||||
|
||||
def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
module = load_mlir(mlir_loc)
|
||||
if module == None:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
print("Torchscript graph generated successfully")
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
inputs,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mlir_model = str(module)
|
||||
func_name = "forward"
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
amdshark_module.compile()
|
||||
|
||||
return amdshark_module
|
||||
|
||||
|
||||
model_path = "models/RRDB_ESRGAN_x4.pth" # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
|
||||
# device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
|
||||
device = torch.device("cpu")
|
||||
|
||||
test_img_folder = "InputImages/*"
|
||||
|
||||
model = RRDBNet(3, 3, 64, 23, gc=32)
|
||||
model.load_state_dict(torch.load(model_path), strict=True)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
print("Model path {:s}. \nTesting...".format(model_path))
|
||||
|
||||
if __name__ == "__main__":
|
||||
idx = 0
|
||||
for path in glob.glob(test_img_folder):
|
||||
idx += 1
|
||||
base = osp.splitext(osp.basename(path))[0]
|
||||
print(idx, base)
|
||||
# read images
|
||||
img = cv2.imread(path, cv2.IMREAD_COLOR)
|
||||
img = img * 1.0 / 255
|
||||
img = torch.from_numpy(
|
||||
np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))
|
||||
).float()
|
||||
img_LR = img.unsqueeze(0)
|
||||
img_LR = img_LR.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
amdshark_module = compile_through_fx(inference, img_LR)
|
||||
amdshark_output = amdshark_module.forward((img_LR,))
|
||||
amdshark_output = torch.from_numpy(amdshark_output)
|
||||
amdshark_output = (
|
||||
amdshark_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
)
|
||||
esrgan_output = (
|
||||
model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
)
|
||||
# AMDSHARK OUTPUT
|
||||
amdshark_output = np.transpose(amdshark_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
amdshark_output = (amdshark_output * 255.0).round()
|
||||
cv2.imwrite(
|
||||
"OutputImages/{:s}_rlt_amdshark_output.png".format(base), amdshark_output
|
||||
)
|
||||
print("Generated AMDSHARK's output")
|
||||
# ESRGAN OUTPUT
|
||||
esrgan_output = np.transpose(esrgan_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
esrgan_output = (esrgan_output * 255.0).round()
|
||||
cv2.imwrite(
|
||||
"OutputImages/{:s}_rlt_esrgan_output.png".format(base),
|
||||
esrgan_output,
|
||||
)
|
||||
print("Generated ESRGAN's output")
|
||||
@@ -1,86 +0,0 @@
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import torch
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
from iree.compiler import compile_str
|
||||
from iree import runtime as ireert
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
|
||||
class AlbertModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForMaskedLM.from_pretrained("albert-base-v2")
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.model(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).logits
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||
text = "This [MASK] is very tasty."
|
||||
encoded_inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
|
||||
mlir_importer = AMDSharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=True
|
||||
)
|
||||
amdshark_module = AMDSharkInference(minilm_mlir)
|
||||
amdshark_module.compile()
|
||||
token_logits = torch.tensor(amdshark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
encoded_inputs["input_ids"] == tokenizer.mask_token_id
|
||||
)[1]
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence with [MASK] to fill: ")
|
||||
encoded_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
token_logits = torch.tensor(amdshark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
encoded_inputs["input_ids"] == tokenizer.mask_token_id
|
||||
)[1]
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = (
|
||||
torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
|
||||
)
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
break
|
||||
@@ -1,100 +0,0 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
|
||||
import tensorflow as tf
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
from iree.compiler import tf as tfc
|
||||
from iree.compiler import compile_str
|
||||
from iree import runtime as ireert
|
||||
import os
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
# Create a set of inputs
|
||||
t5_inputs = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class AlbertModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(AlbertModule, self).__init__()
|
||||
self.m = TFAutoModelForMaskedLM.from_pretrained("albert-base-v2")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||
# text = "This is a great [MASK]."
|
||||
text = "This [MASK] is very tasty."
|
||||
encoded_inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="tf",
|
||||
)
|
||||
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
|
||||
mlir_importer = AMDSharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="tf",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=False
|
||||
)
|
||||
amdshark_module = AMDSharkInference(minilm_mlir, mlir_dialect="mhlo")
|
||||
amdshark_module.compile()
|
||||
output_idx = 0
|
||||
data_idx = 1
|
||||
token_logits = amdshark_module.forward(inputs)[output_idx][data_idx]
|
||||
mask_id = np.where(
|
||||
tf.squeeze(encoded_inputs["input_ids"]) == tokenizer.mask_token_id
|
||||
)
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[0:5]
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence with [MASK] to fill: ")
|
||||
encoded_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="tf",
|
||||
)
|
||||
inputs = (
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
token_logits = amdshark_module.forward(inputs)[output_idx][data_idx]
|
||||
mask_id = np.where(
|
||||
tf.squeeze(encoded_inputs["input_ids"])
|
||||
== tokenizer.mask_token_id
|
||||
)
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[
|
||||
0:5
|
||||
]
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
sys.exit()
|
||||
@@ -1,14 +0,0 @@
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"bloom", frontend="torch"
|
||||
)
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_model, device="cpu", mlir_dialect="tm_tensor"
|
||||
)
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward(inputs)
|
||||
print("The obtained result via amdshark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
@@ -1,18 +0,0 @@
|
||||
# AMDSHARK LLaMA
|
||||
|
||||
## TORCH-MLIR Version
|
||||
|
||||
```
|
||||
https://github.com/nod-ai/torch-mlir.git
|
||||
```
|
||||
Then check out the `complex` branch and `git submodule update --init` and then build with `.\build_tools\python_deploy\build_windows.ps1`
|
||||
|
||||
### Setup & Run
|
||||
```
|
||||
git clone https://github.com/nod-ai/llama.git
|
||||
```
|
||||
Then in this repository
|
||||
```
|
||||
pip install -e .
|
||||
python llama/amdshark_model.py
|
||||
```
|
||||
@@ -1,72 +0,0 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_compile import amdshark_compile_through_fx
|
||||
from MEGABYTE_pytorch import MEGABYTE
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class MegaModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MEGABYTE(
|
||||
num_tokens=16000, # number of tokens
|
||||
dim=(
|
||||
512,
|
||||
256,
|
||||
), # transformer model dimension (512 for coarsest, 256 for fine in this example)
|
||||
max_seq_len=(
|
||||
1024,
|
||||
4,
|
||||
), # sequence length for global and then local. this can be more than 2
|
||||
depth=(
|
||||
6,
|
||||
4,
|
||||
), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
|
||||
dim_head=64, # dimension per head
|
||||
heads=8, # number of attention heads
|
||||
flash_attn=True, # use flash attention
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
megaModel = MegaModel()
|
||||
inputs = [torch.randint(0, 16000, (1, 1024, 4))]
|
||||
|
||||
# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :-
|
||||
# 1. aten.alias
|
||||
amdshark_module, _ = amdshark_compile_through_fx(
|
||||
model=megaModel,
|
||||
inputs=inputs,
|
||||
extended_model_name="mega_amdshark",
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
save_dir=os.getcwd(),
|
||||
debug=False,
|
||||
generate_or_load_vmfb=True,
|
||||
extra_args=[],
|
||||
device="cuda",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
# logits = model(x)
|
||||
|
||||
|
||||
def print_output_info(output, msg):
|
||||
print("\n", msg)
|
||||
print("\n\t", output.shape)
|
||||
|
||||
|
||||
ans = amdshark_module("forward", inputs)
|
||||
print_output_info(torch.from_numpy(ans), "AMDSHARK's output")
|
||||
|
||||
ans = megaModel.forward(*inputs)
|
||||
print_output_info(ans, "ORIGINAL Model's output")
|
||||
|
||||
# and sample from the logits accordingly
|
||||
# or you can use the generate function
|
||||
|
||||
# NEED TO LOOK AT THIS LATER IF REQUIRED IN AMDSHARK.
|
||||
# sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
|
||||
@@ -1,31 +0,0 @@
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
import numpy as np
|
||||
|
||||
mhlo_ir = r"""builtin.module {
|
||||
func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32>
|
||||
%1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
return %1 : tensor<4x4xf32>
|
||||
}
|
||||
}"""
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
|
||||
print("Running amdshark on cpu backend")
|
||||
amdshark_module = AMDSharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
||||
|
||||
# Generate the random inputs and feed into the graph.
|
||||
x = amdshark_module.generate_random_inputs()
|
||||
amdshark_module.compile()
|
||||
print(amdshark_module.forward(x))
|
||||
|
||||
print("Running amdshark on cuda backend")
|
||||
amdshark_module = AMDSharkInference(mhlo_ir, device="cuda", mlir_dialect="mhlo")
|
||||
amdshark_module.compile()
|
||||
print(amdshark_module.forward(x))
|
||||
|
||||
print("Running amdshark on vulkan backend")
|
||||
amdshark_module = AMDSharkInference(mhlo_ir, device="vulkan", mlir_dialect="mhlo")
|
||||
amdshark_module.compile()
|
||||
print(amdshark_module.forward(x))
|
||||
@@ -1,73 +0,0 @@
|
||||
from transformers import AutoTokenizer, FlaxAutoModel
|
||||
import torch
|
||||
import jax
|
||||
from typing import Union, Dict, List, Any
|
||||
import numpy as np
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
import io
|
||||
|
||||
NumpyTree = Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]
|
||||
|
||||
|
||||
def convert_torch_tensor_tree_to_numpy(
|
||||
tree: Union[torch.tensor, Dict[str, torch.tensor], List[torch.tensor]]
|
||||
) -> NumpyTree:
|
||||
return jax.tree_util.tree_map(
|
||||
lambda torch_tensor: torch_tensor.cpu().detach().numpy(), tree
|
||||
)
|
||||
|
||||
|
||||
def convert_int64_to_int32(tree: NumpyTree) -> NumpyTree:
|
||||
return jax.tree_util.tree_map(
|
||||
lambda tensor: np.array(tensor, dtype=np.int32)
|
||||
if tensor.dtype == np.int64
|
||||
else tensor,
|
||||
tree,
|
||||
)
|
||||
|
||||
|
||||
def get_sample_input():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
inputs_torch = tokenizer("Hello, World!", return_tensors="pt")
|
||||
return convert_int64_to_int32(
|
||||
convert_torch_tensor_tree_to_numpy(inputs_torch.data)
|
||||
)
|
||||
|
||||
|
||||
def get_jax_model():
|
||||
return FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
|
||||
|
||||
def export_jax_to_mlir(jax_model: Any, sample_input: NumpyTree):
|
||||
model_mlir = jax.jit(jax_model).lower(**sample_input).compiler_ir()
|
||||
byte_stream = io.BytesIO()
|
||||
model_mlir.operation.write_bytecode(file=byte_stream)
|
||||
return byte_stream.getvalue()
|
||||
|
||||
|
||||
def assert_array_list_allclose(x, y, *args, **kwargs):
|
||||
assert len(x) == len(y)
|
||||
for a, b in zip(x, y):
|
||||
np.testing.assert_allclose(
|
||||
np.asarray(a), np.asarray(b), *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
sample_input = get_sample_input()
|
||||
jax_model = get_jax_model()
|
||||
mlir = export_jax_to_mlir(jax_model, sample_input)
|
||||
|
||||
# Compile and load module.
|
||||
amdshark_inference = AMDSharkInference(mlir_module=mlir, mlir_dialect="mhlo")
|
||||
amdshark_inference.compile()
|
||||
|
||||
# Run main function.
|
||||
result = amdshark_inference("main", jax.tree_util.tree_flatten(sample_input)[0])
|
||||
|
||||
# Run JAX model.
|
||||
reference_result = jax.tree_util.tree_flatten(jax_model(**sample_input))[0]
|
||||
|
||||
# Verify result.
|
||||
assert_array_list_allclose(result, reference_result, atol=1e-5)
|
||||
@@ -1,6 +0,0 @@
|
||||
flax
|
||||
jax[cpu]
|
||||
nodai-AMDSHARK
|
||||
orbax
|
||||
transformers
|
||||
torch
|
||||
@@ -1,23 +0,0 @@
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased",
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
|
||||
amdshark_module = AMDSharkInference(mlir_model, device="cpu", mlir_dialect="linalg")
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward(inputs)
|
||||
print("The obtained result via amdshark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
|
||||
|
||||
# Let's generate random inputs, currently supported
|
||||
# for static models.
|
||||
rand_inputs = amdshark_module.generate_random_inputs()
|
||||
rand_results = amdshark_module.forward(rand_inputs)
|
||||
|
||||
print("Running amdshark_module with random_inputs is: ", rand_results)
|
||||
@@ -1,39 +0,0 @@
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
|
||||
torch.hub.list("zhanghang1989/ResNeSt", force_reload=True)
|
||||
|
||||
|
||||
class ResnestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = torch.hub.load(
|
||||
"zhanghang1989/ResNeSt", "resnest50", pretrained=True
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
mlir_importer = AMDSharkImporter(
|
||||
ResnestModule(),
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(vision_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
print(golden_out)
|
||||
|
||||
amdshark_module = AMDSharkInference(vision_mlir, mlir_dialect="linalg")
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward((input,))
|
||||
print("Obtained result", result)
|
||||
@@ -1,74 +0,0 @@
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.parser import amdshark_args
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import sys
|
||||
import torchvision.models as models
|
||||
import torch_mlir
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = models.resnet50(pretrained=True)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
model = VisionModule()
|
||||
test_input = torch.randn(1, 3, 224, 224)
|
||||
actual_out = model(test_input)
|
||||
|
||||
test_input_fp16 = test_input.to(device=torch.device("cuda"), dtype=torch.half)
|
||||
model_fp16 = model.half()
|
||||
model_fp16.eval()
|
||||
model_fp16.to("cuda")
|
||||
actual_out_fp16 = model_fp16(test_input_fp16)
|
||||
|
||||
ts_g = torch.jit.trace(model_fp16, [test_input_fp16])
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(test_input_fp16),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# from contextlib import redirect_stdout
|
||||
|
||||
# with open('resnet50_fp16_linalg_ir.mlir', 'w') as f:
|
||||
# with redirect_stdout(f):
|
||||
# print(module.operation.get_asm())
|
||||
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
amdshark_module = AMDSharkInference(mlir_model, device="cuda", mlir_dialect="linalg")
|
||||
amdshark_module.compile()
|
||||
|
||||
|
||||
def amdshark_result(x):
|
||||
x_ny = x.cpu().detach().numpy()
|
||||
inputs = (x_ny,)
|
||||
result = amdshark_module.forward(inputs)
|
||||
return torch.from_numpy(result)
|
||||
|
||||
|
||||
observed_out = amdshark_result(test_input_fp16)
|
||||
|
||||
print("Golden result:", actual_out_fp16)
|
||||
print("AMDSHARK result:", observed_out)
|
||||
|
||||
actual_out_fp16 = actual_out_fp16.to(device=torch.device("cpu"))
|
||||
|
||||
print(
|
||||
torch.testing.assert_allclose(
|
||||
actual_out_fp16, observed_out, rtol=1e-2, atol=1e-2
|
||||
)
|
||||
)
|
||||
@@ -1,842 +0,0 @@
|
||||
####################################################################################
|
||||
# Please make sure you have transformers 4.21.2 installed before running this demo
|
||||
#
|
||||
# -p --model_path: the directory in which you want to store the bloom files.
|
||||
# -dl --device_list: the list of device indices you want to use. if you want to only use the first device, or you are running on cpu leave this blank.
|
||||
# Otherwise, please give this argument in this format: "[0, 1, 2]"
|
||||
# -de --device: the device you want to run bloom on. E.G. cpu, cuda
|
||||
# -c, --recompile: set to true if you want to recompile to vmfb.
|
||||
# -d, --download: set to true if you want to redownload the mlir files
|
||||
# -cm, --create_mlirs: set to true if you want to create the mlir files from scratch. please make sure you have transformers 4.21.2 before using this option
|
||||
# -t --token_count: the number of tokens you want to generate
|
||||
# -pr --prompt: the prompt you want to feed to the model
|
||||
# -m --model_name: the name of the model, e.g. bloom-560m
|
||||
#
|
||||
# If you don't specify a prompt when you run this example, you will be able to give prompts through the terminal. Run the
|
||||
# example in this way if you want to run multiple examples without reinitializing the model
|
||||
#####################################################################################
|
||||
|
||||
import os
|
||||
import io
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
import torch_mlir
|
||||
from torch_mlir import TensorPlaceholder
|
||||
import re
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
||||
import json
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
import urllib.request
|
||||
import subprocess
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_downloader import download_public_file
|
||||
from transformers import (
|
||||
BloomTokenizerFast,
|
||||
BloomForSequenceClassification,
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from transformers.models.bloom.modeling_bloom import (
|
||||
BloomBlock,
|
||||
build_alibi_tensor,
|
||||
)
|
||||
|
||||
IS_CUDA = False
|
||||
|
||||
|
||||
class ShardedBloom:
|
||||
def __init__(self, src_folder):
|
||||
f = open(f"{src_folder}/config.json")
|
||||
config = json.load(f)
|
||||
f.close()
|
||||
|
||||
self.layers_initialized = False
|
||||
|
||||
self.src_folder = src_folder
|
||||
try:
|
||||
self.n_embed = config["n_embed"]
|
||||
except KeyError:
|
||||
self.n_embed = config["hidden_size"]
|
||||
self.vocab_size = config["vocab_size"]
|
||||
self.n_layer = config["n_layer"]
|
||||
try:
|
||||
self.n_head = config["num_attention_heads"]
|
||||
except KeyError:
|
||||
self.n_head = config["n_head"]
|
||||
|
||||
def _init_layer(self, layer_name, device, replace, device_idx):
|
||||
if replace or not os.path.exists(
|
||||
f"{self.src_folder}/{layer_name}.vmfb"
|
||||
):
|
||||
f_ = open(f"{self.src_folder}/{layer_name}.mlir", encoding="utf-8")
|
||||
module = f_.read()
|
||||
f_.close()
|
||||
module = bytes(module, "utf-8")
|
||||
amdshark_module = AMDSharkInference(
|
||||
module,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
amdshark_module.save_module(
|
||||
module_name=f"{self.src_folder}/{layer_name}",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
else:
|
||||
amdshark_module = AMDSharkInference(
|
||||
"",
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
return amdshark_module
|
||||
|
||||
def init_layers(self, device, replace=False, device_idx=[0]):
|
||||
if device_idx is not None:
|
||||
n_devices = len(device_idx)
|
||||
|
||||
self.word_embeddings_module = self._init_layer(
|
||||
"word_embeddings",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[0 % n_devices],
|
||||
)
|
||||
self.word_embeddings_layernorm_module = self._init_layer(
|
||||
"word_embeddings_layernorm",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[1 % n_devices],
|
||||
)
|
||||
self.ln_f_module = self._init_layer(
|
||||
"ln_f",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[2 % n_devices],
|
||||
)
|
||||
self.lm_head_module = self._init_layer(
|
||||
"lm_head",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[3 % n_devices],
|
||||
)
|
||||
self.block_modules = [
|
||||
self._init_layer(
|
||||
f"bloom_block_{i}",
|
||||
device,
|
||||
replace,
|
||||
device_idx
|
||||
if device_idx is None
|
||||
else device_idx[(i + 4) % n_devices],
|
||||
)
|
||||
for i in range(self.n_layer)
|
||||
]
|
||||
|
||||
self.layers_initialized = True
|
||||
|
||||
def load_layers(self):
|
||||
assert self.layers_initialized
|
||||
|
||||
self.word_embeddings_module.load_module(
|
||||
f"{self.src_folder}/word_embeddings.vmfb"
|
||||
)
|
||||
self.word_embeddings_layernorm_module.load_module(
|
||||
f"{self.src_folder}/word_embeddings_layernorm.vmfb"
|
||||
)
|
||||
for block_module, i in zip(self.block_modules, range(self.n_layer)):
|
||||
block_module.load_module(f"{self.src_folder}/bloom_block_{i}.vmfb")
|
||||
self.ln_f_module.load_module(f"{self.src_folder}/ln_f.vmfb")
|
||||
self.lm_head_module.load_module(f"{self.src_folder}/lm_head.vmfb")
|
||||
|
||||
def forward_pass(self, input_ids, device):
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.word_embeddings_module.device_idx)
|
||||
|
||||
input_embeds = self.word_embeddings_module(
|
||||
inputs=(input_ids,), function_name="forward"
|
||||
)
|
||||
|
||||
input_embeds = torch.tensor(input_embeds).float()
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.word_embeddings_layernorm_module.device_idx)
|
||||
hidden_states = self.word_embeddings_layernorm_module(
|
||||
inputs=(input_embeds,), function_name="forward"
|
||||
)
|
||||
|
||||
hidden_states = torch.tensor(hidden_states).float()
|
||||
|
||||
attention_mask = torch.ones(
|
||||
[hidden_states.shape[0], len(input_ids[0])]
|
||||
)
|
||||
alibi = build_alibi_tensor(
|
||||
attention_mask,
|
||||
self.n_head,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
)
|
||||
|
||||
causal_mask = _prepare_attn_mask(
|
||||
attention_mask, input_ids.size(), input_embeds, 0
|
||||
)
|
||||
causal_mask = torch.tensor(causal_mask).float()
|
||||
|
||||
presents = ()
|
||||
all_hidden_states = tuple(hidden_states)
|
||||
|
||||
for block_module, i in zip(self.block_modules, range(self.n_layer)):
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(block_module.device_idx)
|
||||
|
||||
output = block_module(
|
||||
inputs=(
|
||||
hidden_states.detach().numpy(),
|
||||
alibi.detach().numpy(),
|
||||
causal_mask.detach().numpy(),
|
||||
),
|
||||
function_name="forward",
|
||||
)
|
||||
hidden_states = torch.tensor(output[0]).float()
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
presents = presents + (
|
||||
tuple(
|
||||
(
|
||||
output[1],
|
||||
output[2],
|
||||
)
|
||||
),
|
||||
)
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.ln_f_module.device_idx)
|
||||
|
||||
hidden_states = self.ln_f_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.lm_head_module.device_idx)
|
||||
|
||||
logits = self.lm_head_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
logits = torch.tensor(logits).float()
|
||||
|
||||
return torch.argmax(logits[:, -1, :], dim=-1)
|
||||
|
||||
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
past_key_values_length: int = 0,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
|
||||
mask.masked_fill_(intermediate_mask, 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
target_length, past_key_values_length, dtype=dtype
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
expanded_mask = mask[None, None, :, :].expand(
|
||||
batch_size, 1, target_length, target_length + past_key_values_length
|
||||
)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
batch_size, source_length = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else source_length
|
||||
|
||||
expanded_mask = (
|
||||
mask[:, None, None, :]
|
||||
.expand(batch_size, 1, tgt_len, source_length)
|
||||
.to(dtype)
|
||||
)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def _prepare_attn_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
past_key_values_length=past_key_values_length,
|
||||
).to(attention_mask.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
|
||||
def download_model(destination_folder, model_name):
|
||||
download_public_file(
|
||||
f"gs://amdshark_tank/sharded_bloom/{model_name}/", destination_folder
|
||||
)
|
||||
|
||||
|
||||
def compile_embeddings(embeddings_layer, input_ids, path):
|
||||
input_ids_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
input_ids, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
embeddings_layer,
|
||||
(input_ids_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def compile_word_embeddings_layernorm(
|
||||
embeddings_layer_layernorm, embeds, path
|
||||
):
|
||||
embeds_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
embeds, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
embeddings_layer_layernorm,
|
||||
(embeds_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def compile_to_mlir(
|
||||
bblock,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
output_attentions=False,
|
||||
alibi=None,
|
||||
block_index=0,
|
||||
path=".",
|
||||
):
|
||||
fx_g = make_fx(
|
||||
bblock,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
tracing_mode="real",
|
||||
_allow_non_fake_inputs=False,
|
||||
)(hidden_states, alibi, attention_mask)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
attention_mask, dynamic_axes=[2, 3]
|
||||
)
|
||||
alibi_placeholder = TensorPlaceholder.like(alibi, dynamic_axes=[2])
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
hidden_states_placeholder,
|
||||
alibi_placeholder,
|
||||
attention_mask_placeholder,
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
module_placeholder = module
|
||||
module_context = module_placeholder.context
|
||||
|
||||
def check_valid_line(line, line_n, mlir_file_len):
|
||||
if "private" in line:
|
||||
return False
|
||||
if "attributes" in line:
|
||||
return False
|
||||
if mlir_file_len - line_n == 2:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
mlir_file_len = len(str(module).split("\n"))
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "17x" in line:
|
||||
line = re.sub("17x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi eq" in line:
|
||||
line = re.sub("c17", "dim", line)
|
||||
if " 17," in line:
|
||||
line = re.sub(" 17,", " %dim,", line)
|
||||
return line
|
||||
|
||||
module = "\n".join(
|
||||
[
|
||||
remove_constant_dim(line)
|
||||
for line, line_n in zip(
|
||||
str(module).split("\n"), range(mlir_file_len)
|
||||
)
|
||||
if check_valid_line(line, line_n, mlir_file_len)
|
||||
]
|
||||
)
|
||||
|
||||
module = module_placeholder.parse(module, context=module_context)
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def compile_ln_f(ln_f, hidden_layers, path):
|
||||
hidden_layers_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_layers, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ln_f,
|
||||
(hidden_layers_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def compile_lm_head(lm_head, hidden_layers, path):
|
||||
hidden_layers_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_layers, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
lm_head,
|
||||
(hidden_layers_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def create_mlirs(destination_folder, model_name):
|
||||
model_config = "bigscience/" + model_name
|
||||
sample_input_ids = torch.ones([1, 17], dtype=torch.int64)
|
||||
|
||||
urllib.request.urlretrieve(
|
||||
f"https://huggingface.co/bigscience/{model_name}/resolve/main/config.json",
|
||||
filename=f"{destination_folder}/config.json",
|
||||
)
|
||||
urllib.request.urlretrieve(
|
||||
f"https://huggingface.co/bigscience/bloom/resolve/main/tokenizer.json",
|
||||
filename=f"{destination_folder}/tokenizer.json",
|
||||
)
|
||||
|
||||
class HuggingFaceLanguage(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = BloomForCausalLM.from_pretrained(model_config)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
class HuggingFaceBlock(torch.nn.Module):
|
||||
def __init__(self, block):
|
||||
super().__init__()
|
||||
self.model = block
|
||||
|
||||
def forward(self, tokens, alibi, attention_mask):
|
||||
output = self.model(
|
||||
hidden_states=tokens,
|
||||
alibi=alibi,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
output_attentions=False,
|
||||
)
|
||||
return (output[0], output[1][0], output[1][1])
|
||||
|
||||
model = HuggingFaceLanguage()
|
||||
|
||||
compile_embeddings(
|
||||
model.model.transformer.word_embeddings,
|
||||
sample_input_ids,
|
||||
f"{destination_folder}/word_embeddings.mlir",
|
||||
)
|
||||
|
||||
inputs_embeds = model.model.transformer.word_embeddings(sample_input_ids)
|
||||
|
||||
compile_word_embeddings_layernorm(
|
||||
model.model.transformer.word_embeddings_layernorm,
|
||||
inputs_embeds,
|
||||
f"{destination_folder}/word_embeddings_layernorm.mlir",
|
||||
)
|
||||
|
||||
hidden_states = model.model.transformer.word_embeddings_layernorm(
|
||||
inputs_embeds
|
||||
)
|
||||
|
||||
input_shape = sample_input_ids.size()
|
||||
|
||||
current_sequence_length = hidden_states.shape[1]
|
||||
past_key_values_length = 0
|
||||
past_key_values = tuple([None] * len(model.model.transformer.h))
|
||||
|
||||
attention_mask = torch.ones(
|
||||
(hidden_states.shape[0], current_sequence_length), device="cpu"
|
||||
)
|
||||
|
||||
alibi = build_alibi_tensor(
|
||||
attention_mask,
|
||||
model.model.transformer.n_head,
|
||||
hidden_states.dtype,
|
||||
"cpu",
|
||||
)
|
||||
|
||||
causal_mask = _prepare_attn_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
head_mask = model.model.transformer.get_head_mask(
|
||||
None, model.model.transformer.config.n_layer
|
||||
)
|
||||
output_attentions = model.model.transformer.config.output_attentions
|
||||
|
||||
all_hidden_states = ()
|
||||
|
||||
for i, (block, layer_past) in enumerate(
|
||||
zip(model.model.transformer.h, past_key_values)
|
||||
):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
proxy_model = HuggingFaceBlock(block)
|
||||
|
||||
compile_to_mlir(
|
||||
proxy_model,
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=True,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
block_index=i,
|
||||
path=f"{destination_folder}/bloom_block_{i}.mlir",
|
||||
)
|
||||
|
||||
compile_ln_f(
|
||||
model.model.transformer.ln_f,
|
||||
hidden_states,
|
||||
f"{destination_folder}/ln_f.mlir",
|
||||
)
|
||||
hidden_states = model.model.transformer.ln_f(hidden_states)
|
||||
compile_lm_head(
|
||||
model.model.lm_head,
|
||||
hidden_states,
|
||||
f"{destination_folder}/lm_head.mlir",
|
||||
)
|
||||
|
||||
|
||||
def run_large_model(
|
||||
token_count,
|
||||
recompile,
|
||||
model_path,
|
||||
prompt,
|
||||
device_list,
|
||||
script_path,
|
||||
device,
|
||||
):
|
||||
f = open(f"{model_path}/prompt.txt", "w+")
|
||||
f.write(prompt)
|
||||
f.close()
|
||||
for i in range(token_count):
|
||||
if i == 0:
|
||||
will_compile = recompile
|
||||
else:
|
||||
will_compile = False
|
||||
f = open(f"{model_path}/prompt.txt", "r")
|
||||
prompt = f.read()
|
||||
f.close()
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"python",
|
||||
script_path,
|
||||
model_path,
|
||||
"start",
|
||||
str(will_compile),
|
||||
"cpu",
|
||||
"None",
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
for i in range(config["n_layer"]):
|
||||
if device_list is not None:
|
||||
device_idx = str(device_list[i % len(device_list)])
|
||||
else:
|
||||
device_idx = "None"
|
||||
subprocess.run(
|
||||
[
|
||||
"python",
|
||||
script_path,
|
||||
model_path,
|
||||
str(i),
|
||||
str(will_compile),
|
||||
device,
|
||||
device_idx,
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
subprocess.run(
|
||||
[
|
||||
"python",
|
||||
script_path,
|
||||
model_path,
|
||||
"end",
|
||||
str(will_compile),
|
||||
"cpu",
|
||||
"None",
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
|
||||
f = open(f"{model_path}/prompt.txt", "r")
|
||||
output = f.read()
|
||||
f.close()
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(prog="Bloom-560m")
|
||||
parser.add_argument("-p", "--model_path")
|
||||
parser.add_argument("-dl", "--device_list", default=None)
|
||||
parser.add_argument("-de", "--device", default="cpu")
|
||||
parser.add_argument("-c", "--recompile", default=False, type=bool)
|
||||
parser.add_argument("-d", "--download", default=False, type=bool)
|
||||
parser.add_argument("-t", "--token_count", default=10, type=int)
|
||||
parser.add_argument("-m", "--model_name", default="bloom-560m")
|
||||
parser.add_argument("-cm", "--create_mlirs", default=False, type=bool)
|
||||
|
||||
parser.add_argument(
|
||||
"-lm", "--large_model_memory_efficient", default=False, type=bool
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-pr",
|
||||
"--prompt",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.create_mlirs and args.large_model_memory_efficient:
|
||||
print(
|
||||
"Warning: If you need to use memory efficient mode, you probably want to use 'download' instead"
|
||||
)
|
||||
|
||||
if not os.path.isdir(args.model_path):
|
||||
os.mkdir(args.model_path)
|
||||
|
||||
if args.device_list is not None:
|
||||
args.device_list = json.loads(args.device_list)
|
||||
|
||||
if args.device == "cuda" and args.device_list is not None:
|
||||
IS_CUDA = True
|
||||
from cuda.cudart import cudaSetDevice
|
||||
if args.download and args.create_mlirs:
|
||||
print(
|
||||
"WARNING: It is not advised to turn on both download and create_mlirs"
|
||||
)
|
||||
if args.download:
|
||||
download_model(args.model_path, args.model_name)
|
||||
if args.create_mlirs:
|
||||
create_mlirs(args.model_path, args.model_name)
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
if args.prompt is not None:
|
||||
input_ids = tokenizer.encode(args.prompt, return_tensors="pt")
|
||||
|
||||
if args.large_model_memory_efficient:
|
||||
f = open(f"{args.model_path}/config.json")
|
||||
config = json.load(f)
|
||||
f.close()
|
||||
|
||||
self_path = os.path.dirname(os.path.abspath(__file__))
|
||||
script_path = os.path.join(self_path, "sharded_bloom_large_models.py")
|
||||
|
||||
if args.prompt is not None:
|
||||
run_large_model(
|
||||
args.token_count,
|
||||
args.recompile,
|
||||
args.model_path,
|
||||
args.prompt,
|
||||
args.device_list,
|
||||
script_path,
|
||||
args.device,
|
||||
)
|
||||
|
||||
else:
|
||||
while True:
|
||||
prompt = input("Enter Prompt: ")
|
||||
try:
|
||||
token_count = int(
|
||||
input("Enter number of tokens you want to generate: ")
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"Invalid integer entered. Using default value of 10"
|
||||
)
|
||||
token_count = 10
|
||||
|
||||
run_large_model(
|
||||
token_count,
|
||||
args.recompile,
|
||||
args.model_path,
|
||||
prompt,
|
||||
args.device_list,
|
||||
script_path,
|
||||
args.device,
|
||||
)
|
||||
|
||||
else:
|
||||
shardedbloom = ShardedBloom(args.model_path)
|
||||
shardedbloom.init_layers(
|
||||
device=args.device,
|
||||
replace=args.recompile,
|
||||
device_idx=args.device_list,
|
||||
)
|
||||
shardedbloom.load_layers()
|
||||
|
||||
if args.prompt is not None:
|
||||
for _ in range(args.token_count):
|
||||
next_token = shardedbloom.forward_pass(
|
||||
torch.tensor(input_ids), device=args.device
|
||||
)
|
||||
input_ids = torch.cat(
|
||||
[input_ids, next_token.unsqueeze(-1)], dim=-1
|
||||
)
|
||||
|
||||
print(tokenizer.decode(input_ids.squeeze()))
|
||||
|
||||
else:
|
||||
while True:
|
||||
prompt = input("Enter Prompt: ")
|
||||
try:
|
||||
token_count = int(
|
||||
input("Enter number of tokens you want to generate: ")
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"Invalid integer entered. Using default value of 10"
|
||||
)
|
||||
token_count = 10
|
||||
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
|
||||
for _ in range(token_count):
|
||||
next_token = shardedbloom.forward_pass(
|
||||
torch.tensor(input_ids), device=args.device
|
||||
)
|
||||
input_ids = torch.cat(
|
||||
[input_ids, next_token.unsqueeze(-1)], dim=-1
|
||||
)
|
||||
|
||||
print(tokenizer.decode(input_ids.squeeze()))
|
||||
@@ -1,381 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
|
||||
import re
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
from transformers.models.bloom.modeling_bloom import (
|
||||
BloomBlock,
|
||||
build_alibi_tensor,
|
||||
)
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
batch_size, source_length = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else source_length
|
||||
|
||||
expanded_mask = (
|
||||
mask[:, None, None, :]
|
||||
.expand(batch_size, 1, tgt_len, source_length)
|
||||
.to(dtype)
|
||||
)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def _prepare_attn_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
past_key_values_length=past_key_values_length,
|
||||
).to(attention_mask.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
past_key_values_length: int = 0,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
|
||||
mask.masked_fill_(intermediate_mask, 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
target_length, past_key_values_length, dtype=dtype
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
expanded_mask = mask[None, None, :, :].expand(
|
||||
batch_size, 1, target_length, target_length + past_key_values_length
|
||||
)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
working_dir = sys.argv[1]
|
||||
layer_name = sys.argv[2]
|
||||
will_compile = sys.argv[3]
|
||||
device = sys.argv[4]
|
||||
device_idx = sys.argv[5]
|
||||
prompt = sys.argv[6]
|
||||
|
||||
if device_idx.lower().strip() == "none":
|
||||
device_idx = None
|
||||
else:
|
||||
device_idx = int(device_idx)
|
||||
|
||||
if will_compile.lower().strip() == "true":
|
||||
will_compile = True
|
||||
else:
|
||||
will_compile = False
|
||||
|
||||
f = open(f"{working_dir}/config.json")
|
||||
config = json.load(f)
|
||||
f.close()
|
||||
|
||||
layers_initialized = False
|
||||
try:
|
||||
n_embed = config["n_embed"]
|
||||
except KeyError:
|
||||
n_embed = config["hidden_size"]
|
||||
vocab_size = config["vocab_size"]
|
||||
n_layer = config["n_layer"]
|
||||
try:
|
||||
n_head = config["num_attention_heads"]
|
||||
except KeyError:
|
||||
n_head = config["n_head"]
|
||||
|
||||
if not os.path.isdir(working_dir):
|
||||
os.mkdir(working_dir)
|
||||
|
||||
if layer_name == "start":
|
||||
tokenizer = AutoTokenizer.from_pretrained(working_dir)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(f"{working_dir}/word_embeddings.mlir", encoding="utf-8")
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=None,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
amdshark_module.save_module(
|
||||
module_name=f"{working_dir}/word_embeddings",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
amdshark_module.load_module(f"{working_dir}/word_embeddings.vmfb")
|
||||
input_embeds = amdshark_module(
|
||||
inputs=(input_ids,), function_name="forward"
|
||||
)
|
||||
input_embeds = torch.tensor(input_embeds).float()
|
||||
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(
|
||||
f"{working_dir}/word_embeddings_layernorm.mlir",
|
||||
encoding="utf-8",
|
||||
)
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=None,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
amdshark_module.save_module(
|
||||
module_name=f"{working_dir}/word_embeddings_layernorm",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
amdshark_module.load_module(
|
||||
f"{working_dir}/word_embeddings_layernorm.vmfb"
|
||||
)
|
||||
hidden_states = amdshark_module(
|
||||
inputs=(input_embeds,), function_name="forward"
|
||||
)
|
||||
hidden_states = torch.tensor(hidden_states).float()
|
||||
|
||||
torch.save(hidden_states, f"{working_dir}/hidden_states_0.pt")
|
||||
|
||||
attention_mask = torch.ones(
|
||||
[hidden_states.shape[0], len(input_ids[0])]
|
||||
)
|
||||
|
||||
attention_mask = torch.tensor(attention_mask).float()
|
||||
|
||||
alibi = build_alibi_tensor(
|
||||
attention_mask,
|
||||
n_head,
|
||||
hidden_states.dtype,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
torch.save(alibi, f"{working_dir}/alibi.pt")
|
||||
|
||||
causal_mask = _prepare_attn_mask(
|
||||
attention_mask, input_ids.size(), input_embeds, 0
|
||||
)
|
||||
causal_mask = torch.tensor(causal_mask).float()
|
||||
|
||||
torch.save(causal_mask, f"{working_dir}/causal_mask.pt")
|
||||
|
||||
elif layer_name in [str(x) for x in range(n_layer)]:
|
||||
hidden_states = torch.load(
|
||||
f"{working_dir}/hidden_states_{layer_name}.pt"
|
||||
)
|
||||
alibi = torch.load(f"{working_dir}/alibi.pt")
|
||||
causal_mask = torch.load(f"{working_dir}/causal_mask.pt")
|
||||
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(
|
||||
f"{working_dir}/bloom_block_{layer_name}.mlir",
|
||||
encoding="utf-8",
|
||||
)
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_str,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
amdshark_module.save_module(
|
||||
module_name=f"{working_dir}/bloom_block_{layer_name}",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
amdshark_module.load_module(
|
||||
f"{working_dir}/bloom_block_{layer_name}.vmfb"
|
||||
)
|
||||
|
||||
output = amdshark_module(
|
||||
inputs=(
|
||||
hidden_states.detach().numpy(),
|
||||
alibi.detach().numpy(),
|
||||
causal_mask.detach().numpy(),
|
||||
),
|
||||
function_name="forward",
|
||||
)
|
||||
|
||||
hidden_states = torch.tensor(output[0]).float()
|
||||
|
||||
torch.save(
|
||||
hidden_states,
|
||||
f"{working_dir}/hidden_states_{int(layer_name) + 1}.pt",
|
||||
)
|
||||
|
||||
elif layer_name == "end":
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(f"{working_dir}/ln_f.mlir", encoding="utf-8")
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=None,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
amdshark_module.save_module(
|
||||
module_name=f"{working_dir}/ln_f",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
amdshark_module.load_module(f"{working_dir}/ln_f.vmfb")
|
||||
|
||||
hidden_states = torch.load(f"{working_dir}/hidden_states_{n_layer}.pt")
|
||||
|
||||
hidden_states = amdshark_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(f"{working_dir}/lm_head.mlir", encoding="utf-8")
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
if config["n_embed"] == 14336:
|
||||
|
||||
def get_state_dict():
|
||||
d = torch.load(
|
||||
f"{working_dir}/pytorch_model_00001-of-00072.bin"
|
||||
)
|
||||
return OrderedDict(
|
||||
(k.replace("word_embeddings.", ""), v)
|
||||
for k, v in d.items()
|
||||
)
|
||||
|
||||
def load_causal_lm_head():
|
||||
linear = nn.utils.skip_init(
|
||||
nn.Linear, 14336, 250880, bias=False, dtype=torch.float
|
||||
)
|
||||
linear.load_state_dict(get_state_dict(), strict=False)
|
||||
return linear.float()
|
||||
|
||||
lm_head = load_causal_lm_head()
|
||||
|
||||
logits = lm_head(torch.tensor(hidden_states).float())
|
||||
|
||||
else:
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=None,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
amdshark_module.save_module(
|
||||
module_name=f"{working_dir}/lm_head",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
amdshark_module.load_module(f"{working_dir}/lm_head.vmfb")
|
||||
|
||||
logits = amdshark_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
|
||||
logits = torch.tensor(logits).float()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(working_dir)
|
||||
|
||||
next_token = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=-1))
|
||||
|
||||
f = open(f"{working_dir}/prompt.txt", "w+")
|
||||
f.write(prompt + next_token)
|
||||
f.close()
|
||||
@@ -1,390 +0,0 @@
|
||||
# Description: an implementation of a deep learning recommendation model (DLRM)
|
||||
# The model input consists of dense and sparse features. The former is a vector
|
||||
# of floating point values. The latter is a list of sparse indices into
|
||||
# embedding tables, which consist of vectors of floating point values.
|
||||
# The selected vectors are passed to mlp networks denoted by triangles,
|
||||
# in some cases the vectors are interacted through operators (Ops).
|
||||
#
|
||||
# output:
|
||||
# vector of values
|
||||
# model: |
|
||||
# /\
|
||||
# /__\
|
||||
# |
|
||||
# _____________________> Op <___________________
|
||||
# / | \
|
||||
# /\ /\ /\
|
||||
# /__\ /__\ ... /__\
|
||||
# | | |
|
||||
# | Op Op
|
||||
# | ____/__\_____ ____/__\____
|
||||
# | |_Emb_|____|__| ... |_Emb_|__|___|
|
||||
# input:
|
||||
# [ dense features ] [sparse indices] , ..., [sparse indices]
|
||||
#
|
||||
# More precise definition of model layers:
|
||||
# 1) fully connected layers of an mlp
|
||||
# z = f(y)
|
||||
# y = Wx + b
|
||||
#
|
||||
# 2) embedding lookup (for a list of sparse indices p=[p1,...,pk])
|
||||
# z = Op(e1,...,ek)
|
||||
# obtain vectors e1=E[:,p1], ..., ek=E[:,pk]
|
||||
#
|
||||
# 3) Operator Op can be one of the following
|
||||
# Sum(e1,...,ek) = e1 + ... + ek
|
||||
# Dot(e1,...,ek) = [e1'e1, ..., e1'ek, ..., ek'e1, ..., ek'ek]
|
||||
# Cat(e1,...,ek) = [e1', ..., ek']'
|
||||
# where ' denotes transpose operation
|
||||
#
|
||||
# References:
|
||||
# [1] Maxim Naumov, Dheevatsa Mudigere, Hao-Jun Michael Shi, Jianyu Huang,
|
||||
# Narayanan Sundaram, Jongsoo Park, Xiaodong Wang, Udit Gupta, Carole-Jean Wu,
|
||||
# Alisson G. Azzolini, Dmytro Dzhulgakov, Andrey Mallevich, Ilia Cherniavskii,
|
||||
# Yinghai Lu, Raghuraman Krishnamoorthi, Ansha Yu, Volodymyr Kondratenko,
|
||||
# Stephanie Pereira, Xianjie Chen, Wenlin Chen, Vijay Rao, Bill Jia, Liang Xiong,
|
||||
# Misha Smelyanskiy, "Deep Learning Recommendation Model for Personalization and
|
||||
# Recommendation Systems", CoRR, arXiv:1906.00091, 2019
|
||||
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
### define dlrm in PyTorch ###
|
||||
class DLRM_Net(nn.Module):
|
||||
def create_mlp(self, ln, sigmoid_layer):
|
||||
# build MLP layer by layer
|
||||
layers = nn.ModuleList()
|
||||
for i in range(0, ln.size - 1):
|
||||
n = ln[i]
|
||||
m = ln[i + 1]
|
||||
|
||||
# construct fully connected operator
|
||||
LL = nn.Linear(int(n), int(m), bias=True)
|
||||
|
||||
# initialize the weights
|
||||
# with torch.no_grad():
|
||||
# custom Xavier input, output or two-sided fill
|
||||
|
||||
mean = 0.0 # std_dev = np.sqrt(variance)
|
||||
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
|
||||
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
|
||||
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
|
||||
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
|
||||
LL.weight.data = torch.tensor(W, requires_grad=True)
|
||||
LL.bias.data = torch.tensor(bt, requires_grad=True)
|
||||
|
||||
# approach 2
|
||||
# LL.weight.data.copy_(torch.tensor(W))
|
||||
# LL.bias.data.copy_(torch.tensor(bt))
|
||||
# approach 3
|
||||
# LL.weight = Parameter(torch.tensor(W),requires_grad=True)
|
||||
# LL.bias = Parameter(torch.tensor(bt),requires_grad=True)
|
||||
layers.append(LL)
|
||||
|
||||
# construct sigmoid or relu operator
|
||||
if i == sigmoid_layer:
|
||||
layers.append(nn.Sigmoid())
|
||||
else:
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
# approach 1: use ModuleList
|
||||
# return layers
|
||||
# approach 2: use Sequential container to wrap all layers
|
||||
return torch.nn.Sequential(*layers)
|
||||
|
||||
def create_emb(self, m, ln, weighted_pooling=None):
|
||||
emb_l = nn.ModuleList()
|
||||
v_W_l = []
|
||||
for i in range(0, ln.size):
|
||||
n = ln[i]
|
||||
|
||||
# construct embedding operator
|
||||
EE = nn.EmbeddingBag(n, m, mode="sum")
|
||||
# initialize embeddings
|
||||
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n))
|
||||
W = np.random.uniform(
|
||||
low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m)
|
||||
).astype(np.float32)
|
||||
# approach 1
|
||||
print(W)
|
||||
EE.weight.data = torch.tensor(W, requires_grad=True)
|
||||
# approach 2
|
||||
# EE.weight.data.copy_(torch.tensor(W))
|
||||
# approach 3
|
||||
# EE.weight = Parameter(torch.tensor(W),requires_grad=True)
|
||||
if weighted_pooling is None:
|
||||
v_W_l.append(None)
|
||||
else:
|
||||
v_W_l.append(torch.ones(n, dtype=torch.float32))
|
||||
emb_l.append(EE)
|
||||
return emb_l, v_W_l
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
m_spa=None,
|
||||
ln_emb=None,
|
||||
ln_bot=None,
|
||||
ln_top=None,
|
||||
arch_interaction_op=None,
|
||||
arch_interaction_itself=False,
|
||||
sigmoid_bot=-1,
|
||||
sigmoid_top=-1,
|
||||
weighted_pooling=None,
|
||||
):
|
||||
super(DLRM_Net, self).__init__()
|
||||
|
||||
if (
|
||||
(m_spa is not None)
|
||||
and (ln_emb is not None)
|
||||
and (ln_bot is not None)
|
||||
and (ln_top is not None)
|
||||
and (arch_interaction_op is not None)
|
||||
):
|
||||
# save arguments
|
||||
self.output_d = 0
|
||||
self.arch_interaction_op = arch_interaction_op
|
||||
self.arch_interaction_itself = arch_interaction_itself
|
||||
if weighted_pooling is not None and weighted_pooling != "fixed":
|
||||
self.weighted_pooling = "learned"
|
||||
else:
|
||||
self.weighted_pooling = weighted_pooling
|
||||
|
||||
# create operators
|
||||
self.emb_l, w_list = self.create_emb(
|
||||
m_spa, ln_emb, weighted_pooling
|
||||
)
|
||||
if self.weighted_pooling == "learned":
|
||||
self.v_W_l = nn.ParameterList()
|
||||
for w in w_list:
|
||||
self.v_W_l.append(nn.Parameter(w))
|
||||
else:
|
||||
self.v_W_l = w_list
|
||||
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot)
|
||||
self.top_l = self.create_mlp(ln_top, sigmoid_top)
|
||||
|
||||
def apply_mlp(self, x, layers):
|
||||
return layers(x)
|
||||
|
||||
def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
|
||||
# WARNING: notice that we are processing the batch at once. We implicitly
|
||||
# assume that the data is laid out such that:
|
||||
# 1. each embedding is indexed with a group of sparse indices,
|
||||
# corresponding to a single lookup
|
||||
# 2. for each embedding the lookups are further organized into a batch
|
||||
# 3. for a list of embedding tables there is a list of batched lookups
|
||||
# TORCH-MLIR
|
||||
# We are passing all the embeddings as arguments for easy parsing.
|
||||
|
||||
ly = []
|
||||
for k, sparse_index_group_batch in enumerate(lS_i):
|
||||
sparse_offset_group_batch = lS_o[k]
|
||||
|
||||
# embedding lookup
|
||||
# We are using EmbeddingBag, which implicitly uses sum operator.
|
||||
# The embeddings are represented as tall matrices, with sum
|
||||
# happening vertically across 0 axis, resulting in a row vector
|
||||
# E = emb_l[k]
|
||||
|
||||
if v_W_l[k] is not None:
|
||||
per_sample_weights = v_W_l[k].gather(
|
||||
0, sparse_index_group_batch
|
||||
)
|
||||
else:
|
||||
per_sample_weights = None
|
||||
|
||||
E = emb_l[k]
|
||||
V = E(
|
||||
sparse_index_group_batch,
|
||||
sparse_offset_group_batch,
|
||||
per_sample_weights=per_sample_weights,
|
||||
)
|
||||
|
||||
ly.append(V)
|
||||
|
||||
return ly
|
||||
|
||||
def interact_features(self, x, ly):
|
||||
if self.arch_interaction_op == "dot":
|
||||
# concatenate dense and sparse features
|
||||
(batch_size, d) = x.shape
|
||||
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
|
||||
# perform a dot product
|
||||
Z = torch.bmm(T, torch.transpose(T, 1, 2))
|
||||
# append dense feature with the interactions (into a row vector)
|
||||
# approach 1: all
|
||||
# Zflat = Z.view((batch_size, -1))
|
||||
# approach 2: unique
|
||||
_, ni, nj = Z.shape
|
||||
# approach 1: tril_indices
|
||||
# offset = 0 if self.arch_interaction_itself else -1
|
||||
# li, lj = torch.tril_indices(ni, nj, offset=offset)
|
||||
# approach 2: custom
|
||||
offset = 1 if self.arch_interaction_itself else 0
|
||||
li = torch.tensor(
|
||||
[i for i in range(ni) for j in range(i + offset)]
|
||||
)
|
||||
lj = torch.tensor(
|
||||
[j for i in range(nj) for j in range(i + offset)]
|
||||
)
|
||||
Zflat = Z[:, li, lj]
|
||||
# concatenate dense features and interactions
|
||||
R = torch.cat([x] + [Zflat], dim=1)
|
||||
elif self.arch_interaction_op == "cat":
|
||||
# concatenation features (into a row vector)
|
||||
R = torch.cat([x] + ly, dim=1)
|
||||
else:
|
||||
sys.exit(
|
||||
"ERROR: --arch-interaction-op="
|
||||
+ self.arch_interaction_op
|
||||
+ " is not supported"
|
||||
)
|
||||
|
||||
return R
|
||||
|
||||
def forward(self, dense_x, lS_o, *lS_i):
|
||||
return self.sequential_forward(dense_x, lS_o, lS_i)
|
||||
|
||||
def sequential_forward(self, dense_x, lS_o, lS_i):
|
||||
# process dense features (using bottom mlp), resulting in a row vector
|
||||
x = self.apply_mlp(dense_x, self.bot_l)
|
||||
# debug prints
|
||||
# print("intermediate")
|
||||
# print(x.detach().cpu().numpy())
|
||||
|
||||
# process sparse features(using embeddings), resulting in a list of row vectors
|
||||
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
|
||||
# for y in ly:
|
||||
# print(y.detach().cpu().numpy())
|
||||
|
||||
# interact features (dense and sparse)
|
||||
z = self.interact_features(x, ly)
|
||||
# print(z.detach().cpu().numpy())
|
||||
|
||||
# obtain probability of a click (using top mlp)
|
||||
p = self.apply_mlp(z, self.top_l)
|
||||
|
||||
# # clamp output if needed
|
||||
# if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
|
||||
# z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
|
||||
# else:
|
||||
# z = p
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def dash_separated_ints(value):
|
||||
vals = value.split("-")
|
||||
for val in vals:
|
||||
try:
|
||||
int(val)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is not a valid dash separated list of ints" % value
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# model related parameters
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train Deep Learning Recommendation Model (DLRM)"
|
||||
)
|
||||
parser.add_argument("--arch-sparse-feature-size", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--arch-embedding-size", type=dash_separated_ints, default="4-3-2"
|
||||
)
|
||||
# j will be replaced with the table number
|
||||
parser.add_argument(
|
||||
"--arch-mlp-bot", type=dash_separated_ints, default="4-3-2"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-mlp-top", type=dash_separated_ints, default="8-2-1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-interaction-op", type=str, choices=["dot", "cat"], default="dot"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-interaction-itself", action="store_true", default=False
|
||||
)
|
||||
parser.add_argument("--weighted-pooling", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
|
||||
ln_top = np.fromstring(args.arch_mlp_top, dtype=int, sep="-")
|
||||
m_den = ln_bot[0]
|
||||
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-")
|
||||
m_spa = args.arch_sparse_feature_size
|
||||
ln_emb = np.asarray(ln_emb)
|
||||
num_fea = ln_emb.size + 1 # num sparse + num dense features
|
||||
|
||||
|
||||
# Initialize the model.
|
||||
dlrm_model = DLRM_Net(
|
||||
m_spa=m_spa,
|
||||
ln_emb=ln_emb,
|
||||
ln_bot=ln_bot,
|
||||
ln_top=ln_top,
|
||||
arch_interaction_op=args.arch_interaction_op,
|
||||
)
|
||||
|
||||
|
||||
# Inputs to the model.
|
||||
dense_inp = torch.tensor([[0.6965, 0.2861, 0.2269, 0.5513]])
|
||||
vs0 = torch.tensor([[0], [0], [0]], dtype=torch.int64)
|
||||
vsi = torch.tensor([1, 2, 3]), torch.tensor([1]), torch.tensor([1])
|
||||
|
||||
input_dlrm = (dense_inp, vs0, *vsi)
|
||||
|
||||
golden_output = dlrm_model(dense_inp, vs0, *vsi)
|
||||
|
||||
mlir_importer = AMDSharkImporter(
|
||||
dlrm_model,
|
||||
input_dlrm,
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
dlrm_mlir, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward(input_dlrm)
|
||||
np.testing.assert_allclose(
|
||||
golden_output.detach().numpy(), result, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
|
||||
|
||||
# Verified via torch-mlir.
|
||||
# import torch_mlir
|
||||
# from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
||||
# module = torch_mlir.compile(
|
||||
# dlrm_model, inputs, use_tracing=True, output_type="linalg-on-tensors"
|
||||
# )
|
||||
# backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
# compiled = backend.compile(module)
|
||||
# jit_module = backend.load(compiled)
|
||||
|
||||
# dense_numpy = dense_inp.numpy()
|
||||
# vs0_numpy = vs0.numpy()
|
||||
# vsi_numpy = [inp.numpy() for inp in vsi]
|
||||
|
||||
# numpy_inp = (dense_numpy, vs0_numpy, *vsi_numpy)
|
||||
|
||||
# print(jit_module.forward(*numpy_inp))
|
||||
@@ -1,311 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchrec.datasets.utils import Batch
|
||||
from torchrec.modules.crossnet import LowRankCrossNet
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from torchrec.models.dlrm import (
|
||||
choose,
|
||||
DenseArch,
|
||||
DLRM,
|
||||
InteractionArch,
|
||||
SparseArch,
|
||||
OverArch,
|
||||
)
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
def calculate_offsets(tensor_list, prev_values, prev_offsets):
|
||||
offset_init = 0
|
||||
offset_list = []
|
||||
values_list = []
|
||||
|
||||
if prev_offsets != None:
|
||||
offset_init = prev_values.shape[-1]
|
||||
for tensor in tensor_list:
|
||||
offset_list.append(offset_init)
|
||||
offset_init += tensor.shape[0]
|
||||
|
||||
concatendated_tensor_list = torch.cat(tensor_list)
|
||||
|
||||
if prev_values != None:
|
||||
concatendated_tensor_list = torch.cat(
|
||||
[prev_values, concatendated_tensor_list]
|
||||
)
|
||||
|
||||
concatenated_offsets = torch.tensor(offset_list)
|
||||
|
||||
if prev_offsets != None:
|
||||
concatenated_offsets = torch.cat([prev_offsets, concatenated_offsets])
|
||||
|
||||
return concatendated_tensor_list, concatenated_offsets
|
||||
|
||||
|
||||
# Have to make combined_keys as dict as to which embedding bags they
|
||||
# point to. {f1: 0, f3: 0, f2: 1}
|
||||
# The result will be a triple containing values, indices and pointer tensor.
|
||||
def to_list(key_jagged, combined_keys):
|
||||
key_jagged_dict = key_jagged.to_dict()
|
||||
combined_list = []
|
||||
|
||||
for key in combined_keys:
|
||||
prev_values, prev_offsets = calculate_offsets(
|
||||
key_jagged_dict[key].to_dense(), None, None
|
||||
)
|
||||
print(prev_values)
|
||||
print(prev_offsets)
|
||||
combined_list.append(prev_values)
|
||||
combined_list.append(prev_offsets)
|
||||
combined_list.append(torch.tensor(combined_keys[key]))
|
||||
|
||||
return combined_list
|
||||
|
||||
|
||||
class SparseArchAMDShark(nn.Module):
|
||||
def create_emb(self, embedding_dim, num_embeddings_list):
|
||||
embedding_list = nn.ModuleList()
|
||||
for i in range(0, num_embeddings_list.size):
|
||||
num_embeddings = num_embeddings_list[i]
|
||||
EE = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum")
|
||||
W = np.random.uniform(
|
||||
low=-np.sqrt(1 / num_embeddings),
|
||||
high=np.sqrt(1 / num_embeddings),
|
||||
size=(num_embeddings, embedding_dim),
|
||||
).astype(np.float32)
|
||||
EE.weight.data = torch.tensor(W, requires_grad=True)
|
||||
embedding_list.append(EE)
|
||||
return embedding_list
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
total_features,
|
||||
num_embeddings_list,
|
||||
):
|
||||
super(SparseArchAMDShark, self).__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_features = total_features
|
||||
self.embedding_list = self.create_emb(
|
||||
embedding_dim, num_embeddings_list
|
||||
)
|
||||
|
||||
def forward(self, *batched_inputs):
|
||||
concatenated_list = []
|
||||
input_enum, embedding_enum = 0, 0
|
||||
|
||||
for k in range(len(batched_inputs) // 3):
|
||||
values = batched_inputs[input_enum]
|
||||
input_enum += 1
|
||||
offsets = batched_inputs[input_enum]
|
||||
input_enum += 1
|
||||
embedding_pointer = int(batched_inputs[input_enum])
|
||||
input_enum += 1
|
||||
|
||||
E = self.embedding_list[embedding_pointer]
|
||||
V = E(values, offsets)
|
||||
concatenated_list.append(V)
|
||||
|
||||
return torch.cat(concatenated_list, dim=1).reshape(
|
||||
-1, self.num_features, self.embedding_dim
|
||||
)
|
||||
|
||||
|
||||
def test_sparse_arch() -> None:
|
||||
D = 3
|
||||
eb1_config = EmbeddingBagConfig(
|
||||
name="t1",
|
||||
embedding_dim=D,
|
||||
num_embeddings=10,
|
||||
feature_names=["f1", "f3"],
|
||||
)
|
||||
eb2_config = EmbeddingBagConfig(
|
||||
name="t2",
|
||||
embedding_dim=D,
|
||||
num_embeddings=10,
|
||||
feature_names=["f2"],
|
||||
)
|
||||
|
||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||
|
||||
w1 = ebc.embedding_bags["t1"].weight
|
||||
w2 = ebc.embedding_bags["t2"].weight
|
||||
|
||||
sparse_arch = SparseArch(ebc)
|
||||
|
||||
keys = ["f1", "f2", "f3", "f4", "f5"]
|
||||
offsets = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 19])
|
||||
features = KeyedJaggedTensor.from_offsets_sync(
|
||||
keys=keys,
|
||||
values=torch.tensor(
|
||||
[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]
|
||||
),
|
||||
offsets=offsets,
|
||||
)
|
||||
sparse_archi = SparseArchAMDShark(D, 3, np.array([10, 10]))
|
||||
sparse_archi.embedding_list[0].weight = w1
|
||||
sparse_archi.embedding_list[1].weight = w2
|
||||
inputs = to_list(features, {"f1": 0, "f3": 0, "f2": 1})
|
||||
|
||||
test_results = sparse_archi(*inputs)
|
||||
sparse_features = sparse_arch(features)
|
||||
|
||||
torch.allclose(
|
||||
sparse_features,
|
||||
test_results,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
|
||||
test_sparse_arch()
|
||||
|
||||
|
||||
class DLRMAMDShark(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
total_features,
|
||||
num_embeddings_list,
|
||||
dense_in_features: int,
|
||||
dense_arch_layer_sizes: List[int],
|
||||
over_arch_layer_sizes: List[int],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.sparse_arch: SparseArchAMDShark = SparseArchAMDShark(
|
||||
embedding_dim, total_features, num_embeddings_list
|
||||
)
|
||||
num_sparse_features: int = total_features
|
||||
|
||||
self.dense_arch = DenseArch(
|
||||
in_features=dense_in_features,
|
||||
layer_sizes=dense_arch_layer_sizes,
|
||||
)
|
||||
|
||||
self.inter_arch = InteractionArch(
|
||||
num_sparse_features=num_sparse_features,
|
||||
)
|
||||
|
||||
over_in_features: int = (
|
||||
embedding_dim
|
||||
+ choose(num_sparse_features, 2)
|
||||
+ num_sparse_features
|
||||
)
|
||||
|
||||
self.over_arch = OverArch(
|
||||
in_features=over_in_features,
|
||||
layer_sizes=over_arch_layer_sizes,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, dense_features: torch.Tensor, *sparse_features
|
||||
) -> torch.Tensor:
|
||||
embedded_dense = self.dense_arch(dense_features)
|
||||
embedded_sparse = self.sparse_arch(*sparse_features)
|
||||
concatenated_dense = self.inter_arch(
|
||||
dense_features=embedded_dense, sparse_features=embedded_sparse
|
||||
)
|
||||
logits = self.over_arch(concatenated_dense)
|
||||
return logits
|
||||
|
||||
|
||||
def test_dlrm() -> None:
|
||||
B = 2
|
||||
D = 8
|
||||
dense_in_features = 100
|
||||
|
||||
eb1_config = EmbeddingBagConfig(
|
||||
name="t1",
|
||||
embedding_dim=D,
|
||||
num_embeddings=100,
|
||||
feature_names=["f1", "f3"],
|
||||
)
|
||||
eb2_config = EmbeddingBagConfig(
|
||||
name="t2",
|
||||
embedding_dim=D,
|
||||
num_embeddings=100,
|
||||
feature_names=["f2"],
|
||||
)
|
||||
|
||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||
|
||||
sparse_features = KeyedJaggedTensor.from_offsets_sync(
|
||||
keys=["f1", "f3", "f2"],
|
||||
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]),
|
||||
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]),
|
||||
)
|
||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||
sparse_nn = DLRM(
|
||||
embedding_bag_collection=ebc,
|
||||
dense_in_features=dense_in_features,
|
||||
dense_arch_layer_sizes=[20, D],
|
||||
over_arch_layer_sizes=[5, 1],
|
||||
)
|
||||
sparse_nn_nod = DLRMAMDShark(
|
||||
embedding_dim=8,
|
||||
total_features=3,
|
||||
num_embeddings_list=np.array([100, 100]),
|
||||
dense_in_features=dense_in_features,
|
||||
dense_arch_layer_sizes=[20, D],
|
||||
over_arch_layer_sizes=[5, 1],
|
||||
)
|
||||
|
||||
dense_features = torch.rand((B, dense_in_features))
|
||||
|
||||
x = to_list(sparse_features, {"f1": 0, "f3": 0, "f2": 1})
|
||||
|
||||
w1 = ebc.embedding_bags["t1"].weight
|
||||
w2 = ebc.embedding_bags["t2"].weight
|
||||
|
||||
sparse_nn_nod.sparse_arch.embedding_list[0].weight = w1
|
||||
sparse_nn_nod.sparse_arch.embedding_list[1].weight = w2
|
||||
|
||||
sparse_nn_nod.dense_arch.load_state_dict(sparse_nn.dense_arch.state_dict())
|
||||
sparse_nn_nod.inter_arch.load_state_dict(sparse_nn.inter_arch.state_dict())
|
||||
sparse_nn_nod.over_arch.load_state_dict(sparse_nn.over_arch.state_dict())
|
||||
|
||||
logits = sparse_nn(
|
||||
dense_features=dense_features,
|
||||
sparse_features=sparse_features,
|
||||
)
|
||||
logits_nod = sparse_nn_nod(dense_features, *x)
|
||||
|
||||
# print(logits)
|
||||
# print(logits_nod)
|
||||
|
||||
# Import the module and print.
|
||||
mlir_importer = AMDSharkImporter(
|
||||
sparse_nn_nod,
|
||||
(dense_features, *x),
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
dlrm_mlir, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward(inputs)
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
|
||||
torch.allclose(
|
||||
logits,
|
||||
logits_nod,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
|
||||
test_dlrm()
|
||||
@@ -1,39 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
|
||||
|
||||
class UnetModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = torch.hub.load(
|
||||
"mateuszbuda/brain-segmentation-pytorch",
|
||||
"unet",
|
||||
in_channels=3,
|
||||
out_channels=1,
|
||||
init_features=32,
|
||||
pretrained=True,
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
mlir_importer = AMDSharkImporter(
|
||||
UnetModule(),
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(vision_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=False
|
||||
)
|
||||
|
||||
amdshark_module = AMDSharkInference(vision_mlir, mlir_dialect="linalg")
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward((input,))
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
@@ -1,21 +0,0 @@
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from pipeline_amdshark_stable_diffusion_upscale import (
|
||||
AMDSharkStableDiffusionUpscalePipeline,
|
||||
)
|
||||
import torch
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipeline = AMDSharkStableDiffusionUpscalePipeline(model_id)
|
||||
|
||||
# let's download an image
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
|
||||
response = requests.get(url)
|
||||
low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
low_res_img = low_res_img.resize((128, 128))
|
||||
|
||||
prompt = "a white cat"
|
||||
|
||||
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
|
||||
upscaled_image.save("upsampled_cat.png")
|
||||
@@ -1,98 +0,0 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from utils import compile_through_fx
|
||||
import torch
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
|
||||
model_input = {
|
||||
"clip": (torch.randint(1, 2, (1, 77)),),
|
||||
"vae": (torch.randn(1, 4, 128, 128),),
|
||||
"unet": (
|
||||
torch.randn(2, 7, 128, 128), # latents
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, 77, 1024), # embedding
|
||||
torch.randn(2).to(torch.int64), # noise_level
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
)
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
amdshark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
model_input["clip"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return amdshark_clip
|
||||
|
||||
|
||||
def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
return x
|
||||
|
||||
vae = VaeModel()
|
||||
amdshark_vae = compile_through_fx(
|
||||
vae,
|
||||
model_input["vae"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return amdshark_vae
|
||||
|
||||
|
||||
def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, latent, timestep, text_embedding, noise_level):
|
||||
unet_out = self.unet.forward(
|
||||
latent,
|
||||
timestep,
|
||||
text_embedding,
|
||||
noise_level,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
return unet_out
|
||||
|
||||
unet = UnetModel()
|
||||
f16_input_mask = (True, True, True, False)
|
||||
amdshark_unet = compile_through_fx(
|
||||
unet,
|
||||
model_input["unet"],
|
||||
model_name=model_name,
|
||||
is_f16=True,
|
||||
f16_input_mask=f16_input_mask,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return amdshark_unet
|
||||
@@ -1,48 +0,0 @@
|
||||
import sys
|
||||
from model_wrappers import (
|
||||
get_vae_mlir,
|
||||
get_unet_mlir,
|
||||
get_clip_mlir,
|
||||
)
|
||||
from upscaler_args import args
|
||||
from utils import get_amdshark_model
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
if BATCH_SIZE != 1:
|
||||
sys.exit("Only batch size 1 is supported.")
|
||||
|
||||
|
||||
unet_flag = [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
]
|
||||
|
||||
vae_flag = [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
|
||||
clip_flag = [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
|
||||
bucket = "gs://amdshark_tank/stable_diffusion/"
|
||||
|
||||
|
||||
def get_unet():
|
||||
model_name = "upscaler_unet"
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(model_name, unet_flag)
|
||||
return get_amdshark_model(bucket, model_name, unet_flag)
|
||||
|
||||
|
||||
def get_vae():
|
||||
model_name = "upscaler_vae"
|
||||
if args.import_mlir:
|
||||
return get_vae_mlir(model_name, vae_flag)
|
||||
return get_amdshark_model(bucket, model_name, vae_flag)
|
||||
|
||||
|
||||
def get_clip():
|
||||
model_name = "upscaler_clip"
|
||||
if args.import_mlir:
|
||||
return get_clip_mlir(model_name, clip_flag)
|
||||
return get_amdshark_model(bucket, model_name, clip_flag)
|
||||
@@ -1,489 +0,0 @@
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers import logging
|
||||
from diffusers.pipeline_utils import ImagePipelineOutput
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
|
||||
image = [np.array(i.resize((w, h)))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
def amdshark_run_wrapper(model, *args):
|
||||
np_inputs = tuple([x.detach().numpy() for x in args])
|
||||
outputs = model("forward", np_inputs)
|
||||
return torch.from_numpy(outputs)
|
||||
|
||||
|
||||
class AMDSharkStableDiffusionUpscalePipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_id,
|
||||
):
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
model_id, subfolder="tokenizer"
|
||||
)
|
||||
self.low_res_scheduler = DDPMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
self.scheduler = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
self.vae = get_vae()
|
||||
self.unet = get_unet()
|
||||
self.text_encoder = get_clip()
|
||||
self.max_noise_level = (350,)
|
||||
self._execution_device = "cpu"
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
Args:
|
||||
prompt (`str` or `list(int)`):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
"""
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(
|
||||
prompt, padding="longest", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
||||
-1
|
||||
] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
# if (
|
||||
# hasattr(self.text_encoder.config, "use_attention_mask")
|
||||
# and self.text_encoder.config.use_attention_mask
|
||||
# ):
|
||||
# attention_mask = text_inputs.attention_mask.to(device)
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
text_embeddings = amdshark_run_wrapper(
|
||||
self.text_encoder, text_input_ids.to(device)
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(
|
||||
bs_embed * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# if (
|
||||
# hasattr(self.text_encoder.config, "use_attention_mask")
|
||||
# and self.text_encoder.config.use_attention_mask
|
||||
# ):
|
||||
# attention_mask = uncond_input.attention_mask.to(device)
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
uncond_embeddings = amdshark_run_wrapper(
|
||||
self.text_encoder,
|
||||
uncond_input.input_ids.to(device),
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(
|
||||
1, num_images_per_prompt, 1
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
image = amdshark_run_wrapper(self.vae, latents)
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def check_inputs(self, prompt, image, noise_level, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(
|
||||
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
||||
)
|
||||
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
)
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor
|
||||
if isinstance(image, list) or isinstance(image, torch.Tensor):
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = len(prompt)
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
else:
|
||||
image_batch_size = image.shape[0]
|
||||
if batch_size != image_batch_size:
|
||||
raise ValueError(
|
||||
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
|
||||
" Please make sure that passed `prompt` matches the batch size of `image`."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
if images.shape[-1] == 1:
|
||||
# special case for grayscale (single channel) images
|
||||
pil_images = [
|
||||
Image.fromarray(image.squeeze(), mode="L") for image in images
|
||||
]
|
||||
else:
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
if latents is None:
|
||||
if device == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(
|
||||
shape, generator=generator, device="cpu", dtype=dtype
|
||||
).to(device)
|
||||
else:
|
||||
latents = torch.randn(
|
||||
shape, generator=generator, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(
|
||||
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
|
||||
)
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[
|
||||
torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]
|
||||
],
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 20,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[
|
||||
Union[torch.Generator, List[torch.Generator]]
|
||||
] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[
|
||||
Callable[[int, int, torch.FloatTensor], None]
|
||||
] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
):
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, image, noise_level, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = image.to(dtype=text_embeddings.dtype, device=device)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor(
|
||||
[noise_level], dtype=torch.long, device=device
|
||||
)
|
||||
if device == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
noise = torch.randn(
|
||||
image.shape,
|
||||
generator=generator,
|
||||
device="cpu",
|
||||
dtype=text_embeddings.dtype,
|
||||
).to(device)
|
||||
else:
|
||||
noise = torch.randn(
|
||||
image.shape,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=text_embeddings.dtype,
|
||||
)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
image = torch.cat([image] * batch_multiplier * num_images_per_prompt)
|
||||
noise_level = torch.cat([noise_level] * image.shape[0])
|
||||
|
||||
# 6. Prepare latent variables
|
||||
height, width = image.shape[2:]
|
||||
# num_channels_latents = self.vae.config.latent_channels
|
||||
num_channels_latents = 4
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Check that sizes of image and latents match
|
||||
num_channels_image = image.shape[1]
|
||||
# if (
|
||||
# num_channels_latents + num_channels_image
|
||||
# != self.unet.config.in_channels
|
||||
# ):
|
||||
# raise ValueError(
|
||||
# f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
# f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
# f" `num_channels_image`: {num_channels_image} "
|
||||
# f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
||||
# " `pipeline.unet` or your `image` input."
|
||||
# )
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = (
|
||||
len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
)
|
||||
for i, t in tqdm(enumerate(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
latent_model_input = torch.cat([latent_model_input, image], dim=1)
|
||||
|
||||
timestep = torch.tensor([t]).to(torch.float32)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = amdshark_run_wrapper(
|
||||
self.unet,
|
||||
latent_model_input.half(),
|
||||
timestep,
|
||||
text_embeddings.half(),
|
||||
noise_level,
|
||||
)
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# # call the callback, if provided
|
||||
# if i == len(timesteps) - 1 or (
|
||||
# (i + 1) > num_warmup_steps
|
||||
# and (i + 1) % self.scheduler.order == 0
|
||||
# ):
|
||||
# progress_bar.update()
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
# self.vae.to(dtype=torch.float32)
|
||||
image = self.decode_latents(latents.float())
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -1,98 +0,0 @@
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=["cyberpunk forest by Salvador Dali"],
|
||||
help="text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative-prompts",
|
||||
nargs="+",
|
||||
default=[""],
|
||||
help="text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the no. of steps to do the sampling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed to use.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="the value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Model Config and Usage Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="vulkan", help="device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp16", help="precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="imports the model from torch module to amdshark_module otherwise downloads the model from amdshark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="saves the compiled flatbuffer to the local directory",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree-vulkan-target-triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Profiles vulkan device and collects the .rdc info",
|
||||
)
|
||||
|
||||
|
||||
args = p.parse_args()
|
||||
@@ -1,230 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from upscaler_args import args
|
||||
from amdshark.amdshark_importer import import_with_fx
|
||||
from amdshark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
|
||||
|
||||
def _compile_module(amdshark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
amdshark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
print("Saving to {}".format(vmfb_path))
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(
|
||||
vmfb_path
|
||||
)
|
||||
)
|
||||
path = amdshark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
)
|
||||
amdshark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
amdshark_module.compile(extra_args)
|
||||
return amdshark_module
|
||||
|
||||
|
||||
# Downloads the model from amdshark_tank and returns the amdshark_module.
|
||||
def get_amdshark_model(tank_url, model_name, extra_args=[]):
|
||||
from amdshark.amdshark_downloader import download_model
|
||||
from amdshark.parser import amdshark_args
|
||||
|
||||
# Set local amdshark_tank cache directory.
|
||||
# amdshark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=tank_url,
|
||||
frontend="torch",
|
||||
)
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
return _compile_module(amdshark_module, model_name, extra_args)
|
||||
|
||||
|
||||
# Converts the torch-module into a amdshark_module.
|
||||
def compile_through_fx(
|
||||
model, inputs, model_name, is_f16=False, f16_input_mask=None, extra_args=[]
|
||||
):
|
||||
mlir_module, func_name = import_with_fx(
|
||||
model, inputs, is_f16, f16_input_mask
|
||||
)
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
return _compile_module(amdshark_module, model_name, extra_args)
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
"""
|
||||
Inputs: driver_name
|
||||
Returns a list of all the available devices for a given driver sorted by
|
||||
the iree path names of the device as in --list_devices option in iree.
|
||||
"""
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return device_list_src
|
||||
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired combination of name/path.
|
||||
"""
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
|
||||
def get_output_value(dev_dict):
|
||||
if key_combination == 1:
|
||||
return f"{driver}://{dev_dict['path']}"
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
|
||||
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
return device_map
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
try:
|
||||
device_mapping = device_map[device]
|
||||
except KeyError:
|
||||
raise ValueError(f"Device '{device}' is not a valid device.")
|
||||
return device_mapping
|
||||
|
||||
|
||||
def set_init_device_flags():
|
||||
if "vulkan" in args.device:
|
||||
# set runtime flags for vulkan.
|
||||
set_iree_runtime_flags()
|
||||
|
||||
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
||||
device_name, args.device = map_device_to_name_path(args.device)
|
||||
if not args.iree_vulkan_target_triple:
|
||||
triple = get_vulkan_target_triple(device_name)
|
||||
if triple is not None:
|
||||
args.iree_vulkan_target_triple = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
|
||||
)
|
||||
elif "cuda" in args.device:
|
||||
args.device = "cuda"
|
||||
elif "cpu" in args.device:
|
||||
args.device = "cpu"
|
||||
|
||||
# set max_length based on availability.
|
||||
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
args.max_length = 77
|
||||
elif args.variant == "openjourney":
|
||||
args.max_length = 64
|
||||
|
||||
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
|
||||
if (
|
||||
args.variant in ["openjourney", "dreamlike"]
|
||||
or args.precision != "fp16"
|
||||
or "vulkan" not in args.device
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
):
|
||||
args.use_tuned = False
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
elif args.use_base_vae and args.variant != "stablediffusion":
|
||||
args.use_tuned = False
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
if args.use_tuned:
|
||||
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
|
||||
|
||||
|
||||
# Utility to get list of devices available.
|
||||
def get_available_devices():
|
||||
def get_devices_by_name(driver_name):
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
driver_name = iree_device_map(driver_name)
|
||||
device_list_dict = get_all_devices(driver_name)
|
||||
print(f"{driver_name} devices are available.")
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_list.append(f"{driver_name}://{i} => {device['name']}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
vulkan_devices = get_devices_by_name("vulkan")
|
||||
available_devices.extend(vulkan_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
available_devices.append("cpu")
|
||||
return available_devices
|
||||
@@ -1,15 +0,0 @@
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.amdshark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"v_diffusion", frontend="torch"
|
||||
)
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_model, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
amdshark_module.compile()
|
||||
result = amdshark_module.forward(inputs)
|
||||
print("The obtained result via amdshark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
@@ -1,60 +0,0 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
import tensorflow as tf
|
||||
|
||||
from amdshark.amdshark_trainer import AMDSharkTrainer
|
||||
from amdshark.parser import parser
|
||||
from urllib import request
|
||||
|
||||
parser.add_argument(
|
||||
"--download_mlir_path",
|
||||
type=str,
|
||||
default="bert_tf_training.mlir",
|
||||
help="Specifies path to target mlir file that will be loaded.",
|
||||
)
|
||||
load_args, unknown = parser.parse_known_args()
|
||||
|
||||
tf.random.set_seed(0)
|
||||
vocab_size = 100
|
||||
NUM_CLASSES = 5
|
||||
SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
# Download BERT model from tank and train.
|
||||
if __name__ == "__main__":
|
||||
predict_sample_input = [
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
]
|
||||
file_link = "https://storage.googleapis.com/amdshark_tank/users/stanley/bert_tf_training.mlir"
|
||||
response = request.urlretrieve(file_link, load_args.download_mlir_path)
|
||||
sample_input_tensors = [
|
||||
tf.convert_to_tensor(val, dtype=tf.int32)
|
||||
for val in predict_sample_input
|
||||
]
|
||||
num_iter = 10
|
||||
if not os.path.isfile(load_args.download_mlir_path):
|
||||
raise ValueError(
|
||||
f"Tried looking for target mlir in {load_args.download_mlir_path}, but cannot be found."
|
||||
)
|
||||
with open(load_args.download_mlir_path, "rb") as input_file:
|
||||
bert_mlir = input_file.read()
|
||||
amdshark_module = AMDSharkTrainer(
|
||||
bert_mlir,
|
||||
(
|
||||
sample_input_tensors,
|
||||
tf.convert_to_tensor(
|
||||
np.random.randint(5, size=(BATCH_SIZE)), dtype=tf.int32
|
||||
),
|
||||
),
|
||||
)
|
||||
amdshark_module.set_frontend("mhlo")
|
||||
amdshark_module.compile()
|
||||
start = time.time()
|
||||
print(amdshark_module.train(num_iter))
|
||||
end = time.time()
|
||||
total_time = end - start
|
||||
print("time: " + str(total_time))
|
||||
print("time/iter: " + str(total_time / num_iter))
|
||||
@@ -1,41 +0,0 @@
|
||||
# Stable Diffusion Img2Img model
|
||||
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
<summary>Installation (Linux)</summary>
|
||||
|
||||
### Activate amdshark.venv Virtual Environment
|
||||
|
||||
```shell
|
||||
source amdshark.venv/bin/activate
|
||||
|
||||
# Some older pip installs may not be able to handle the recent PyTorch deps
|
||||
python -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
### Install dependencies
|
||||
|
||||
# Run the setup.sh script
|
||||
|
||||
```shell
|
||||
./setup.sh
|
||||
```
|
||||
|
||||
### Run the Stable diffusion Img2Img model
|
||||
|
||||
To run the model with the default set of images and params, run:
|
||||
```shell
|
||||
python stable_diffusion_img2img.py
|
||||
```
|
||||
To run the model with your set of images, and parameters you need to specify the following params:
|
||||
1.) Input images directory with the arg `--input_dir` containing 3-5 images.
|
||||
2.) What to teach the model? Using the arg `--what_to_teach`, allowed values are `object` or `style`.
|
||||
3.) Placeholder token using the arg `--placeholder_token`, that represents your new concept. It should be passed with the opening and closing angle brackets. For ex: token is `cat-toy`, it should be passed as `<cat-toy>`.
|
||||
4.) Initializer token using the arg `--initializer_token`, which summarise what is your new concept.
|
||||
|
||||
For the result, you need to pass the text prompt with the arg: `--prompt`. The prompt string should contain a "*s" in it, which will be replaced by the placeholder token during the inference.
|
||||
|
||||
By default the result images will go into the `sd_result` dir. To specify your output dir use the arg: `--output_dir`.
|
||||
|
||||
The default value of max_training_steps is `3000`, which takes some hours to complete. You can pass the smaller value with the arg `--training_steps`. Specify the number of images to be sampled for the result with the `--num_inference_samples` arg.
|
||||
@@ -1,25 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
TD="$(cd $(dirname $0) && pwd)"
|
||||
if [ -z "$PYTHON" ]; then
|
||||
PYTHON="$(which python3)"
|
||||
fi
|
||||
|
||||
function die() {
|
||||
echo "Error executing command: $*"
|
||||
exit 1
|
||||
}
|
||||
|
||||
PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; print("{0}.{1}".format(*version))'`
|
||||
|
||||
echo "Python: $PYTHON"
|
||||
echo "Python version: $PYTHON_VERSION_X_Y"
|
||||
|
||||
mkdir input_images
|
||||
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg -P input_images/
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg -P input_images/
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg -P input_images/
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg -P input_images/
|
||||
|
||||
pip install diffusers["training"]==0.4.1 transformers ftfy opencv-python
|
||||
@@ -1,600 +0,0 @@
|
||||
# Textual-inversion fine-tuning for Stable Diffusion using diffusers
|
||||
# This script shows how to "teach" Stable Diffusion a new concept via
|
||||
# textual-inversion using 🤗 Hugging Face [🧨 Diffusers library](https://github.com/huggingface/diffusers).
|
||||
# By using just 3-5 images you can teach new concepts to Stable Diffusion
|
||||
# and personalize the model on your own images.
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import PIL
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
p.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
default="input_images/",
|
||||
help="the directory contains the images used for fine tuning",
|
||||
)
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd_result",
|
||||
help="the directory contains the images used for fine tuning",
|
||||
)
|
||||
p.add_argument(
|
||||
"--training_steps",
|
||||
type=int,
|
||||
default=3000,
|
||||
help="the maximum number of training steps",
|
||||
)
|
||||
p.add_argument("--seed", type=int, default=42, help="the random seed")
|
||||
p.add_argument(
|
||||
"--what_to_teach",
|
||||
type=str,
|
||||
choices=["object", "style"],
|
||||
default="object",
|
||||
help="what is it that you are teaching?",
|
||||
)
|
||||
p.add_argument(
|
||||
"--placeholder_token",
|
||||
type=str,
|
||||
default="<cat-toy>",
|
||||
help="It is the token you are going to use to represent your new concept",
|
||||
)
|
||||
p.add_argument(
|
||||
"--initializer_token",
|
||||
type=str,
|
||||
default="toy",
|
||||
help="It is a word that can summarise what is your new concept",
|
||||
)
|
||||
p.add_argument(
|
||||
"--inference_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the number of steps for inference",
|
||||
)
|
||||
p.add_argument(
|
||||
"--num_inference_samples",
|
||||
type=int,
|
||||
default=4,
|
||||
help="the number of samples for inference",
|
||||
)
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a grafitti in a wall with a *s on it",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
if "*s" not in args.prompt:
|
||||
raise ValueError(
|
||||
f'The prompt should have a "*s" which will be replaced by a placeholder token.'
|
||||
)
|
||||
|
||||
prompt1, prompt2 = args.prompt.split("*s")
|
||||
args.prompt = prompt1 + args.placeholder_token + prompt2
|
||||
|
||||
pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
# Load input images.
|
||||
images = []
|
||||
for filename in os.listdir(args.input_dir):
|
||||
img = cv2.imread(os.path.join(args.input_dir, filename))
|
||||
if img is not None:
|
||||
images.append(img)
|
||||
|
||||
# Setup the prompt templates for training
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
# Setup the dataset
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
|
||||
self.templates = (
|
||||
imagenet_style_templates_small
|
||||
if learnable_property == "style"
|
||||
else imagenet_templates_small
|
||||
)
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
text = random.choice(self.templates).format(placeholder_string)
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
# Setting up the model
|
||||
# Load the tokenizer and add the placeholder token as a additional special token.
|
||||
# Please read and if you agree accept the LICENSE
|
||||
# [here](https://huggingface.co/CompVis/stable-diffusion-v1-4) if you see an error
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Get token ids for our placeholder and initializer token.
|
||||
# This code block will complain if initializer string is not a single token
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load the Stable Diffusion model
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# We have added the `placeholder_token` in the `tokenizer` so we resize the token embeddings here,
|
||||
# this will a new embedding vector in the token embeddings for our `placeholder_token`
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
|
||||
# In Textual-Inversion we only train the newly added embedding vector,
|
||||
# so lets freeze rest of the model parameters here.
|
||||
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
# Freeze vae and unet
|
||||
freeze_params(vae.parameters())
|
||||
freeze_params(unet.parameters())
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
params_to_freeze = itertools.chain(
|
||||
text_encoder.text_model.encoder.parameters(),
|
||||
text_encoder.text_model.final_layer_norm.parameters(),
|
||||
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
||||
)
|
||||
freeze_params(params_to_freeze)
|
||||
|
||||
# Creating our training data
|
||||
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=args.input_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=512,
|
||||
placeholder_token=args.placeholder_token,
|
||||
repeats=100,
|
||||
learnable_property=args.what_to_teach, # Option selected above between object and style
|
||||
center_crop=False,
|
||||
set="train",
|
||||
)
|
||||
|
||||
|
||||
def create_dataloader(train_batch_size=1):
|
||||
return torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True
|
||||
)
|
||||
|
||||
|
||||
# Create noise_scheduler for training.
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
tensor_format="pt",
|
||||
)
|
||||
|
||||
# Define hyperparameters for our training
|
||||
hyperparameters = {
|
||||
"learning_rate": 5e-04,
|
||||
"scale_lr": True,
|
||||
"max_train_steps": args.training_steps,
|
||||
"train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"seed": args.seed,
|
||||
"output_dir": "sd-concept-output",
|
||||
}
|
||||
|
||||
|
||||
def training_function(text_encoder, vae, unet):
|
||||
logger = get_logger(__name__)
|
||||
|
||||
train_batch_size = hyperparameters["train_batch_size"]
|
||||
gradient_accumulation_steps = hyperparameters[
|
||||
"gradient_accumulation_steps"
|
||||
]
|
||||
learning_rate = hyperparameters["learning_rate"]
|
||||
max_train_steps = hyperparameters["max_train_steps"]
|
||||
output_dir = hyperparameters["output_dir"]
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
train_dataloader = create_dataloader(train_batch_size)
|
||||
|
||||
if hyperparameters["scale_lr"]:
|
||||
learning_rate = (
|
||||
learning_rate
|
||||
* gradient_accumulation_steps
|
||||
* train_batch_size
|
||||
* accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=learning_rate,
|
||||
)
|
||||
|
||||
text_encoder, optimizer, train_dataloader = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
# Move vae and unet to device
|
||||
vae.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
|
||||
# Keep vae and unet in eval model as we don't train these
|
||||
vae.eval()
|
||||
unet.eval()
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
total_batch_size = (
|
||||
train_batch_size
|
||||
* accelerator.num_processes
|
||||
* gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
range(max_train_steps), disable=not accelerator.is_local_main_process
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = (
|
||||
vae.encode(batch["pixel_values"])
|
||||
.latent_dist.sample()
|
||||
.detach()
|
||||
)
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(latents.shape).to(latents.device)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(
|
||||
latents, noise, timesteps
|
||||
)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(
|
||||
noisy_latents, timesteps, encoder_hidden_states
|
||||
).sample
|
||||
|
||||
loss = (
|
||||
F.mse_loss(noise_pred, noise, reduction="none")
|
||||
.mean([1, 2, 3])
|
||||
.mean()
|
||||
)
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Zero out the gradients for all token embeddings except the newly added
|
||||
# embeddings for the concept, as we only want to optimize the concept embeddings
|
||||
if accelerator.num_processes > 1:
|
||||
grads = (
|
||||
text_encoder.module.get_input_embeddings().weight.grad
|
||||
)
|
||||
else:
|
||||
grads = text_encoder.get_input_embeddings().weight.grad
|
||||
# Get the index for tokens that we want to zero the grads for
|
||||
index_grads_to_zero = (
|
||||
torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
)
|
||||
grads.data[index_grads_to_zero, :] = grads.data[
|
||||
index_grads_to_zero, :
|
||||
].fill_(0)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {"loss": loss.detach().item()}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
skip_prk_steps=True,
|
||||
),
|
||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker"
|
||||
),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained(
|
||||
"openai/clip-vit-base-patch32"
|
||||
),
|
||||
)
|
||||
pipeline.save_pretrained(output_dir)
|
||||
# Also save the newly trained embeddings
|
||||
learned_embeds = (
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight[placeholder_token_id]
|
||||
)
|
||||
learned_embeds_dict = {
|
||||
args.placeholder_token: learned_embeds.detach().cpu()
|
||||
}
|
||||
torch.save(
|
||||
learned_embeds_dict, os.path.join(output_dir, "learned_embeds.bin")
|
||||
)
|
||||
|
||||
|
||||
import accelerate
|
||||
|
||||
accelerate.notebook_launcher(
|
||||
training_function, args=(text_encoder, vae, unet), num_processes=1
|
||||
)
|
||||
|
||||
# Set up the pipeline
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
hyperparameters["output_dir"],
|
||||
# torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
all_images = []
|
||||
for _ in range(args.num_inference_samples):
|
||||
images = pipe(
|
||||
[args.prompt],
|
||||
num_inference_steps=args.inference_steps,
|
||||
guidance_scale=7.5,
|
||||
).images
|
||||
all_images.extend(images)
|
||||
|
||||
# output_path = os.path.abspath(os.path.join(os.getcwd(), args.output_dir))
|
||||
if not os.path.isdir(args.output_dir):
|
||||
os.mkdir(args.output_dir)
|
||||
|
||||
[
|
||||
image.save(f"{args.output_dir}/{i}.jpeg")
|
||||
for i, image in enumerate(all_images)
|
||||
]
|
||||
@@ -1,43 +0,0 @@
|
||||
# Stable Diffusion Fine Tuning
|
||||
|
||||
## Installation (Linux)
|
||||
|
||||
### Activate amdshark.venv Virtual Environment
|
||||
|
||||
```shell
|
||||
source amdshark.venv/bin/activate
|
||||
|
||||
# Some older pip installs may not be able to handle the recent PyTorch deps
|
||||
python -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
## Install dependencies
|
||||
|
||||
### Run the following installation commands:
|
||||
```
|
||||
pip install -U git+https://github.com/huggingface/diffusers.git
|
||||
pip install accelerate transformers ftfy
|
||||
```
|
||||
|
||||
### Build torch-mlir with the following branch:
|
||||
|
||||
Please cherry-pick this branch of torch-mlir: https://github.com/vivekkhandelwal1/torch-mlir/tree/sd-ops
|
||||
and build it locally. You can find the instructions for using locally build Torch-MLIR,
|
||||
here: https://github.com/nod-ai/AMD-SHARK-Studio#how-to-use-your-locally-built-iree--torch-mlir-with-amdshark
|
||||
|
||||
## Run the Stable diffusion fine tuning
|
||||
|
||||
To run the model with the default set of images and params, run:
|
||||
```shell
|
||||
python stable_diffusion_fine_tuning.py
|
||||
```
|
||||
By default the training is run through the PyTorch path. If you want to train the model using the Torchdynamo path of Torch-MLIR, you need to specify `--use_torchdynamo=True`.
|
||||
|
||||
The default number of training steps are `2000`, which would take many hours to complete based on your system config. You can pass the smaller value with the arg `--training_steps`. You can specify the number of images to be sampled for the result with the `--num_inference_samples` arg. For the number of inference steps you can use `--inference_steps` flag.
|
||||
|
||||
For example, you can run the training for a limited set of steps via the dynamo path by using the following command:
|
||||
```
|
||||
python stable_diffusion_fine_tuning.py --training_steps=1 --inference_steps=1 --num_inference_samples=1 --train_batch_size=1 --use_torchdynamo=True
|
||||
```
|
||||
|
||||
You can also specify the device to be used via the flag `--device`. The default value is `cpu`, for GPU execution you can specify `--device="cuda"`.
|
||||
@@ -1,914 +0,0 @@
|
||||
# Install the required libs
|
||||
# pip install -U git+https://github.com/huggingface/diffusers.git
|
||||
# pip install accelerate transformers ftfy
|
||||
|
||||
# Import required libraries
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import PIL
|
||||
import logging
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
import torch._dynamo as dynamo
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
|
||||
torch._dynamo.config.verbose = True
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import (
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
|
||||
# Enter your HuggingFace Token
|
||||
# Note: You can comment this prompt and just set your token instead of passing it through cli for every execution.
|
||||
hf_token = input("Please enter your huggingface token here: ")
|
||||
YOUR_TOKEN = hf_token
|
||||
|
||||
|
||||
def image_grid(imgs, rows, cols):
|
||||
assert len(imgs) == rows * cols
|
||||
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||
grid_w, grid_h = grid.size
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
return grid
|
||||
|
||||
|
||||
# `pretrained_model_name_or_path` which Stable Diffusion checkpoint you want to use
|
||||
# Options: 1.) "stabilityai/stable-diffusion-2"
|
||||
# 2.) "stabilityai/stable-diffusion-2-base"
|
||||
# 3.) "CompVis/stable-diffusion-v1-4"
|
||||
# 4.) "runwayml/stable-diffusion-v1-5"
|
||||
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2"
|
||||
|
||||
# Add here the URLs to the images of the concept you are adding. 3-5 should be fine
|
||||
urls = [
|
||||
"https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg",
|
||||
"https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg",
|
||||
"https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg",
|
||||
"https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg",
|
||||
## You can add additional images here
|
||||
]
|
||||
|
||||
# Downloading Images
|
||||
import requests
|
||||
import glob
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def download_image(url):
|
||||
try:
|
||||
response = requests.get(url)
|
||||
except:
|
||||
return None
|
||||
return Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
|
||||
images = list(filter(None, [download_image(url) for url in urls]))
|
||||
save_path = "./my_concept"
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
[image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)]
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__,
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
p.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
default="my_concept/",
|
||||
help="the directory contains the images used for fine tuning",
|
||||
)
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd_result",
|
||||
help="the directory contains the images used for fine tuning",
|
||||
)
|
||||
p.add_argument(
|
||||
"--training_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="the maximum number of training steps",
|
||||
)
|
||||
p.add_argument(
|
||||
"--train_batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The batch size for training",
|
||||
)
|
||||
p.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=250,
|
||||
help="the number of steps after which to save the learned concept",
|
||||
)
|
||||
p.add_argument("--seed", type=int, default=42, help="the random seed")
|
||||
p.add_argument(
|
||||
"--what_to_teach",
|
||||
type=str,
|
||||
choices=["object", "style"],
|
||||
default="object",
|
||||
help="what is it that you are teaching?",
|
||||
)
|
||||
p.add_argument(
|
||||
"--placeholder_token",
|
||||
type=str,
|
||||
default="<cat-toy>",
|
||||
help="It is the token you are going to use to represent your new concept",
|
||||
)
|
||||
p.add_argument(
|
||||
"--initializer_token",
|
||||
type=str,
|
||||
default="toy",
|
||||
help="It is a word that can summarise what is your new concept",
|
||||
)
|
||||
p.add_argument(
|
||||
"--inference_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the number of steps for inference",
|
||||
)
|
||||
p.add_argument(
|
||||
"--num_inference_samples",
|
||||
type=int,
|
||||
default=4,
|
||||
help="the number of samples for inference",
|
||||
)
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a grafitti in a wall with a *s on it",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
p.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="The device to use",
|
||||
)
|
||||
p.add_argument(
|
||||
"--use_torchdynamo",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="This flag is used to determine whether the training has to be done through the torchdynamo path or not.",
|
||||
)
|
||||
args = p.parse_args()
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
if "*s" not in args.prompt:
|
||||
raise ValueError(
|
||||
f'The prompt should have a "*s" which will be replaced by a placeholder token.'
|
||||
)
|
||||
|
||||
prompt1, prompt2 = args.prompt.split("*s")
|
||||
args.prompt = prompt1 + args.placeholder_token + prompt2
|
||||
|
||||
# `images_path` is a path to directory containing the training images.
|
||||
images_path = args.input_dir
|
||||
while not os.path.exists(str(images_path)):
|
||||
print(
|
||||
"The images_path specified does not exist, use the colab file explorer to copy the path :"
|
||||
)
|
||||
images_path = input("")
|
||||
save_path = images_path
|
||||
|
||||
# Setup and check the images you have just added
|
||||
images = []
|
||||
for file_path in os.listdir(save_path):
|
||||
try:
|
||||
image_path = os.path.join(save_path, file_path)
|
||||
images.append(Image.open(image_path).resize((512, 512)))
|
||||
except:
|
||||
print(
|
||||
f"{image_path} is not a valid image, please make sure to remove this file from the directory otherwise the training could fail."
|
||||
)
|
||||
image_grid(images, 1, len(images))
|
||||
|
||||
########### Create Dataset ##########
|
||||
|
||||
# Setup the prompt templates for training
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
# Setup the dataset
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
|
||||
self.templates = (
|
||||
imagenet_style_templates_small
|
||||
if learnable_property == "style"
|
||||
else imagenet_templates_small
|
||||
)
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
text = random.choice(self.templates).format(placeholder_string)
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
########## Setting up the model ##########
|
||||
|
||||
# Load the tokenizer and add the placeholder token as a additional special token.
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Get token ids for our placeholder and initializer token.
|
||||
# This code block will complain if initializer string is not a single token
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load the Stable Diffusion model
|
||||
# Load models and create wrapper for stable diffusion
|
||||
# pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path)
|
||||
# del pipeline
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="text_encoder"
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="vae"
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="unet"
|
||||
)
|
||||
|
||||
# We have added the placeholder_token in the tokenizer so we resize the token embeddings here
|
||||
# this will a new embedding vector in the token embeddings for our placeholder_token
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
|
||||
# In Textual-Inversion we only train the newly added embedding vector
|
||||
# so lets freeze rest of the model parameters here
|
||||
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
# Freeze vae and unet
|
||||
freeze_params(vae.parameters())
|
||||
freeze_params(unet.parameters())
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
params_to_freeze = itertools.chain(
|
||||
text_encoder.text_model.encoder.parameters(),
|
||||
text_encoder.text_model.final_layer_norm.parameters(),
|
||||
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
||||
)
|
||||
freeze_params(params_to_freeze)
|
||||
|
||||
|
||||
# Move vae and unet to device
|
||||
# For the dynamo path default compilation device is `cpu`, since torch-mlir
|
||||
# supports only that. Therefore, convert to device only for PyTorch path.
|
||||
if not args.use_torchdynamo:
|
||||
vae.to(args.device)
|
||||
unet.to(args.device)
|
||||
|
||||
# Keep vae in eval mode as we don't train it
|
||||
vae.eval()
|
||||
# Keep unet in train mode to enable gradient checkpointing
|
||||
unet.train()
|
||||
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.encode(input, return_dict=False)[0]
|
||||
return x
|
||||
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
|
||||
amdshark_vae = VaeModel()
|
||||
amdshark_unet = UnetModel()
|
||||
|
||||
####### Creating our training data ########
|
||||
|
||||
# Let's create the Dataset and Dataloader
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=save_path,
|
||||
tokenizer=tokenizer,
|
||||
size=vae.sample_size,
|
||||
placeholder_token=args.placeholder_token,
|
||||
repeats=100,
|
||||
learnable_property=args.what_to_teach, # Option selected above between object and style
|
||||
center_crop=False,
|
||||
set="train",
|
||||
)
|
||||
|
||||
|
||||
def create_dataloader(train_batch_size=1):
|
||||
return torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True
|
||||
)
|
||||
|
||||
|
||||
# Create noise_scheduler for training
|
||||
noise_scheduler = DDPMScheduler.from_config(
|
||||
pretrained_model_name_or_path, subfolder="scheduler"
|
||||
)
|
||||
|
||||
######## Training ###########
|
||||
|
||||
# Define hyperparameters for our training. If you are not happy with your results,
|
||||
# you can tune the `learning_rate` and the `max_train_steps`
|
||||
|
||||
# Setting up all training args
|
||||
hyperparameters = {
|
||||
"learning_rate": 5e-04,
|
||||
"scale_lr": True,
|
||||
"max_train_steps": args.training_steps,
|
||||
"save_steps": args.save_steps,
|
||||
"train_batch_size": args.train_batch_size,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": True,
|
||||
"mixed_precision": "fp16",
|
||||
"seed": 42,
|
||||
"output_dir": "sd-concept-output",
|
||||
}
|
||||
# creating output directory
|
||||
cwd = os.getcwd()
|
||||
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
|
||||
while not os.path.exists(str(out_dir)):
|
||||
try:
|
||||
os.mkdir(out_dir)
|
||||
except OSError as error:
|
||||
print("Output directory not created")
|
||||
|
||||
###### Torch-MLIR Compilation ######
|
||||
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
@make_simple_dynamo_backend
|
||||
def refbackend_torchdynamo_backend(
|
||||
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
):
|
||||
# handling usage of empty tensor without initializing
|
||||
transform_fx(fx_graph)
|
||||
fx_graph.recompile()
|
||||
if _returns_nothing(fx_graph):
|
||||
return fx_graph
|
||||
removed_none_indexes = _remove_nones(fx_graph)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
|
||||
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
|
||||
bytecode_stream = BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
amdshark_module.compile()
|
||||
|
||||
def compiled_callable(*inputs):
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
result = amdshark_module("forward", inputs)
|
||||
if was_unwrapped:
|
||||
result = [
|
||||
result,
|
||||
]
|
||||
if not isinstance(result, list):
|
||||
result = torch.from_numpy(result)
|
||||
else:
|
||||
result = tuple(torch.from_numpy(x) for x in result)
|
||||
result = list(result)
|
||||
for removed_index in removed_none_indexes:
|
||||
result.insert(removed_index, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
return compiled_callable
|
||||
|
||||
|
||||
def predictions(torch_func, jit_func, batchA, batchB):
|
||||
res = jit_func(batchA.numpy(), batchB.numpy())
|
||||
if res is not None:
|
||||
prediction = res
|
||||
else:
|
||||
prediction = None
|
||||
return prediction
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# def save_progress(text_encoder, placeholder_token_id, accelerator, save_path):
|
||||
def save_progress(text_encoder, placeholder_token_id, save_path):
|
||||
logger.info("Saving embeddings")
|
||||
learned_embeds = (
|
||||
# accelerator.unwrap_model(text_encoder)
|
||||
text_encoder.get_input_embeddings().weight[placeholder_token_id]
|
||||
)
|
||||
learned_embeds_dict = {
|
||||
args.placeholder_token: learned_embeds.detach().cpu()
|
||||
}
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
|
||||
train_batch_size = hyperparameters["train_batch_size"]
|
||||
gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"]
|
||||
learning_rate = hyperparameters["learning_rate"]
|
||||
if hyperparameters["scale_lr"]:
|
||||
learning_rate = (
|
||||
learning_rate
|
||||
* gradient_accumulation_steps
|
||||
* train_batch_size
|
||||
# * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=learning_rate,
|
||||
)
|
||||
|
||||
|
||||
# Training function
|
||||
def train_func(batch_pixel_values, batch_input_ids):
|
||||
# Convert images to latent space
|
||||
latents = amdshark_vae(batch_pixel_values).sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch_input_ids)[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = amdshark_unet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
)
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
||||
)
|
||||
|
||||
loss = (
|
||||
F.mse_loss(noise_pred, target, reduction="none").mean([1, 2, 3]).mean()
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
# Zero out the gradients for all token embeddings except the newly added
|
||||
# embeddings for the concept, as we only want to optimize the concept embeddings
|
||||
grads = text_encoder.get_input_embeddings().weight.grad
|
||||
# Get the index for tokens that we want to zero the grads for
|
||||
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
grads.data[index_grads_to_zero, :] = grads.data[
|
||||
index_grads_to_zero, :
|
||||
].fill_(0)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def training_function():
|
||||
max_train_steps = hyperparameters["max_train_steps"]
|
||||
output_dir = hyperparameters["output_dir"]
|
||||
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
|
||||
|
||||
train_dataloader = create_dataloader(train_batch_size)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
total_batch_size = (
|
||||
train_batch_size
|
||||
* gradient_accumulation_steps
|
||||
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
# range(max_train_steps), disable=not accelerator.is_local_main_process
|
||||
range(max_train_steps)
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
params_ = [i for i in text_encoder.get_input_embeddings().parameters()]
|
||||
if args.use_torchdynamo:
|
||||
print("******** TRAINING STARTED - TORCHYDNAMO PATH ********")
|
||||
else:
|
||||
print("******** TRAINING STARTED - PYTORCH PATH ********")
|
||||
print("Initial weights:")
|
||||
print(params_, params_[0].shape)
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if args.use_torchdynamo:
|
||||
dynamo_callable = dynamo.optimize(
|
||||
refbackend_torchdynamo_backend
|
||||
)(train_func)
|
||||
lam_func = lambda x, y: dynamo_callable(
|
||||
torch.from_numpy(x), torch.from_numpy(y)
|
||||
)
|
||||
loss = predictions(
|
||||
train_func,
|
||||
lam_func,
|
||||
batch["pixel_values"],
|
||||
batch["input_ids"],
|
||||
# params[0].detach(),
|
||||
)
|
||||
else:
|
||||
loss = train_func(batch["pixel_values"], batch["input_ids"])
|
||||
print(loss)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % hyperparameters["save_steps"] == 0:
|
||||
save_path = os.path.join(
|
||||
output_dir,
|
||||
f"learned_embeds-step-{global_step}.bin",
|
||||
)
|
||||
save_progress(
|
||||
text_encoder,
|
||||
placeholder_token_id,
|
||||
save_path,
|
||||
)
|
||||
|
||||
logs = {"loss": loss.detach().item()}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
params__ = [i for i in text_encoder.get_input_embeddings().parameters()]
|
||||
print("******** TRAINING PROCESS FINISHED ********")
|
||||
print("Updated weights:")
|
||||
print(params__, params__[0].shape)
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
# text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
)
|
||||
pipeline.save_pretrained(output_dir)
|
||||
# Also save the newly trained embeddings
|
||||
save_path = os.path.join(output_dir, f"learned_embeds.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, save_path)
|
||||
|
||||
|
||||
training_function()
|
||||
|
||||
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
|
||||
if param.grad is not None:
|
||||
del param.grad # free some memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Set up the pipeline
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
hyperparameters["output_dir"],
|
||||
scheduler=DPMSolverMultistepScheduler.from_pretrained(
|
||||
hyperparameters["output_dir"], subfolder="scheduler"
|
||||
),
|
||||
)
|
||||
if not args.use_torchdynamo:
|
||||
pipe.to(args.device)
|
||||
|
||||
# Run the Stable Diffusion pipeline
|
||||
# Don't forget to use the placeholder token in your prompt
|
||||
|
||||
all_images = []
|
||||
for _ in range(args.num_inference_samples):
|
||||
images = pipe(
|
||||
[args.prompt],
|
||||
num_inference_steps=args.inference_steps,
|
||||
guidance_scale=7.5,
|
||||
).images
|
||||
all_images.extend(images)
|
||||
|
||||
output_path = os.path.abspath(os.path.join(os.getcwd(), args.output_dir))
|
||||
if not os.path.isdir(args.output_dir):
|
||||
os.mkdir(args.output_dir)
|
||||
|
||||
[
|
||||
image.save(f"{args.output_dir}/{i}.jpeg")
|
||||
for i, image in enumerate(all_images)
|
||||
]
|
||||
@@ -1,164 +0,0 @@
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
## Common utilities to be shared by iree utilities.
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
|
||||
def run_cmd(cmd, debug=False, raise_err=False):
|
||||
"""
|
||||
Inputs:
|
||||
cmd : cli command string.
|
||||
debug : if True, prints debug info
|
||||
raise_err : if True, raise exception to caller
|
||||
"""
|
||||
if debug:
|
||||
print("IREE run command: \n\n")
|
||||
print(cmd)
|
||||
print("\n\n")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
check=True,
|
||||
)
|
||||
stdout = result.stdout.decode()
|
||||
stderr = result.stderr.decode()
|
||||
return stdout, stderr
|
||||
except subprocess.CalledProcessError as e:
|
||||
if raise_err:
|
||||
raise Exception from e
|
||||
else:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
|
||||
|
||||
def iree_device_map(device):
|
||||
uri_parts = device.split("://", 2)
|
||||
iree_driver = (
|
||||
_IREE_DEVICE_MAP[uri_parts[0]]
|
||||
if uri_parts[0] in _IREE_DEVICE_MAP
|
||||
else uri_parts[0]
|
||||
)
|
||||
if len(uri_parts) == 1:
|
||||
return iree_driver
|
||||
elif "rocm" in uri_parts:
|
||||
return "rocm"
|
||||
else:
|
||||
return f"{iree_driver}://{uri_parts[1]}"
|
||||
|
||||
|
||||
def get_supported_device_list():
|
||||
return list(_IREE_DEVICE_MAP.keys())
|
||||
|
||||
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"hip": "hip",
|
||||
"intel-gpu": "level_zero",
|
||||
}
|
||||
|
||||
|
||||
def iree_target_map(device):
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
|
||||
|
||||
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cpu-task": "llvm-cpu",
|
||||
"cpu-sync": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan-spirv",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"hip": "rocm",
|
||||
"intel-gpu": "opencl-spirv",
|
||||
}
|
||||
|
||||
|
||||
# Finds whether the required drivers are installed for the given device.
|
||||
@functools.cache
|
||||
def check_device_drivers(device):
|
||||
"""
|
||||
Checks necessary drivers present for gpu and vulkan devices
|
||||
False => drivers present!
|
||||
"""
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
|
||||
from iree.runtime import get_driver
|
||||
|
||||
device_mapped = iree_device_map(device)
|
||||
|
||||
try:
|
||||
_ = get_driver(device_mapped)
|
||||
except ValueError as ve:
|
||||
print(
|
||||
f"[ERR] device `{device}` not registered with IREE. "
|
||||
"Ensure IREE is configured for use with this device.\n"
|
||||
f"Full Error: \n {repr(ve)}"
|
||||
)
|
||||
return True
|
||||
except RuntimeError as re:
|
||||
print(f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}")
|
||||
return True
|
||||
|
||||
# Unknown device. We assume drivers are installed.
|
||||
return False
|
||||
|
||||
|
||||
# Installation info for the missing device drivers.
|
||||
def device_driver_info(device):
|
||||
device_driver_err_map = {
|
||||
"cuda": {
|
||||
"debug": "Try `nvidia-smi` on system to check.",
|
||||
"solution": " from https://www.nvidia.in/Download/index.aspx?lang=en-in for your system.",
|
||||
},
|
||||
"vulkan": {
|
||||
"debug": "Try `vulkaninfo` on system to check.",
|
||||
"solution": " from https://vulkan.lunarg.com/sdk/home for your distribution.",
|
||||
},
|
||||
"metal": {
|
||||
"debug": "Check if Bare metal is supported and enabled on your system.",
|
||||
"solution": ".",
|
||||
},
|
||||
"rocm": {
|
||||
"debug": f"Try `{'hip' if sys.platform == 'win32' else 'rocm'}info` on system to check.",
|
||||
"solution": " from https://rocm.docs.amd.com/en/latest/rocm.html for your system.",
|
||||
},
|
||||
}
|
||||
|
||||
if device in device_driver_err_map:
|
||||
err_msg = (
|
||||
f"Required drivers for {device} not found. {device_driver_err_map[device]['debug']} "
|
||||
f"Please install the required drivers{device_driver_err_map[device]['solution']} "
|
||||
f"For further assistance please reach out to the community on discord [https://discord.com/invite/RUqY2h2s9u]"
|
||||
f" and/or file a bug at https://github.com/nod-ai/AMD-SHARK-Studio/issues"
|
||||
)
|
||||
return err_msg
|
||||
else:
|
||||
return f"{device} is not supported."
|
||||
@@ -1,154 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from amdshark.iree_utils._common import run_cmd, iree_device_map
|
||||
from amdshark.iree_utils.cpu_utils import get_cpu_count
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import platform
|
||||
|
||||
UNIT_TO_SECOND_MAP = {"us": 1e-6, "ms": 0.001, "s": 1}
|
||||
|
||||
|
||||
def tensor_to_type_str(input_tensors: tuple, mlir_dialect: str):
|
||||
"""
|
||||
Input: A tuple of input tensors i.e tuple(torch.tensor)
|
||||
Output: list of string that represent mlir types (i.e 1x24xf64)
|
||||
# TODO: Support more than floats, and ints
|
||||
"""
|
||||
list_of_type = []
|
||||
for input_tensor in input_tensors:
|
||||
type_string = "x".join([str(dim) for dim in input_tensor.shape])
|
||||
if mlir_dialect in ["linalg", "tosa"]:
|
||||
dtype_string = str(input_tensor.dtype).replace("torch.", "")
|
||||
elif mlir_dialect in ["mhlo", "tflite"]:
|
||||
dtype = input_tensor.dtype
|
||||
try:
|
||||
dtype_string = re.findall("'[^\"]*'", str(dtype))[0].replace(
|
||||
"'", ""
|
||||
)
|
||||
except IndexError:
|
||||
dtype_string = str(dtype)
|
||||
regex_split = re.compile("([a-zA-Z]+)([0-9]+)")
|
||||
match = regex_split.match(dtype_string)
|
||||
mlir_type_string = str(match.group(1)[0]) + str(match.group(2))
|
||||
type_string += f"x{mlir_type_string}"
|
||||
list_of_type.append(type_string)
|
||||
return list_of_type
|
||||
|
||||
|
||||
def build_benchmark_args(
|
||||
input_file: str,
|
||||
device: str,
|
||||
input_tensors: tuple,
|
||||
mlir_dialect: str,
|
||||
training=False,
|
||||
):
|
||||
"""
|
||||
Inputs: input_file leading to vmfb, input_tensor to function, target device,
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
|
||||
if platform.system() == "Windows":
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
|
||||
time_extractor = None
|
||||
else:
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
fn_name = "forward"
|
||||
if training == True:
|
||||
# TODO: Replace name of train with actual train fn name.
|
||||
fn_name = "train"
|
||||
benchmark_cl.append(f"--function={fn_name}")
|
||||
benchmark_cl.append(f"--device={iree_device_map(device)}")
|
||||
mlir_input_types = tensor_to_type_str(input_tensors, mlir_dialect)
|
||||
for mlir_input in mlir_input_types:
|
||||
benchmark_cl.append(f"--input={mlir_input}")
|
||||
if device == "cpu":
|
||||
num_cpus = get_cpu_count()
|
||||
if num_cpus is not None:
|
||||
benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}")
|
||||
# if time_extractor:
|
||||
# benchmark_cl.append(time_extractor)
|
||||
benchmark_cl.append(f"--print_statistics=true")
|
||||
return benchmark_cl
|
||||
|
||||
|
||||
def build_benchmark_args_non_tensor_input(
|
||||
input_file: str,
|
||||
device: str,
|
||||
inputs: tuple,
|
||||
mlir_dialect: str,
|
||||
function_name: str,
|
||||
):
|
||||
"""
|
||||
Inputs: input_file leading to vmfb, input_tensor to function, target device,
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
|
||||
if platform.system() == "Windows":
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
|
||||
time_extractor = None
|
||||
else:
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
if function_name:
|
||||
benchmark_cl.append(f"--function={function_name}")
|
||||
benchmark_cl.append(f"--device={iree_device_map(device)}")
|
||||
for input in inputs:
|
||||
benchmark_cl.append(f"--input={input}")
|
||||
if platform.system() != "Windows":
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl.append(time_extractor)
|
||||
return benchmark_cl
|
||||
|
||||
|
||||
def run_benchmark_module(benchmark_cl):
|
||||
"""
|
||||
Run benchmark command, extract result and return iteration/seconds, host
|
||||
peak memory, and device peak memory.
|
||||
|
||||
# TODO: Add an example of the benchmark command.
|
||||
Input: benchmark command.
|
||||
"""
|
||||
benchmark_path = benchmark_cl[0]
|
||||
assert os.path.exists(
|
||||
benchmark_path
|
||||
), "Cannot find iree_benchmark_module, Please contact AMDSHARK maintainer on discord."
|
||||
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
|
||||
try:
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_stdout)
|
||||
time_ms = float(match.group(1))
|
||||
unit = match.group(3)
|
||||
except AttributeError:
|
||||
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_stdout)
|
||||
time_ms = float(match.group(1))
|
||||
unit = match.group(2)
|
||||
iter_per_second = 1.0 / (time_ms * 0.001)
|
||||
|
||||
# Extract peak memory.
|
||||
host_regex = re.compile(r".*HOST_LOCAL:\s*([0-9]+)B peak")
|
||||
host_peak_b = int(host_regex.search(bench_stderr).group(1))
|
||||
device_regex = re.compile(r".*DEVICE_LOCAL:\s*([0-9]+)B peak")
|
||||
device_peak_b = int(device_regex.search(bench_stderr).group(1))
|
||||
return iter_per_second, host_peak_b, device_peak_b
|
||||
@@ -1,704 +0,0 @@
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from amdshark.parser import amdshark_args
|
||||
|
||||
from .trace import DetailLogger
|
||||
from ._common import iree_device_map, iree_target_map
|
||||
from .cpu_utils import get_iree_cpu_rt_args
|
||||
from .benchmark_utils import *
|
||||
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
print("Configuring for device:" + device)
|
||||
device, device_num = clean_device_info(device)
|
||||
|
||||
if "cpu" in device:
|
||||
from amdshark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
u_kernel_flag = ["--iree-llvmcpu-enable-ukernels"]
|
||||
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
|
||||
|
||||
return (
|
||||
get_iree_cpu_args()
|
||||
+ u_kernel_flag
|
||||
+ stack_size_flag
|
||||
)
|
||||
if device == "cuda":
|
||||
from amdshark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
return get_iree_gpu_args()
|
||||
if device == "vulkan":
|
||||
from amdshark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device == "metal":
|
||||
from amdshark.iree_utils.metal_utils import get_iree_metal_args
|
||||
|
||||
return get_iree_metal_args(extra_args=extra_args)
|
||||
if device == "rocm":
|
||||
from amdshark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
|
||||
if device == "hip":
|
||||
from amdshark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args, hip_driver=True)
|
||||
return []
|
||||
|
||||
def get_iree_target_triple(device):
|
||||
args = get_iree_device_args(device)
|
||||
for flag in args:
|
||||
if "triple" in flag:
|
||||
triple = flag.split("=")[-1]
|
||||
return triple
|
||||
return ""
|
||||
|
||||
|
||||
def clean_device_info(raw_device):
|
||||
# return appropriate device and device_id for consumption by Studio pipeline
|
||||
# Multiple devices only supported for vulkan and rocm (as of now).
|
||||
# default device must be selected for all others
|
||||
|
||||
device_id = None
|
||||
device = (
|
||||
raw_device
|
||||
if "=>" not in raw_device
|
||||
else raw_device.split("=>")[1].strip()
|
||||
)
|
||||
if "://" in device:
|
||||
device, device_id = device.split("://")
|
||||
if len(device_id) <= 2:
|
||||
device_id = int(device_id)
|
||||
|
||||
if device not in ["hip", "rocm", "vulkan"]:
|
||||
device_id = None
|
||||
if device in ["hip", "rocm", "vulkan"] and device_id == None:
|
||||
device_id = 0
|
||||
return device, device_id
|
||||
|
||||
|
||||
# Get the iree-compiler arguments given frontend.
|
||||
def get_iree_frontend_args(frontend):
|
||||
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
|
||||
return ["--iree-llvmcpu-target-cpu-features=host"]
|
||||
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
return [
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-input-demote-i64-to-i32",
|
||||
]
|
||||
else:
|
||||
# Frontend not found.
|
||||
return []
|
||||
|
||||
|
||||
# Common args to be used given any frontend or device.
|
||||
def get_iree_common_args(debug=False):
|
||||
common_args = [
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
"--mlir-elide-elementsattrs-if-larger=10",
|
||||
]
|
||||
if debug == True:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=false",
|
||||
"--verify=true",
|
||||
]
|
||||
)
|
||||
else:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=true",
|
||||
"--verify=false",
|
||||
]
|
||||
)
|
||||
return common_args
|
||||
|
||||
|
||||
# Args that are suitable only for certain models or groups of models.
|
||||
# amdshark_args are passed down from pytests to control which models compile with these flags,
|
||||
# but they can also be set in amdshark/parser.py
|
||||
def get_model_specific_args():
|
||||
ms_args = []
|
||||
if amdshark_args.enable_conv_transform == True:
|
||||
ms_args += [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc))"
|
||||
]
|
||||
if amdshark_args.enable_img2col_transform == True:
|
||||
ms_args += [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col))"
|
||||
]
|
||||
if amdshark_args.use_winograd == True:
|
||||
ms_args += [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
return ms_args
|
||||
|
||||
|
||||
def create_dispatch_dirs(bench_dir, device):
|
||||
protected_files = ["ordered-dispatches.txt"]
|
||||
bench_dir_path = bench_dir.split("/")
|
||||
bench_dir_path[-1] = "temp_" + bench_dir_path[-1]
|
||||
tmp_bench_dir = "/".join(bench_dir_path)
|
||||
for f_ in os.listdir(bench_dir):
|
||||
if os.path.isfile(f"{bench_dir}/{f_}") and f_ not in protected_files:
|
||||
dir_name = re.sub("\.\S*$", "", f_)
|
||||
if os.path.exists(f"{bench_dir}/{dir_name}"):
|
||||
os.system(f"rm -rf {bench_dir}/{dir_name}")
|
||||
os.system(f"mkdir {bench_dir}/{dir_name}")
|
||||
os.system(f"mv {bench_dir}/{f_} {bench_dir}/{dir_name}/{f_}")
|
||||
for f_ in os.listdir(tmp_bench_dir):
|
||||
if os.path.isfile(f"{tmp_bench_dir}/{f_}"):
|
||||
dir_name = ""
|
||||
for d_ in os.listdir(bench_dir):
|
||||
if re.search(f"{d_}(?=\D)", f_):
|
||||
dir_name = d_
|
||||
if dir_name != "":
|
||||
os.system(
|
||||
f"mv {tmp_bench_dir}/{f_} {bench_dir}/{dir_name}/{dir_name}_benchmark.mlir"
|
||||
)
|
||||
|
||||
|
||||
def dump_isas(bench_dir):
|
||||
for d_ in os.listdir(bench_dir):
|
||||
if os.path.isdir(f"{bench_dir}/{d_}"):
|
||||
for f_ in os.listdir(f"{bench_dir}/{d_}"):
|
||||
if f_.endswith(".spv"):
|
||||
os.system(
|
||||
f"amdllpc -gfxip 11.0 {bench_dir}/{d_}/{f_} -v > \
|
||||
{bench_dir}/{d_}/isa.txt"
|
||||
)
|
||||
|
||||
|
||||
def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
|
||||
benchmark_runtimes = {}
|
||||
dispatch_list = []
|
||||
all_dispatches = False
|
||||
|
||||
if dispatch_benchmarks.lower().strip() == "all":
|
||||
all_dispatches = True
|
||||
else:
|
||||
try:
|
||||
dispatch_list = [
|
||||
int(dispatch_index)
|
||||
for dispatch_index in dispatch_benchmarks.split(" ")
|
||||
]
|
||||
except:
|
||||
print("ERROR: Invalid dispatch benchmarks")
|
||||
return None
|
||||
for d_ in os.listdir(bench_dir):
|
||||
if os.path.isdir(f"{bench_dir}/{d_}"):
|
||||
in_dispatches = False
|
||||
for dispatch in dispatch_list:
|
||||
if str(dispatch) in d_:
|
||||
in_dispatches = True
|
||||
if all_dispatches or in_dispatches:
|
||||
for f_ in os.listdir(f"{bench_dir}/{d_}"):
|
||||
if "benchmark.mlir" in f_:
|
||||
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
|
||||
module = dispatch_file.read()
|
||||
dispatch_file.close()
|
||||
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module, target_backends=[iree_target_map(device)]
|
||||
)
|
||||
|
||||
vmfb_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_benchmark.vmfb", "wb"
|
||||
)
|
||||
vmfb_file.write(flatbuffer_blob)
|
||||
vmfb_file.close()
|
||||
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
config.vm_instance,
|
||||
flatbuffer_blob,
|
||||
warn_if_copy=False,
|
||||
)
|
||||
|
||||
benchmark_cl = build_benchmark_args_non_tensor_input(
|
||||
input_file=f"{bench_dir}/{d_}/{d_}_benchmark.vmfb",
|
||||
device=device,
|
||||
inputs=(0,),
|
||||
mlir_dialect="linalg",
|
||||
function_name="",
|
||||
)
|
||||
|
||||
benchmark_bash = open(
|
||||
f"{bench_dir}/{d_}/{d_}_benchmark.sh", "w+"
|
||||
)
|
||||
benchmark_bash.write("#!/bin/bash\n")
|
||||
benchmark_bash.write(" ".join(benchmark_cl))
|
||||
benchmark_bash.close()
|
||||
|
||||
iter_per_second, _, _ = run_benchmark_module(
|
||||
benchmark_cl
|
||||
)
|
||||
|
||||
benchmark_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
|
||||
)
|
||||
benchmark_file.write(f"DISPATCH: {d_}\n")
|
||||
benchmark_file.write(str(iter_per_second) + "\n")
|
||||
benchmark_file.write(
|
||||
"AMDSHARK BENCHMARK RESULT: "
|
||||
+ str(1 / (iter_per_second * 0.001))
|
||||
+ "\n"
|
||||
)
|
||||
benchmark_file.close()
|
||||
|
||||
benchmark_runtimes[d_] = 1 / (iter_per_second * 0.001)
|
||||
|
||||
elif ".mlir" in f_ and "benchmark" not in f_:
|
||||
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
|
||||
module = dispatch_file.read()
|
||||
dispatch_file.close()
|
||||
|
||||
module = re.sub(
|
||||
"hal.executable private",
|
||||
"hal.executable public",
|
||||
module,
|
||||
)
|
||||
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=["--compile-mode=hal-executable"],
|
||||
)
|
||||
|
||||
spirv_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_spirv.vmfb", "wb"
|
||||
)
|
||||
spirv_file.write(flatbuffer_blob)
|
||||
spirv_file.close()
|
||||
|
||||
ordered_dispatches = [
|
||||
(k, v)
|
||||
for k, v in sorted(
|
||||
benchmark_runtimes.items(), key=lambda item: item[1]
|
||||
)
|
||||
][::-1]
|
||||
f_ = open(f"{bench_dir}/ordered-dispatches.txt", "w+")
|
||||
for dispatch in ordered_dispatches:
|
||||
f_.write(f"{dispatch[0]}: {dispatch[1]}ms\n")
|
||||
f_.close()
|
||||
|
||||
|
||||
def compile_module_to_flatbuffer(
|
||||
module,
|
||||
device,
|
||||
frontend,
|
||||
model_config_path,
|
||||
extra_args,
|
||||
model_name="None",
|
||||
debug=False,
|
||||
compile_str=False,
|
||||
write_to=None,
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = "auto"
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args(debug=debug)
|
||||
args += get_model_specific_args()
|
||||
args += extra_args
|
||||
args += amdshark_args.additional_compile_args
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "auto"
|
||||
elif frontend in ["stablehlo", "tosa"]:
|
||||
input_type = frontend
|
||||
elif frontend in ["tflite", "tflite-tosa"]:
|
||||
input_type = "tosa"
|
||||
elif frontend in ["tm_tensor"]:
|
||||
input_type = ireec.InputType.TM_TENSOR
|
||||
elif frontend in ["torch", "pytorch"]:
|
||||
input_type = "torch"
|
||||
|
||||
if compile_str:
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=args,
|
||||
input_type=input_type,
|
||||
)
|
||||
else:
|
||||
assert os.path.isfile(module)
|
||||
flatbuffer_blob = ireec.compile_file(
|
||||
str(module),
|
||||
input_type=input_type,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=args,
|
||||
)
|
||||
|
||||
if write_to is not None:
|
||||
with open(write_to, "wb") as f:
|
||||
f.write(flatbuffer_blob)
|
||||
return None
|
||||
|
||||
return flatbuffer_blob
|
||||
|
||||
|
||||
def get_iree_module(
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=None,
|
||||
rt_flags: list = [],
|
||||
external_weight_file=None,
|
||||
):
|
||||
if external_weight_file is not None:
|
||||
index = ireert.ParameterIndex()
|
||||
index.load(external_weight_file)
|
||||
# Returns the compiled module and the configs.
|
||||
for flag in rt_flags:
|
||||
ireert.flags.parse_flag(flag)
|
||||
if device_idx is not None:
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
haldriver = ireert.get_driver(device)
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
hal_device_id,
|
||||
allocators=amdshark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = hal_device_id
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
config.vm_instance, flatbuffer_blob, warn_if_copy=False
|
||||
)
|
||||
modules = []
|
||||
if external_weight_file is not None:
|
||||
modules.append(index.create_provider(scope="model"))
|
||||
ctx = ireert.SystemContext(vm_modules=modules, config=config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
ModuleCompiled = getattr(ctx.modules, vm_module.name)
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
def load_vmfb_using_mmap(
|
||||
flatbuffer_blob_or_path,
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
rt_flags: list = [],
|
||||
external_weight_file: str = None,
|
||||
):
|
||||
print(f"Loading module {flatbuffer_blob_or_path}...")
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
rt_flags.append(flag)
|
||||
for flag in rt_flags:
|
||||
print(flag)
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
if "rocm" in device:
|
||||
device = "rocm"
|
||||
with DetailLogger(timeout=2.5) as dl:
|
||||
# First get configs.
|
||||
if device_idx is not None:
|
||||
dl.log(f"Mapping device id: {device_idx}")
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
dl.log(f"ireert.get_driver()")
|
||||
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
hal_device_id,
|
||||
allocators=amdshark_args.device_allocator,
|
||||
)
|
||||
dl.log(f"ireert.create_device()")
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = hal_device_id
|
||||
dl.log(f"ireert.Config()")
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
dl.log("get_iree_runtime_config")
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
# (This would arise if we're invoking `load_module` from a AMDSharkInference obj)
|
||||
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
|
||||
# (This would arise if we're invoking `compile` from a AMDSharkInference obj)
|
||||
temp_file_to_unlink = None
|
||||
if isinstance(flatbuffer_blob_or_path, Path):
|
||||
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
|
||||
if (
|
||||
isinstance(flatbuffer_blob_or_path, str)
|
||||
and ".vmfb" in flatbuffer_blob_or_path
|
||||
):
|
||||
vmfb_file_path = flatbuffer_blob_or_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(
|
||||
config.vm_instance, flatbuffer_blob_or_path
|
||||
)
|
||||
vm_modules = []
|
||||
if external_weight_file is not None:
|
||||
index = ireert.ParameterIndex()
|
||||
index.load(external_weight_file)
|
||||
param_module = ireert.create_io_parameters_module(
|
||||
config.vm_instance, index.create_provider(scope="model")
|
||||
)
|
||||
vm_modules.append(param_module)
|
||||
vm_modules.append(mmaped_vmfb)
|
||||
vm_modules.append(
|
||||
ireert.create_hal_module(config.vm_instance, config.device)
|
||||
)
|
||||
dl.log(f"mmap {flatbuffer_blob_or_path}")
|
||||
if "vulkan" in device:
|
||||
# Vulkan pipeline creation consumes significant amount of time.
|
||||
print(
|
||||
"\tCompiling Vulkan shaders. This may take a few minutes."
|
||||
)
|
||||
ctx = ireert.SystemContext(config=config, vm_modules=vm_modules)
|
||||
dl.log(f"ireert.SystemContext created")
|
||||
for flag in amdshark_args.additional_runtime_args:
|
||||
ireert.flags.parse_flags(flag)
|
||||
dl.log(f"module initialized")
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
tf.write(flatbuffer_blob_or_path)
|
||||
tf.flush()
|
||||
vmfb_file_path = tf.name
|
||||
temp_file_to_unlink = vmfb_file_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
|
||||
dl.log(f"mmap temp {vmfb_file_path}")
|
||||
return mmaped_vmfb, config, temp_file_to_unlink
|
||||
|
||||
|
||||
def get_iree_compiled_module(
|
||||
module,
|
||||
device: str,
|
||||
frontend: str = "torch",
|
||||
model_config_path: str = None,
|
||||
extra_args: list = [],
|
||||
rt_flags: list = [],
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
debug: bool = False,
|
||||
compile_str: bool = False,
|
||||
external_weight_file: str = None,
|
||||
write_to: bool = None,
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module=module,
|
||||
device=device,
|
||||
frontend=frontend,
|
||||
model_config_path=model_config_path,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=compile_str,
|
||||
write_to=write_to,
|
||||
)
|
||||
temp_file_to_unlink = None
|
||||
# TODO: Currently mmap=True control flow path has been switched off for mmap.
|
||||
# Got to find a cleaner way to unlink/delete the temporary file since
|
||||
# we're setting delete=False when creating NamedTemporaryFile. That's why
|
||||
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
|
||||
if mmap:
|
||||
if write_to is not None:
|
||||
flatbuffer_blob = write_to
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx,
|
||||
rt_flags,
|
||||
external_weight_file=external_weight_file,
|
||||
)
|
||||
else:
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=device_idx,
|
||||
rt_flags=rt_flags,
|
||||
external_weight_file=external_weight_file,
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
"config": config,
|
||||
"temp_file_to_unlink": temp_file_to_unlink,
|
||||
}
|
||||
return ret_params
|
||||
|
||||
|
||||
def load_flatbuffer(
|
||||
flatbuffer_path: str,
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
temp_file_to_unlink = None
|
||||
if mmap:
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_path, device, device_idx, rt_flags
|
||||
)
|
||||
else:
|
||||
with open(os.path.join(flatbuffer_path), "rb") as f:
|
||||
flatbuffer_blob = f.read()
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=device_idx,
|
||||
rt_flags=rt_flags,
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
"config": config,
|
||||
"temp_file_to_unlink": temp_file_to_unlink,
|
||||
}
|
||||
return ret_params
|
||||
|
||||
|
||||
def export_iree_module_to_vmfb(
|
||||
module,
|
||||
device: str,
|
||||
directory: str,
|
||||
mlir_dialect: str = "linalg",
|
||||
model_config_path: str = None,
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
debug: bool = False,
|
||||
compile_str: bool = False,
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module=module,
|
||||
device=device,
|
||||
frontend=mlir_dialect,
|
||||
model_config_path=model_config_path,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=compile_str,
|
||||
)
|
||||
if module_name is None:
|
||||
device_name = (
|
||||
device if "://" not in device else "-".join(device.split("://"))
|
||||
)
|
||||
module_name = f"{mlir_dialect}_{device_name}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
with open(filename, "wb") as f:
|
||||
f.write(flatbuffer_blob)
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
return filename
|
||||
|
||||
|
||||
def export_module_to_mlir_file(module, frontend, directory: str):
|
||||
# TODO: write proper documentation.
|
||||
mlir_str = module
|
||||
if frontend in ["tensorflow", "tf", "mhlo", "stablehlo", "tflite"]:
|
||||
mlir_str = module.decode("utf-8")
|
||||
elif frontend in ["pytorch", "torch"]:
|
||||
mlir_str = module.operation.get_asm()
|
||||
filename = os.path.join(directory, "model.mlir")
|
||||
with open(filename, "w") as f:
|
||||
f.write(mlir_str)
|
||||
print(f"Saved mlir in {filename}.")
|
||||
return filename
|
||||
|
||||
|
||||
def get_results(
|
||||
compiled_vm,
|
||||
function_name,
|
||||
input,
|
||||
config,
|
||||
frontend="torch",
|
||||
send_to_host=True,
|
||||
debug_timeout: float = 5.0,
|
||||
device: str = None,
|
||||
):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
with DetailLogger(debug_timeout) as dl:
|
||||
device_inputs = []
|
||||
if device == "rocm" and hasattr(config, "id"):
|
||||
haldriver = ireert.get_driver("rocm")
|
||||
haldevice = haldriver.create_device(
|
||||
config.id,
|
||||
allocators=amdshark_args.device_allocator,
|
||||
)
|
||||
for input_array in input:
|
||||
dl.log(f"Load to device: {input_array.shape}")
|
||||
device_inputs.append(
|
||||
ireert.asdevicearray(config.device, input_array)
|
||||
)
|
||||
dl.log(f"Invoke function: {function_name}")
|
||||
result = compiled_vm[function_name](*device_inputs)
|
||||
dl.log(f"Invoke complete")
|
||||
result_tensors = []
|
||||
if isinstance(result, tuple):
|
||||
if send_to_host:
|
||||
for val in result:
|
||||
dl.log(f"Result to host: {val.shape}")
|
||||
result_tensors.append(np.asarray(val, val.dtype))
|
||||
else:
|
||||
for val in result:
|
||||
result_tensors.append(val)
|
||||
return result_tensors
|
||||
elif isinstance(result, dict):
|
||||
data = list(result.items())
|
||||
if send_to_host:
|
||||
res = np.array(data, dtype=object)
|
||||
return np.copy(res)
|
||||
return data
|
||||
else:
|
||||
if send_to_host and result is not None:
|
||||
dl.log("Result to host")
|
||||
return result.to_host()
|
||||
return result
|
||||
dl.log("Execution complete")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_iree_runtime_config(device):
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
if "metal" in device and amdshark_args.device_allocator == "caching":
|
||||
print(
|
||||
"[WARNING] metal devices can not have a `caching` allocator."
|
||||
"\nUsing default allocator `None`"
|
||||
)
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
|
||||
allocators=amdshark_args.device_allocator
|
||||
if "metal" not in device
|
||||
else None,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
return config
|
||||
@@ -1,65 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# All the iree_cpu related functionalities go here.
|
||||
|
||||
import functools
|
||||
import subprocess
|
||||
import platform
|
||||
from amdshark.parser import amdshark_args
|
||||
|
||||
|
||||
def get_cpu_count():
|
||||
import multiprocessing
|
||||
|
||||
try:
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
return cpu_count
|
||||
except NotImplementedError:
|
||||
return None
|
||||
|
||||
|
||||
# Get the default cpu args.
|
||||
@functools.cache
|
||||
def get_iree_cpu_args():
|
||||
uname = platform.uname()
|
||||
os_name, proc_name = uname.system, uname.machine
|
||||
|
||||
if os_name == "Darwin":
|
||||
kernel_version = uname.release
|
||||
target_triple = f"{proc_name}-apple-darwin{kernel_version}"
|
||||
elif os_name == "Linux":
|
||||
target_triple = f"{proc_name}-linux-gnu"
|
||||
elif os_name == "Windows":
|
||||
target_triple = "x86_64-pc-windows-msvc"
|
||||
else:
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dAMDSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
print(f"Target triple found:{target_triple}")
|
||||
return [
|
||||
f"--iree-llvmcpu-target-triple={target_triple}",
|
||||
]
|
||||
|
||||
|
||||
# Get iree runtime flags for cpu
|
||||
@functools.cache
|
||||
def get_iree_cpu_rt_args():
|
||||
default = get_cpu_count()
|
||||
default = default if default <= 8 else default - 2
|
||||
cpu_count = (
|
||||
default
|
||||
if amdshark_args.task_topology_max_group_count is None
|
||||
else amdshark_args.task_topology_max_group_count
|
||||
)
|
||||
return [f"--task_topology_max_group_count={cpu_count}"]
|
||||
@@ -1,209 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# All the iree_gpu related functionalities go here.
|
||||
|
||||
import functools
|
||||
import iree.runtime as ireert
|
||||
import ctypes
|
||||
import sys
|
||||
from subprocess import CalledProcessError
|
||||
from amdshark.parser import amdshark_args
|
||||
from amdshark.iree_utils._common import run_cmd
|
||||
|
||||
# TODO: refactor to rocm and cuda utils
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
@functools.cache
|
||||
def get_iree_gpu_args():
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
ireert.flags.parse_flags("--cuda_allow_inline_execution")
|
||||
# TODO: Give the user_interface to pass the sm_arch.
|
||||
sm_arch = get_cuda_sm_cc()
|
||||
if (
|
||||
sm_arch
|
||||
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
|
||||
) and (amdshark_args.enable_tf32 == True):
|
||||
return [
|
||||
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def check_rocm_device_arch_in_args(extra_args):
|
||||
# Check if the target arch flag for rocm device present in extra_args
|
||||
for flag in extra_args:
|
||||
if "iree-rocm-target-chip" in flag:
|
||||
flag_arch = flag.split("=")[1]
|
||||
return flag_arch
|
||||
return None
|
||||
|
||||
|
||||
def get_rocm_device_arch(device_num=0, extra_args=[], hip_driver=False):
|
||||
# ROCM Device Arch selection:
|
||||
# 1 : User given device arch using `--iree-rocm-target-chip` flag
|
||||
# 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index <device_num>
|
||||
# 3 : default arch : gfx1100
|
||||
|
||||
arch_in_flag = check_rocm_device_arch_in_args(extra_args)
|
||||
if arch_in_flag is not None:
|
||||
print(
|
||||
f"User Specified rocm target device arch from flag : {arch_in_flag} will be used"
|
||||
)
|
||||
return arch_in_flag
|
||||
|
||||
arch_in_device_dump = None
|
||||
|
||||
# get rocm arch from iree dump devices
|
||||
def get_devices_info_from_dump(dump, driver):
|
||||
from os import linesep
|
||||
|
||||
if driver == "hip":
|
||||
dump_clean = list(
|
||||
filter(
|
||||
lambda s: "AMD" in s,
|
||||
dump.split(linesep),
|
||||
)
|
||||
)
|
||||
else:
|
||||
dump_clean = list(
|
||||
filter(
|
||||
lambda s: f"--device={driver}" in s or "gpu-arch-name:" in s,
|
||||
dump.split(linesep),
|
||||
)
|
||||
)
|
||||
arch_pairs = [
|
||||
(
|
||||
dump_clean[i].split("=")[1].strip(),
|
||||
dump_clean[i + 1].split(":")[1].strip(),
|
||||
)
|
||||
for i in range(0, len(dump_clean), 2)
|
||||
]
|
||||
return arch_pairs
|
||||
|
||||
dump_device_info = None
|
||||
driver = "hip" if hip_driver else "rocm"
|
||||
try:
|
||||
dump_device_info = run_cmd(
|
||||
"iree-run-module --dump_devices=" + driver, raise_err=True
|
||||
)
|
||||
except Exception as e:
|
||||
print("could not execute `iree-run-module --dump_devices=" + driver + "`")
|
||||
|
||||
if dump_device_info is not None:
|
||||
device_num = 0 if device_num is None else device_num
|
||||
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0], driver)
|
||||
if len(device_arch_pairs) > device_num: # can find arch in the list
|
||||
arch_in_device_dump = device_arch_pairs[device_num][1]
|
||||
|
||||
if arch_in_device_dump is not None:
|
||||
print(f"Found ROCm device arch : {arch_in_device_dump}")
|
||||
return arch_in_device_dump
|
||||
|
||||
default_rocm_arch = "gfx1100"
|
||||
print(
|
||||
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
|
||||
"\n or from `iree-run-module --dump_devices` command."
|
||||
f"\nUsing {default_rocm_arch} as ROCm arch for compilation."
|
||||
)
|
||||
return default_rocm_arch
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
def get_iree_rocm_args(device_num=0, extra_args=[], hip_driver=False):
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
rocm_flags = []
|
||||
if check_rocm_device_arch_in_args(extra_args) is None:
|
||||
rocm_arch = get_rocm_device_arch(device_num, extra_args, hip_driver=hip_driver)
|
||||
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
|
||||
|
||||
return rocm_flags
|
||||
|
||||
# Some constants taken from cuda.h
|
||||
CUDA_SUCCESS = 0
|
||||
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16
|
||||
CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR = 39
|
||||
CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13
|
||||
CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_cuda_sm_cc():
|
||||
libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll")
|
||||
for libname in libnames:
|
||||
try:
|
||||
cuda = ctypes.CDLL(libname)
|
||||
except OSError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise OSError("could not load any of: " + " ".join(libnames))
|
||||
|
||||
nGpus = ctypes.c_int()
|
||||
name = b" " * 100
|
||||
cc_major = ctypes.c_int()
|
||||
cc_minor = ctypes.c_int()
|
||||
|
||||
result = ctypes.c_int()
|
||||
device = ctypes.c_int()
|
||||
context = ctypes.c_void_p()
|
||||
error_str = ctypes.c_char_p()
|
||||
|
||||
result = cuda.cuInit(0)
|
||||
if result != CUDA_SUCCESS:
|
||||
cuda.cuGetErrorString(result, ctypes.byref(error_str))
|
||||
print(
|
||||
"cuInit failed with error code %d: %s"
|
||||
% (result, error_str.value.decode())
|
||||
)
|
||||
return 1
|
||||
result = cuda.cuDeviceGetCount(ctypes.byref(nGpus))
|
||||
if result != CUDA_SUCCESS:
|
||||
cuda.cuGetErrorString(result, ctypes.byref(error_str))
|
||||
print(
|
||||
"cuDeviceGetCount failed with error code %d: %s"
|
||||
% (result, error_str.value.decode())
|
||||
)
|
||||
return 1
|
||||
print("Found %d device(s)." % nGpus.value)
|
||||
for i in range(nGpus.value):
|
||||
result = cuda.cuDeviceGet(ctypes.byref(device), i)
|
||||
if result != CUDA_SUCCESS:
|
||||
cuda.cuGetErrorString(result, ctypes.byref(error_str))
|
||||
print(
|
||||
"cuDeviceGet failed with error code %d: %s"
|
||||
% (result, error_str.value.decode())
|
||||
)
|
||||
return 1
|
||||
print("Device: %d" % i)
|
||||
if (
|
||||
cuda.cuDeviceGetName(ctypes.c_char_p(name), len(name), device)
|
||||
== CUDA_SUCCESS
|
||||
):
|
||||
print(" Name: %s" % (name.split(b"\0", 1)[0].decode()))
|
||||
if (
|
||||
cuda.cuDeviceComputeCapability(
|
||||
ctypes.byref(cc_major), ctypes.byref(cc_minor), device
|
||||
)
|
||||
== CUDA_SUCCESS
|
||||
):
|
||||
print(
|
||||
" Compute Capability: %d.%d"
|
||||
% (cc_major.value, cc_minor.value)
|
||||
)
|
||||
sm = f"sm_{cc_major.value}{cc_minor.value}"
|
||||
return sm
|
||||
@@ -1,102 +0,0 @@
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# All the iree_vulkan related functionalities go here.
|
||||
|
||||
import functools
|
||||
|
||||
from amdshark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
from sys import platform
|
||||
from amdshark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_metal_device_name(device_num=0):
|
||||
iree_device_dump = run_cmd("iree-run-module --dump_devices")
|
||||
iree_device_dump = iree_device_dump[0].split("\n\n")
|
||||
metal_device_list = [
|
||||
s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s
|
||||
]
|
||||
if len(metal_device_list) == 0:
|
||||
raise ValueError("No device name found in device dump!")
|
||||
if len(metal_device_list) > 1:
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(metal_device_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing device: {metal_device_list[device_num]}")
|
||||
return metal_device_list[device_num]
|
||||
|
||||
|
||||
def get_os_name():
|
||||
if platform.startswith("linux"):
|
||||
return "linux"
|
||||
elif platform == "darwin":
|
||||
return "macos"
|
||||
elif platform == "win32":
|
||||
return "windows"
|
||||
else:
|
||||
print("Cannot detect OS type, defaulting to linux.")
|
||||
return "linux"
|
||||
|
||||
|
||||
def get_metal_target_triple(device_name):
|
||||
"""This method provides a target triple str for specified vulkan device.
|
||||
|
||||
Args:
|
||||
device_name (str): name of the hardware device to be used with vulkan
|
||||
|
||||
Returns:
|
||||
str or None: target triple or None if no match found for given name
|
||||
"""
|
||||
return "macos"
|
||||
|
||||
|
||||
def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
for flag in extra_args:
|
||||
if "-iree-metal-target-platform=" in flag:
|
||||
print(f"Using target triple {flag.split('=')[1]}")
|
||||
return None
|
||||
|
||||
if device_name == "" or device_name == [] or device_name is None:
|
||||
metal_device = get_metal_device_name(device_num=device_num)
|
||||
else:
|
||||
metal_device = device_name
|
||||
triple = get_metal_target_triple(metal_device)
|
||||
if triple is not None:
|
||||
print(
|
||||
f"Found metal device {metal_device}. Using metal target platform {triple}"
|
||||
)
|
||||
return f"-iree-metal-target-platform={triple}"
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
Contact AMDSHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
or pull up an issue."""
|
||||
)
|
||||
print(f"Target : {metal_device}")
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_metal_args(device_num=0, extra_args=[]):
|
||||
# Add any metal spefic compilation flags here
|
||||
res_metal_flag = []
|
||||
if len(extra_args) > 0:
|
||||
res_metal_flag.extend(extra_args)
|
||||
return res_metal_flag
|
||||
|
||||
|
||||
def set_iree_metal_runtime_flags(flags):
|
||||
for flag in flags:
|
||||
ireert.flags.parse_flags(flag)
|
||||
return
|
||||
@@ -1,76 +0,0 @@
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
def _enable_detail_trace() -> bool:
|
||||
return os.getenv("AMDSHARK_DETAIL_TRACE", "0") == "1"
|
||||
|
||||
|
||||
class DetailLogger:
|
||||
"""Context manager which can accumulate detailed log messages.
|
||||
|
||||
Detailed log is only emitted if the operation takes a long time
|
||||
or errors.
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: float):
|
||||
self._timeout = timeout
|
||||
self._messages: List[Tuple[float, str]] = []
|
||||
self._start_time = time.time()
|
||||
self._active = not _enable_detail_trace()
|
||||
self._lock = threading.RLock()
|
||||
self._cond = threading.Condition(self._lock)
|
||||
self._thread = None
|
||||
|
||||
def __enter__(self):
|
||||
self._thread = threading.Thread(target=self._run)
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
with self._lock:
|
||||
self._active = False
|
||||
self._cond.notify()
|
||||
if traceback:
|
||||
self.dump_on_error(f"exception")
|
||||
|
||||
def _run(self):
|
||||
with self._lock:
|
||||
timed_out = not self._cond.wait(self._timeout)
|
||||
if timed_out:
|
||||
self.dump_on_error(f"took longer than {self._timeout}s")
|
||||
|
||||
def log(self, msg):
|
||||
with self._lock:
|
||||
timestamp = time.time()
|
||||
if self._active:
|
||||
self._messages.append((timestamp, msg))
|
||||
else:
|
||||
print(f" +{(timestamp - self._start_time) * 1000}ms: {msg}")
|
||||
|
||||
def dump_on_error(self, summary: str):
|
||||
with self._lock:
|
||||
if self._active:
|
||||
print(f"::: Detailed report ({summary}):")
|
||||
for timestamp, msg in self._messages:
|
||||
print(
|
||||
f" +{(timestamp - self._start_time) * 1000}ms: {msg}"
|
||||
)
|
||||
self._active = False
|
||||
@@ -1,538 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict
|
||||
import functools
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_target_env(vulkan_target_triple):
|
||||
arch, product, os = vulkan_target_triple.split("=")[1].split("-")
|
||||
triple = (arch, product, os)
|
||||
# get version
|
||||
version = get_version(triple=triple)
|
||||
# TODO get revision
|
||||
revision = 120
|
||||
|
||||
# extensions
|
||||
extensions = get_extensions(triple)
|
||||
# get vendor
|
||||
vendor = get_vendor(triple)
|
||||
# get device type
|
||||
device_type = get_device_type(triple)
|
||||
# get capabilities
|
||||
capabilities = get_vulkan_target_capabilities(triple)
|
||||
target_env = f"<#spirv.vce<{version}, r({revision}), {extensions}>, {vendor}:{device_type}, #spirv.resource_limits< {capabilities} >>"
|
||||
return target_env
|
||||
|
||||
|
||||
def get_vulkan_target_env_flag(vulkan_target_triple):
|
||||
target_env = get_vulkan_target_env(vulkan_target_triple)
|
||||
target_env_flag = f"--iree-vulkan-target-env={target_env}"
|
||||
return target_env_flag
|
||||
|
||||
|
||||
def get_version(triple):
|
||||
arch, product, os = triple
|
||||
if os in ["android30", "android31"]:
|
||||
return "v1.1"
|
||||
if product in ["android30", "android31"]:
|
||||
return "v1.1"
|
||||
if arch in ["unknown"]:
|
||||
return "v1.1"
|
||||
return "v1.3"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_extensions(triple):
|
||||
def make_ext_list(ext_list):
|
||||
res = ", ".join(ext_list)
|
||||
return f"[{res}]"
|
||||
|
||||
arch, product, os = triple
|
||||
if arch == "m1":
|
||||
ext = [
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "valhall":
|
||||
ext = [
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "adreno":
|
||||
ext = [
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
if os == "android31":
|
||||
ext.append("SPV_KHR_8bit_storage")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if get_vendor(triple) == "SwiftShader":
|
||||
ext = ["SPV_KHR_storage_buffer_storage_class"]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "unknown":
|
||||
ext = [
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
ext = [
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
"VK_EXT_subgroup_size_control",
|
||||
]
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("SPV_KHR_cooperative_matrix")
|
||||
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
|
||||
ext.append("SPV_KHR_shader_integer_dot_product")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vendor(triple):
|
||||
arch, product, os = triple
|
||||
if arch == "unknown":
|
||||
return "Unknown"
|
||||
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn4", "rgcn5"]:
|
||||
return "AMD"
|
||||
if arch == "valhall":
|
||||
return "ARM"
|
||||
if arch == "m1":
|
||||
return "Apple"
|
||||
if arch in ["arc", "UHD"]:
|
||||
return "Intel"
|
||||
if arch in ["turing", "ampere", "pascal"]:
|
||||
return "NVIDIA"
|
||||
if arch == "adreno":
|
||||
return "Qualcomm"
|
||||
if arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
return "SwiftShader"
|
||||
return "Unknown"
|
||||
print(f"Vendor for target triple - {triple} not found. Using unknown")
|
||||
return "Unknown"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_device_type(triple):
|
||||
arch, product, _ = triple
|
||||
if arch == "unknown":
|
||||
return "Unknown"
|
||||
if arch == "cpu":
|
||||
return "CPU"
|
||||
if arch in ["turing", "ampere", "arc", "pascal"]:
|
||||
return "DiscreteGPU"
|
||||
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]:
|
||||
if product == "ivega10":
|
||||
return "IntegratedGPU"
|
||||
return "DiscreteGPU"
|
||||
if arch in ["m1", "valhall", "adreno"]:
|
||||
return "IntegratedGPU"
|
||||
print(f"Device type for target triple - {triple} not found. Using unknown")
|
||||
return "Unknown"
|
||||
|
||||
|
||||
# get all the capabilities for the device
|
||||
# TODO: make a dataclass for capabilites and init using vulkaninfo
|
||||
@functools.cache
|
||||
def get_vulkan_target_capabilities(triple):
|
||||
def get_subgroup_val(l):
|
||||
return int(sum([subgroup_feature[sgf] for sgf in l]))
|
||||
|
||||
cap = OrderedDict()
|
||||
arch, product, os = triple
|
||||
subgroup_feature = {
|
||||
"Basic": 1,
|
||||
"Vote": 2,
|
||||
"Arithmetic": 4,
|
||||
"Ballot": 8,
|
||||
"Shuffle": 16,
|
||||
"ShuffleRelative": 32,
|
||||
"Clustered": 64,
|
||||
"Quad": 128,
|
||||
"PartitionedNV": 256,
|
||||
}
|
||||
cap["max_compute_shared_memory_size"] = 16384
|
||||
cap["max_compute_workgroup_invocations"] = 128
|
||||
cap["max_compute_workgroup_size"] = [128, 128, 64]
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = ["Basic"]
|
||||
cap["min_subgroup_size"] = None
|
||||
cap["max_subgroup_size"] = None
|
||||
cap["shaderFloat16"] = False
|
||||
cap["shaderFloat64"] = False
|
||||
cap["shaderInt8"] = False
|
||||
cap["shaderInt16"] = False
|
||||
cap["shaderInt64"] = False
|
||||
cap["storageBuffer16BitAccess"] = False
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = False
|
||||
cap["storageBuffer8BitAccess"] = False
|
||||
cap["storagePushConstant8"] = False
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = False
|
||||
cap["variablePointers"] = False
|
||||
cap["variablePointersStorageBuffer"] = False
|
||||
cap["coopmatCases"] = None
|
||||
|
||||
if arch in ["rdna1", "rdna2", "rdna3"]:
|
||||
cap["max_compute_shared_memory_size"] = 65536
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroup_size"] = 64
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>",
|
||||
"m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>"
|
||||
]
|
||||
|
||||
if product == "rx5700xt":
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["storagePushConstant8"] = False
|
||||
|
||||
elif arch in ["rgcn5", "rgcn4", "rgcn3"]:
|
||||
cap["max_compute_shared_memory_size"] = 65536
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
cap["min_subgroup_size"] = 64
|
||||
cap["max_subgroup_size"] = 64
|
||||
|
||||
if arch == "rgcn5":
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = False
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "m1":
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = False
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "valhall":
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 512
|
||||
cap["max_compute_workgroup_size"] = [512, 512, 512]
|
||||
|
||||
cap["subgroup_size"] = 16
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
if os == "android31":
|
||||
cap["subgroupFeatures"].append("Shuffle")
|
||||
cap["subgroupFeatures"].append("ShuffleRelative")
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "arc":
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 64]
|
||||
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = False
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = False
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
cap["max_compute_shared_memory_size"] = 16384
|
||||
cap["subgroup_size"] = 4
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
]
|
||||
|
||||
elif arch in ["pascal"]:
|
||||
cap["max_compute_shared_memory_size"] = 49152
|
||||
cap["max_compute_workgroup_invocations"] = 1536
|
||||
cap["max_compute_workgroup_size"] = [1536, 1024, 64]
|
||||
|
||||
cap["subgroup_size"] = 32
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = False
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch in ["ampere", "turing"]:
|
||||
cap["max_compute_shared_memory_size"] = 49152
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroup_size"] = 32
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
]
|
||||
|
||||
elif arch == "adreno":
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 64]
|
||||
|
||||
cap["subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
if os == "android31":
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "unknown":
|
||||
cap["subgroup_size"] = 64
|
||||
cap["variablePointers"] = False
|
||||
cap["variablePointersStorageBuffer"] = False
|
||||
else:
|
||||
print(
|
||||
f"Architecture {arch} not matched. Using default vulkan target device capability"
|
||||
)
|
||||
|
||||
def get_comma_sep_str(ele_list):
|
||||
l = ""
|
||||
for ele in ele_list:
|
||||
l += f"{ele}, "
|
||||
l = f"[{l[:-2]}]"
|
||||
return l
|
||||
|
||||
res = ""
|
||||
for k, v in cap.items():
|
||||
if v is None or v == False:
|
||||
continue
|
||||
if isinstance(v, bool):
|
||||
res += f"{k} = {'unit' if v == True else None}, "
|
||||
elif isinstance(v, list):
|
||||
if k == "subgroupFeatures":
|
||||
res += f"subgroup_features = {get_subgroup_val(v)}: i32, "
|
||||
elif k == "max_compute_workgroup_size":
|
||||
res += f"max_compute_workgroup_size = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
|
||||
elif k == "coopmatCases":
|
||||
cmc = ""
|
||||
for case in v:
|
||||
cmc += f"#spirv.coop_matrix_props_khr<{case}>, "
|
||||
res += f"cooperative_matrix_properties_khr = [{cmc[:-2]}], "
|
||||
else:
|
||||
res += f"{k} = {get_comma_sep_str(v)}, "
|
||||
else:
|
||||
res += f"{k} = {v}, "
|
||||
res = res[:-2]
|
||||
return res
|
||||
@@ -1,221 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# All the iree_vulkan related functionalities go here.
|
||||
|
||||
import functools
|
||||
from os import linesep
|
||||
from amdshark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
from sys import platform
|
||||
from amdshark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
from amdshark.parser import amdshark_args
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_all_vulkan_devices():
|
||||
from iree.runtime import get_driver
|
||||
|
||||
try:
|
||||
driver = get_driver("vulkan")
|
||||
device_list_src = driver.query_available_devices()
|
||||
except:
|
||||
device_list_src = {}
|
||||
|
||||
return [d["name"] for d in device_list_src]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
if isinstance(device_num, int):
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing device: vulkan://{device_num}")
|
||||
vulkan_device_name = vulkaninfo_list[device_num]
|
||||
else:
|
||||
from iree.runtime import get_driver
|
||||
|
||||
vulkan_device_driver = get_driver(device_num)
|
||||
vulkan_device_name = vulkan_device_driver.query_available_devices()[0]
|
||||
print(vulkan_device_name)
|
||||
return vulkan_device_name
|
||||
|
||||
|
||||
def get_os_name():
|
||||
if platform.startswith("linux"):
|
||||
return "linux"
|
||||
elif platform == "darwin":
|
||||
return "macos"
|
||||
elif platform == "win32":
|
||||
return "windows"
|
||||
else:
|
||||
print("Cannot detect OS type, defaulting to linux.")
|
||||
return "linux"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_target_triple(device_name):
|
||||
"""This method provides a target triple str for specified vulkan device.
|
||||
|
||||
Args:
|
||||
device_name (str): name of the hardware device to be used with vulkan
|
||||
|
||||
Returns:
|
||||
str or None: target triple or None if no match found for given name
|
||||
"""
|
||||
|
||||
# TODO: Replace this with a dict or something smarter.
|
||||
system_os = get_os_name()
|
||||
# Apple Targets
|
||||
if all(x in device_name for x in ("Apple", "M1")):
|
||||
triple = "m1-moltenvk-macos"
|
||||
elif all(x in device_name for x in ("Apple", "M2")):
|
||||
triple = "m1-moltenvk-macos"
|
||||
|
||||
# Nvidia Targets
|
||||
elif all(x in device_name for x in ("RTX", "2080")):
|
||||
triple = f"turing-rtx2080-{system_os}"
|
||||
elif all(x in device_name for x in ("A100", "SXM4")):
|
||||
triple = f"ampere-a100-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "3090")):
|
||||
triple = f"ampere-rtx3090-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "3080")):
|
||||
triple = f"ampere-rtx3080-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "3070")):
|
||||
triple = f"ampere-rtx3070-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "3060")):
|
||||
triple = f"ampere-rtx3060-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "3050")):
|
||||
triple = f"ampere-rtx3050-{system_os}"
|
||||
# We use ampere until lovelace target triples are plumbed in.
|
||||
elif all(x in device_name for x in ("RTX", "4090")):
|
||||
triple = f"ampere-rtx4090-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "4080")):
|
||||
triple = f"ampere-rtx4080-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "4070")):
|
||||
triple = f"ampere-rtx4070-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "4000")):
|
||||
triple = f"turing-rtx4000-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "5000")):
|
||||
triple = f"turing-rtx5000-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "6000")):
|
||||
triple = f"turing-rtx6000-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "8000")):
|
||||
triple = f"turing-rtx8000-{system_os}"
|
||||
elif all(x in device_name for x in ("TITAN", "RTX")):
|
||||
triple = f"turing-titanrtx-{system_os}"
|
||||
elif all(x in device_name for x in ("GTX", "1060")):
|
||||
triple = f"pascal-gtx1060-{system_os}"
|
||||
elif all(x in device_name for x in ("GTX", "1070")):
|
||||
triple = f"pascal-gtx1070-{system_os}"
|
||||
elif all(x in device_name for x in ("GTX", "1080")):
|
||||
triple = f"pascal-gtx1080-{system_os}"
|
||||
|
||||
# Amd Targets
|
||||
# Linux: Radeon RX 7900 XTX
|
||||
# Windows: AMD Radeon RX 7900 XTX
|
||||
elif all(x in device_name for x in ("RX", "7800")):
|
||||
triple = f"rdna3-7800-{system_os}"
|
||||
elif all(x in device_name for x in ("RX", "7900")):
|
||||
triple = f"rdna3-7900-{system_os}"
|
||||
elif all(x in device_name for x in ("Radeon", "780M")):
|
||||
triple = f"rdna3-780m-{system_os}"
|
||||
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
|
||||
triple = f"rdna3-w7900-{system_os}"
|
||||
elif any(x in device_name for x in ("AMD", "Radeon")):
|
||||
triple = f"rdna2-unknown-{system_os}"
|
||||
# Intel Targets
|
||||
elif any(x in device_name for x in ("A770", "A750")):
|
||||
triple = f"arc-770-{system_os}"
|
||||
elif "v620" in device_name:
|
||||
triple = f"rdna2-v620-{system_os}"
|
||||
|
||||
# Adreno Targets
|
||||
elif all(x in device_name for x in ("Adreno", "740")):
|
||||
triple = f"adreno-a740-{system_os}"
|
||||
|
||||
else:
|
||||
triple = None
|
||||
return triple
|
||||
|
||||
|
||||
def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
for flag in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in flag:
|
||||
print(f"Using target triple {flag.split('=')[1]}")
|
||||
return None
|
||||
|
||||
if device_name == "" or device_name == [] or device_name is None:
|
||||
vulkan_device = get_vulkan_device_name(device_num=device_num)
|
||||
else:
|
||||
vulkan_device = device_name
|
||||
triple = get_vulkan_target_triple(vulkan_device)
|
||||
if triple is not None:
|
||||
print(
|
||||
f"Found vulkan device {vulkan_device}. Using target triple {triple}"
|
||||
)
|
||||
return f"--iree-vulkan-target-triple={triple}"
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
Contact AMDSHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
or pull up an issue."""
|
||||
)
|
||||
print(f"Target : {vulkan_device}")
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
res_vulkan_flag = []
|
||||
res_vulkan_flag += [
|
||||
"--iree-stream-resource-max-allocation-size=3221225472",
|
||||
"--iree-flow-inline-constants-max-byte-length=0"
|
||||
]
|
||||
vulkan_triple_flag = None
|
||||
for arg in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in arg:
|
||||
print(f"Using target triple {arg} from command line args")
|
||||
vulkan_triple_flag = arg
|
||||
break
|
||||
|
||||
if vulkan_triple_flag is None:
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
res_vulkan_flag += [vulkan_triple_flag]
|
||||
|
||||
return res_vulkan_flag
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_iree_vulkan_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_validation_layers={'true' if amdshark_args.vulkan_debug_utils else 'false'}",
|
||||
f"--vulkan_debug_verbosity={'4' if amdshark_args.vulkan_debug_utils else '0'}"
|
||||
f"--vulkan-robust-buffer-access={'true' if amdshark_args.vulkan_debug_utils else 'false'}",
|
||||
]
|
||||
return vulkan_runtime_flags
|
||||
|
||||
|
||||
def set_iree_vulkan_runtime_flags(flags):
|
||||
for flag in flags:
|
||||
ireert.flags.parse_flags(flag)
|
||||
return
|
||||
@@ -1,468 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
This function takes the model mlir file and the tuned config file as input,
|
||||
and output a new mlir file with lowering configs annotated on certain ops.
|
||||
There are two ways to utilize the function:
|
||||
1. Call model_annotation function within another python script
|
||||
from amdshark.model_annotation import model_annotation
|
||||
with create_context() as ctx:
|
||||
module = model_annotation(ctx, input_contents=..., config_path=..., search_op=...)
|
||||
2. Run model_annotation.py directly
|
||||
python model_annotation.py -model path_to_original_mlir -config_path path_to_config_file
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
|
||||
import iree.compiler._mlir_libs
|
||||
from iree.compiler import ir
|
||||
|
||||
|
||||
def model_annotation(
|
||||
ctx: ir.Context,
|
||||
*,
|
||||
input_contents: str,
|
||||
config_path: str,
|
||||
search_op: str,
|
||||
winograd: bool = False,
|
||||
):
|
||||
if os.path.isfile(input_contents):
|
||||
with open(input_contents, "rb") as f:
|
||||
input_contents = f.read()
|
||||
module = ir.Module.parse(input_contents)
|
||||
|
||||
if config_path == "":
|
||||
return module
|
||||
|
||||
if winograd:
|
||||
with open(config_path, "r") as f:
|
||||
data = json.load(f)
|
||||
configs = data["c,f"]
|
||||
else:
|
||||
configs = load_model_configs(config_path)
|
||||
|
||||
# The Python API does not expose a general walk() function, so we just
|
||||
# do it ourselves.
|
||||
walk_children(module.operation, configs, search_op, winograd)
|
||||
|
||||
if not module.operation.verify():
|
||||
raise RuntimeError("Modified program does not verify!")
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def load_model_configs(config_path: str):
|
||||
config = {}
|
||||
with open(config_path, "r") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
|
||||
if "identifier" not in data.keys():
|
||||
continue
|
||||
if data["identifier"] == "matmul":
|
||||
matrix_size = [data["m"], data["n"], data["k"]]
|
||||
elif data["identifier"] == "bmm":
|
||||
matrix_size = [data["b"], data["m"], data["n"], data["k"]]
|
||||
elif data["identifier"] == "generic":
|
||||
matrix_size = [1, data["b"], data["m"], data["n"], data["k"]]
|
||||
elif data["identifier"] == "conv":
|
||||
matrix_size = [
|
||||
data["n"],
|
||||
data["ih"],
|
||||
data["iw"],
|
||||
data["c"],
|
||||
data["kh"],
|
||||
data["kw"],
|
||||
data["f"],
|
||||
data["oh"],
|
||||
data["ow"],
|
||||
data["d"],
|
||||
data["s"],
|
||||
data["p"],
|
||||
]
|
||||
config[shape_list_to_string(matrix_size)] = data
|
||||
f.close()
|
||||
return config
|
||||
|
||||
|
||||
def walk_children(
|
||||
op: ir.Operation, configs: List[Dict], search_op: str, winograd: bool
|
||||
):
|
||||
if search_op == "matmul":
|
||||
op_names = ["linalg.matmul", "mhlo.dot"]
|
||||
elif search_op == "bmm":
|
||||
op_names = ["linalg.batch_matmul", "mhlo.dot_general"]
|
||||
elif search_op == "conv":
|
||||
op_names = ["mhlo.convolution", "linalg.conv_2d_nhwc_hwcf"]
|
||||
elif search_op == "generic":
|
||||
op_names = ["linalg.generic"]
|
||||
elif search_op == "all":
|
||||
op_names = [
|
||||
"mhlo.dot",
|
||||
"mhlo.dot_general",
|
||||
"mhlo.convolution",
|
||||
"linalg.matmul",
|
||||
"linalg.batch_matmul",
|
||||
"linalg.conv_2d_nhwc_hwcf",
|
||||
"linalg.generic",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"{search_op} op is not tunable.")
|
||||
|
||||
for region in op.regions:
|
||||
for block in region.blocks:
|
||||
for child_op in block.operations:
|
||||
# TODO: This is dumb. Both Operation and OpView should expose
|
||||
# 'operation' and 'name' attributes.
|
||||
if isinstance(child_op, ir.OpView):
|
||||
child_op = child_op.operation
|
||||
if winograd and child_op.name in [
|
||||
"linalg.conv_2d_nchw_fchw",
|
||||
"linalg.conv_2d_nhwc_hwcf",
|
||||
]:
|
||||
add_winograd_attribute(child_op, configs)
|
||||
if child_op.name in op_names:
|
||||
if child_op.name == "linalg.generic":
|
||||
# This is for generic op that has contractionOpInterface
|
||||
# which is basically einsum("mk,bkn->bmn")
|
||||
op_result = str(child_op.results[0])
|
||||
op_iterator = str(
|
||||
child_op.attributes["iterator_types"]
|
||||
)
|
||||
if len(child_op.operands) != 3:
|
||||
continue
|
||||
if "reduction" not in op_iterator:
|
||||
continue
|
||||
if (
|
||||
"arith.addf" not in op_result
|
||||
or "arith.mulf" not in op_result
|
||||
):
|
||||
continue
|
||||
if "arith.subf" in op_result:
|
||||
continue
|
||||
|
||||
child_op_shape = get_op_shape(child_op, search_op)
|
||||
if (
|
||||
child_op_shape in configs.keys()
|
||||
and configs[child_op_shape]["options"][0] != None
|
||||
):
|
||||
add_attributes(
|
||||
child_op, configs[child_op_shape]["options"][0]
|
||||
)
|
||||
|
||||
walk_children(child_op, configs, search_op, winograd)
|
||||
|
||||
|
||||
def get_op_shape(op: ir.Operation, search_op: str):
|
||||
shape_list = []
|
||||
if search_op in ["generic", "all"]:
|
||||
if op.name in ["linalg.generic"]:
|
||||
input1 = str(op.operands[0].type)
|
||||
input2 = str(op.operands[1].type)
|
||||
m = input1.split("tensor<")[1].split("x")[0]
|
||||
b = input2.split("tensor<")[1].split("x")[0]
|
||||
k = input2.split("tensor<")[1].split("x")[1]
|
||||
n = input2.split("tensor<")[1].split("x")[2]
|
||||
shape_list = [1, int(b), int(m), int(n), int(k)]
|
||||
|
||||
if search_op in ["matmul", "all"]:
|
||||
if op.name in ["mhlo.dot"]:
|
||||
op_result = str(op.results[0])
|
||||
m = op_result.split("tensor<")[1].split("x")[0]
|
||||
k = op_result.split("tensor<")[1].split("x")[1]
|
||||
n = op_result.split("tensor<")[2].split("x")[1]
|
||||
shape_list = [int(m), int(n), int(k)]
|
||||
elif op.name in ["linalg.matmul"]:
|
||||
op_result = str(op.results[0]).split("ins(")[1]
|
||||
m = op_result.split("tensor<")[1].split("x")[0]
|
||||
k = op_result.split("tensor<")[1].split("x")[1]
|
||||
n = op_result.split("tensor<")[2].split("x")[1]
|
||||
shape_list = [int(m), int(n), int(k)]
|
||||
|
||||
if search_op in ["bmm", "all"]:
|
||||
if op.name in ["mhlo.dot_general"]:
|
||||
op_result = str(op.results[0])
|
||||
b = op_result.split("tensor<")[1].split("x")[1]
|
||||
m = op_result.split("tensor<")[1].split("x")[2]
|
||||
k = op_result.split("tensor<")[1].split("x")[3]
|
||||
n = op_result.split("tensor<")[3].split("x")[3]
|
||||
shape_list = [int(b), int(m), int(n), int(k)]
|
||||
elif op.name in ["linalg.batch_matmul"]:
|
||||
op_result = str(op.results[0]).split("ins(")[1]
|
||||
b = op_result.split("tensor<")[1].split("x")[0]
|
||||
m = op_result.split("tensor<")[1].split("x")[1]
|
||||
k = op_result.split("tensor<")[1].split("x")[2]
|
||||
n = op_result.split("tensor<")[3].split("x")[2]
|
||||
shape_list = [int(b), int(m), int(n), int(k)]
|
||||
|
||||
if search_op in ["conv", "all"]:
|
||||
if op.name in ["mhlo.convolution"]:
|
||||
op_result = str(op.results[0])
|
||||
dilation = (
|
||||
str(op.attributes["rhs_dilation"])
|
||||
.split("dense<")[1]
|
||||
.split(">")[0]
|
||||
)
|
||||
stride = (
|
||||
str(op.attributes["window_strides"])
|
||||
.split("dense<")[1]
|
||||
.split(">")[0]
|
||||
)
|
||||
pad = (
|
||||
str(op.attributes["padding"]).split("dense<")[1].split(">")[0]
|
||||
)
|
||||
n = op_result.split("tensor<")[1].split("x")[0]
|
||||
ih = op_result.split("tensor<")[1].split("x")[1]
|
||||
iw = op_result.split("tensor<")[1].split("x")[2]
|
||||
c = op_result.split("tensor<")[1].split("x")[3]
|
||||
kh = op_result.split("tensor<")[2].split("x")[0]
|
||||
kw = op_result.split("tensor<")[2].split("x")[1]
|
||||
f = op_result.split("tensor<")[2].split("x")[3]
|
||||
oh = op_result.split("tensor<")[3].split("x")[1]
|
||||
ow = op_result.split("tensor<")[3].split("x")[2]
|
||||
shape_list = [
|
||||
int(n),
|
||||
int(ih),
|
||||
int(iw),
|
||||
int(c),
|
||||
int(kh),
|
||||
int(kw),
|
||||
int(f),
|
||||
int(oh),
|
||||
int(ow),
|
||||
int(dilation),
|
||||
int(stride),
|
||||
int(pad),
|
||||
]
|
||||
|
||||
elif op.name in ["linalg.conv_2d_nhwc_hwcf"]:
|
||||
op_result = str(op.results[0]).split("ins(")[1]
|
||||
dilation = (
|
||||
str(op.attributes["dilations"])
|
||||
.split("dense<")[1]
|
||||
.split(">")[0]
|
||||
)
|
||||
stride = (
|
||||
str(op.attributes["strides"]).split("dense<")[1].split(">")[0]
|
||||
)
|
||||
pad = 0
|
||||
n = op_result.split("tensor<")[1].split("x")[0]
|
||||
ih = op_result.split("tensor<")[1].split("x")[1]
|
||||
iw = op_result.split("tensor<")[1].split("x")[2]
|
||||
c = op_result.split("tensor<")[1].split("x")[3]
|
||||
kh = op_result.split("tensor<")[2].split("x")[0]
|
||||
kw = op_result.split("tensor<")[2].split("x")[1]
|
||||
f = op_result.split("tensor<")[2].split("x")[3]
|
||||
oh = op_result.split("tensor<")[3].split("x")[1]
|
||||
ow = op_result.split("tensor<")[3].split("x")[2]
|
||||
shape_list = [
|
||||
int(n),
|
||||
int(ih),
|
||||
int(iw),
|
||||
int(c),
|
||||
int(kh),
|
||||
int(kw),
|
||||
int(f),
|
||||
int(oh),
|
||||
int(ow),
|
||||
int(dilation),
|
||||
int(stride),
|
||||
int(pad),
|
||||
]
|
||||
|
||||
shape_str = shape_list_to_string(shape_list)
|
||||
return shape_str
|
||||
|
||||
|
||||
def add_attributes(op: ir.Operation, config: List[Dict]):
|
||||
# Parse the config file
|
||||
split_k = None
|
||||
pipeline_depth = None
|
||||
store_stage = None
|
||||
subgroup_size = None
|
||||
|
||||
if "GPU" in config["pipeline"]:
|
||||
pipeline = (
|
||||
"LLVMGPUMatmulSimt"
|
||||
if config["pipeline"] == "GPU"
|
||||
else "LLVMGPUMatmulTensorCore"
|
||||
)
|
||||
tile_sizes = [config["work_group_tile_sizes"]]
|
||||
workgroup_size = config["work_group_sizes"]
|
||||
if "pipeline_depth" in config.keys():
|
||||
pipeline_depth = config["pipeline_depth"]
|
||||
if "split_k" in config.keys():
|
||||
split_k = config["split_k"]
|
||||
elif "SPIRV" in config["pipeline"]:
|
||||
pipeline = config["pipeline"]
|
||||
if pipeline == "SPIRVMatmulPromoteVectorize":
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"]
|
||||
+ [config["reduction_tile_sizes"][-1]],
|
||||
]
|
||||
else:
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"],
|
||||
config["parallel_tile_sizes"],
|
||||
config["reduction_tile_sizes"],
|
||||
]
|
||||
|
||||
workgroup_size = config["work_group_sizes"]
|
||||
if "vector_tile_sizes" in config.keys():
|
||||
tile_sizes += [config["vector_tile_sizes"]]
|
||||
if "window_tile_sizes" in config.keys():
|
||||
tile_sizes += [config["window_tile_sizes"]]
|
||||
if "subgroup_size" in config.keys():
|
||||
subgroup_size = config["subgroup_size"]
|
||||
if "pipeline_depth" in config.keys():
|
||||
pipeline_depth = config["pipeline_depth"]
|
||||
if "store_stage" in config.keys():
|
||||
store_stage = config["store_stage"]
|
||||
else:
|
||||
# For IREE CPU pipelines
|
||||
pipeline = config["pipeline"]
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"],
|
||||
config["parallel_tile_sizes"],
|
||||
config["reduction_tile_sizes"],
|
||||
]
|
||||
workgroup_size = []
|
||||
|
||||
# Add compilation info as an attribute. We don't have a Python binding for CompilationInfo,
|
||||
# so we just parse its string form.
|
||||
if pipeline_depth != None:
|
||||
translation_info = f"{pipeline} pipeline_depth = {pipeline_depth}"
|
||||
if store_stage != None:
|
||||
translation_info += f" store_stage = {store_stage}"
|
||||
else:
|
||||
translation_info = f"{pipeline}"
|
||||
|
||||
compilation_info = (
|
||||
f"#iree_codegen.compilation_info<"
|
||||
f"lowering_config = <tile_sizes = {repr(tile_sizes)}>, "
|
||||
f"translation_info = <{translation_info}>, "
|
||||
f"workgroup_size = {repr(workgroup_size)} "
|
||||
)
|
||||
|
||||
if subgroup_size != None:
|
||||
compilation_info += f", subgroup_size = {subgroup_size}>"
|
||||
else:
|
||||
compilation_info += ">"
|
||||
|
||||
attr = ir.Attribute.parse(compilation_info)
|
||||
op.attributes["compilation_info"] = attr
|
||||
|
||||
# Add other attributes if required.
|
||||
if split_k:
|
||||
add_attribute_by_name(op, "iree_flow_split_k", split_k)
|
||||
|
||||
|
||||
def add_winograd_attribute(op: ir.Operation, config: List):
|
||||
op_result = str(op.results[0]).split("ins(")[1]
|
||||
dilation = int(
|
||||
str(op.attributes["dilations"]).split("dense<")[1].split(">")[0]
|
||||
)
|
||||
stride = int(
|
||||
str(op.attributes["strides"]).split("dense<")[1].split(">")[0]
|
||||
)
|
||||
|
||||
if op.name == "linalg.conv_2d_nchw_fchw":
|
||||
f = int(op_result.split("tensor<")[2].split("x")[0])
|
||||
c = int(op_result.split("tensor<")[2].split("x")[1])
|
||||
kh = int(op_result.split("tensor<")[2].split("x")[2])
|
||||
kw = int(op_result.split("tensor<")[2].split("x")[3])
|
||||
else:
|
||||
kh = int(op_result.split("tensor<")[2].split("x")[0])
|
||||
kw = int(op_result.split("tensor<")[2].split("x")[1])
|
||||
c = int(op_result.split("tensor<")[2].split("x")[2])
|
||||
f = int(op_result.split("tensor<")[2].split("x")[3])
|
||||
|
||||
if (
|
||||
dilation == 1
|
||||
and stride == 1
|
||||
and kh == 3
|
||||
and kw == 3
|
||||
and [c, f] in config
|
||||
):
|
||||
op.attributes["iree_winograd_conv"] = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(64), 1
|
||||
)
|
||||
|
||||
|
||||
def add_attribute_by_name(op: ir.Operation, name: str, val: int):
|
||||
attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
|
||||
op.attributes[name] = attr
|
||||
|
||||
|
||||
def shape_list_to_string(input):
|
||||
return "x".join([str(d) for d in input])
|
||||
|
||||
|
||||
def create_context() -> ir.Context:
|
||||
context = ir.Context()
|
||||
context.allow_unregistered_dialects = True
|
||||
return context
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-model",
|
||||
type=path_expand,
|
||||
default="model.mlir",
|
||||
help="Path to the input mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-config_path",
|
||||
type=path_expand,
|
||||
default="best_configs.json",
|
||||
help="Path where stores the op config file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-output_path",
|
||||
type=path_expand,
|
||||
default="tuned_model.mlir",
|
||||
help="Path to save the annotated mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-search_op",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Op to be optimized. options are matmul, bmm, conv.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with create_context() as ctx:
|
||||
module = model_annotation(
|
||||
ctx,
|
||||
input_contents=args.model,
|
||||
config_path=args.config_path,
|
||||
search_op=args.search_op,
|
||||
)
|
||||
mlir_str = str(module)
|
||||
with open(args.output_path, "w") as f:
|
||||
f.write(mlir_str)
|
||||
print(f"Saved mlir in {args.output_path}.")
|
||||
@@ -1,170 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
|
||||
class SplitStrToListAction(argparse.Action):
|
||||
def __init__(self, option_strings, dest, *args, **kwargs):
|
||||
super(SplitStrToListAction, self).__init__(
|
||||
option_strings=option_strings, dest=dest, *args, **kwargs
|
||||
)
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
del parser, option_string
|
||||
setattr(namespace, self.dest, shlex.split(" "))
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="AMDSHARK runner.")
|
||||
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Device on which amdshark_runner runs. options are cpu, cuda, and vulkan",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional_compile_args",
|
||||
default=list(),
|
||||
nargs=1,
|
||||
action=SplitStrToListAction,
|
||||
help="Additional arguments to pass to the compiler. These are appended as the last arguments.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional_runtime_args",
|
||||
default=list(),
|
||||
nargs=1,
|
||||
action=SplitStrToListAction,
|
||||
help="Additional arguments to pass to the IREE runtime. These are appended as the last arguments.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_tf32",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Enables TF32 precision calculations on supported GPUs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config_path",
|
||||
help="Directory to where the tuned model config file is located.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_warmup_iterations",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Run the model for the specified number of warmup iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_iterations",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Run the model for the specified number of iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--onnx_bench",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled, pytest bench results will include ONNX benchmark results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--amdshark_prefix",
|
||||
default=None,
|
||||
help="gs://amdshark_tank/<this_flag>/model_directories",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update_tank",
|
||||
default=True,
|
||||
action="store_true",
|
||||
help="When enabled, AMDSHARK downloader will update local amdshark_tank if local hash is different from latest upstream hash.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_update_tank",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled, AMDSHARK downloader will force an update of local amdshark_tank artifacts for each request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local_tank_cache",
|
||||
default=None,
|
||||
help="Specify where to save downloaded amdshark_tank artifacts. If this is not set, the default is ~/.local/amdshark_tank/.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dispatch_benchmarks",
|
||||
default=None,
|
||||
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="temp_dispatch_benchmarks",
|
||||
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_conv_transform",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Enables the --iree-flow-enable-conv-nchw-to-nhwc-transform flag.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_img2col_transform",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Enables the --iree-flow-enable-conv-img2col-transform flag.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_winograd",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Enables the --iree-flow-enable-conv-winograd-transform flag.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device_allocator",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["caching"],
|
||||
help="Specifies one or more HAL device allocator specs "
|
||||
"to augment the base device allocator",
|
||||
choices=["debug", "caching"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_topology_max_group_count",
|
||||
type=str,
|
||||
default=None,
|
||||
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Profiles vulkan device and collects the .rdc info.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for disabling vulkan validation layers when benchmarking.",
|
||||
)
|
||||
|
||||
amdshark_args, unknown = parser.parse_known_args()
|
||||
@@ -1,315 +0,0 @@
|
||||
# Copyright 2022 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from iree.runtime import query_available_drivers, get_driver
|
||||
from amdshark.amdshark_downloader import download_model
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import argparse
|
||||
from amdshark.iree_utils._common import _IREE_DEVICE_MAP
|
||||
import multiprocessing
|
||||
from amdshark.amdshark_runner import supported_dialects
|
||||
import logging
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
IREE_TO_AMDSHARK_DRIVER_MAP = {v: k for k, v in _IREE_DEVICE_MAP.items()}
|
||||
|
||||
|
||||
def stress_test_compiled_model(
|
||||
amdshark_module_path: str,
|
||||
function_name: str,
|
||||
device: str,
|
||||
inputs: List[np.ndarray],
|
||||
golden_out: List[np.ndarray],
|
||||
batch_size: int,
|
||||
max_iterations: int,
|
||||
max_duration_seconds: float,
|
||||
inference_timeout_seconds: float,
|
||||
tolerance_nulp: int,
|
||||
stress_test_index: int,
|
||||
):
|
||||
logging.info(
|
||||
f"Running stress test {stress_test_index} on device {device}."
|
||||
)
|
||||
# All interactions with the module must run in a single thread.
|
||||
# We are using execution in a sperate thread in order to be able
|
||||
# to wait with a timeout on the inference operation.
|
||||
module_executor = ThreadPoolExecutor(1)
|
||||
amdshark_module = module_executor.submit(
|
||||
AMDSharkInference,
|
||||
mlir_module=bytes(),
|
||||
function_name=function_name,
|
||||
device=device,
|
||||
).result()
|
||||
module_executor.submit(
|
||||
amdshark_module.load_module, amdshark_module_path
|
||||
).result()
|
||||
input_batches = [np.repeat(arr, batch_size, axis=0) for arr in inputs]
|
||||
golden_output_batches = np.repeat(golden_out, batch_size, axis=0)
|
||||
report_interval_seconds = 10
|
||||
start_time = time.time()
|
||||
previous_report_time = start_time
|
||||
first_iteration_output = None
|
||||
for i in range(max_iterations):
|
||||
output = module_executor.submit(
|
||||
amdshark_module.forward, input_batches
|
||||
).result(inference_timeout_seconds)
|
||||
if first_iteration_output is None:
|
||||
np.testing.assert_array_almost_equal_nulp(
|
||||
golden_output_batches, output, nulp=tolerance_nulp
|
||||
)
|
||||
first_iteration_output = output
|
||||
else:
|
||||
np.testing.assert_array_equal(output, first_iteration_output)
|
||||
current_time = time.time()
|
||||
if report_interval_seconds < current_time - previous_report_time:
|
||||
logging.info(
|
||||
f"Stress test {stress_test_index} on device "
|
||||
f"{device} at iteration {i+1}"
|
||||
)
|
||||
previous_report_time = current_time
|
||||
if max_duration_seconds < current_time - start_time:
|
||||
return
|
||||
logging.info(f"Stress test {stress_test_index} on device {device} done.")
|
||||
|
||||
|
||||
def get_device_type(device_name: str):
|
||||
return device_name.split("://", 1)[0]
|
||||
|
||||
|
||||
def get_device_types(device_names: str):
|
||||
return [get_device_type(device_name) for device_name in device_names]
|
||||
|
||||
|
||||
def query_devices(device_types: Optional[List[str]] = None) -> List[str]:
|
||||
devices = []
|
||||
if device_types is None:
|
||||
device_types = [
|
||||
IREE_TO_AMDSHARK_DRIVER_MAP[name]
|
||||
for name in query_available_drivers()
|
||||
if name in IREE_TO_AMDSHARK_DRIVER_MAP
|
||||
]
|
||||
for device_type in device_types:
|
||||
driver = get_driver(_IREE_DEVICE_MAP[device_type])
|
||||
device_infos = driver.query_available_devices()
|
||||
for device_info in device_infos:
|
||||
uri_path = (
|
||||
device_info["path"]
|
||||
if device_info["path"] != ""
|
||||
else str(device_info["device_id"])
|
||||
)
|
||||
device_uri = f"{device_type}://{uri_path}"
|
||||
devices.append(device_uri)
|
||||
return devices
|
||||
|
||||
|
||||
def compile_stress_test_module(
|
||||
device_types: List[str], mlir_model: str, func_name: str, mlir_dialect: str
|
||||
) -> List[str]:
|
||||
amdshark_module_paths = []
|
||||
for device_type in device_types:
|
||||
logging.info(
|
||||
f"Compiling stress test model for device type {device_type}."
|
||||
)
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_model,
|
||||
func_name,
|
||||
mlir_dialect=mlir_dialect,
|
||||
device=device_type,
|
||||
)
|
||||
amdshark_module_paths.append(amdshark_module.save_module())
|
||||
return amdshark_module_paths
|
||||
|
||||
|
||||
def stress_test(
|
||||
model_name: str,
|
||||
dynamic_model: bool = False,
|
||||
device_types: Optional[List[str]] = None,
|
||||
device_names: Optional[List[str]] = None,
|
||||
batch_size: int = 1,
|
||||
max_iterations: int = 10**7,
|
||||
max_duration_seconds: float = 3600,
|
||||
inference_timeout_seconds: float = 60,
|
||||
mlir_dialect: str = "linalg",
|
||||
frontend: str = "torch",
|
||||
oversubscription_factor: int = 1,
|
||||
tolerance_nulp: int = 50000,
|
||||
):
|
||||
logging.info(f"Downloading stress test model {model_name}.")
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name=model_name, dynamic=dynamic_model, frontend=frontend
|
||||
)
|
||||
|
||||
if device_names is None or device_types is not None:
|
||||
device_names = [] if device_names is None else device_names
|
||||
with ProcessPoolExecutor() as executor:
|
||||
# query_devices needs to run in a separate process,
|
||||
# because it will interfere with other processes that are forked later.
|
||||
device_names.extend(
|
||||
executor.submit(query_devices, device_types).result()
|
||||
)
|
||||
|
||||
device_types_set = list(set(get_device_types(device_names)))
|
||||
with ProcessPoolExecutor() as executor:
|
||||
# This needs to run in a subprocess because when compiling for CUDA,
|
||||
# some stuff get intialized and cuInit will fail in a forked process
|
||||
# later. It should be just compiling, but alas.
|
||||
amdshark_module_paths_set = executor.submit(
|
||||
compile_stress_test_module,
|
||||
device_types_set,
|
||||
mlir_model,
|
||||
func_name,
|
||||
mlir_dialect,
|
||||
).result()
|
||||
device_type_amdshark_module_path_map = {
|
||||
device_type: module_path
|
||||
for device_type, module_path in zip(
|
||||
device_types_set, amdshark_module_paths_set
|
||||
)
|
||||
}
|
||||
device_name_amdshark_module_path_map = {
|
||||
device_name: device_type_amdshark_module_path_map[
|
||||
get_device_type(device_name)
|
||||
]
|
||||
for device_name in device_names
|
||||
}
|
||||
|
||||
# This needs to run in a spearate process, because it uses the drvier chache
|
||||
# in IREE and a subsequent call to `iree.runtime.SystemContext.add_vm_module`
|
||||
# in a forked process will hang.
|
||||
with multiprocessing.Pool(
|
||||
len(device_name_amdshark_module_path_map) * oversubscription_factor
|
||||
) as process_pool:
|
||||
process_pool.starmap(
|
||||
stress_test_compiled_model,
|
||||
[
|
||||
(
|
||||
module_path,
|
||||
func_name,
|
||||
device_name,
|
||||
inputs,
|
||||
golden_out,
|
||||
batch_size,
|
||||
max_iterations,
|
||||
max_duration_seconds,
|
||||
inference_timeout_seconds,
|
||||
tolerance_nulp,
|
||||
stress_test_index,
|
||||
)
|
||||
for stress_test_index, (device_name, module_path) in enumerate(
|
||||
list(device_name_amdshark_module_path_map.items())
|
||||
* oversubscription_factor
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(encoding="utf-8", level=logging.INFO)
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Downloads, compiles and runs a model from the tank to stress test the system."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, help="Model name in the tank.", default="alexnet"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
help="Use dynamic version of the model.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frontend", type=str, help="Frontend of the model.", default="torch"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlir-dialect",
|
||||
type=str,
|
||||
help="MLIR dialect of the model.",
|
||||
default="linalg",
|
||||
choices=supported_dialects,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device-types",
|
||||
type=str,
|
||||
nargs="*",
|
||||
choices=_IREE_DEVICE_MAP.keys(),
|
||||
help="Runs the stress test on all devices with that type. "
|
||||
"If absent and no deveices are specified "
|
||||
"will run against all available devices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--devices",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="List of devices to run the stress test on. "
|
||||
"If device-types is specified will run against the union of the two.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
help="Number of inputs to feed into the model",
|
||||
default=1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--oversubscription",
|
||||
type=int,
|
||||
help="Oversubscrption factor. Each device will execute the model simultaneously "
|
||||
"this many number of times.",
|
||||
default=1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-iterations",
|
||||
type=int,
|
||||
help="Maximum number of iterations to run the stress test per device.",
|
||||
default=10**7,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-duration",
|
||||
type=float,
|
||||
help="Maximum number of seconds to run the stress test.",
|
||||
default=3600,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inference-timeout",
|
||||
type=float,
|
||||
help="Timeout in seconds for a single model inference operation.",
|
||||
default=60,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tolerance-nulp",
|
||||
type=int,
|
||||
help="The maximum number of unit in the last place for tolerance "
|
||||
"when verifing results with the golden reference output.",
|
||||
default=50000,
|
||||
)
|
||||
|
||||
args = parser.parse_known_args()[0]
|
||||
stress_test(
|
||||
model_name=args.model,
|
||||
dynamic_model=args.dynamic,
|
||||
frontend=args.frontend,
|
||||
mlir_dialect=args.mlir_dialect,
|
||||
device_types=args.device_types,
|
||||
device_names=args.devices,
|
||||
batch_size=args.batch_size,
|
||||
oversubscription_factor=args.oversubscription,
|
||||
max_iterations=args.max_iterations,
|
||||
max_duration_seconds=args.max_duration,
|
||||
inference_timeout_seconds=args.inference_timeout,
|
||||
tolerance_nulp=args.tolerance_nulp,
|
||||
)
|
||||
@@ -1,144 +0,0 @@
|
||||
# RUN: %PYTHON %s
|
||||
import numpy as np
|
||||
from amdshark.amdshark_importer import AMDSharkImporter
|
||||
import pytest
|
||||
from amdshark.parser import amdshark_args
|
||||
from amdshark.amdshark_inference import AMDSharkInference
|
||||
from amdshark.tflite_utils import TFLitePreprocessor
|
||||
import sys
|
||||
|
||||
# model_path = "https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1?lite-format=tflite"
|
||||
|
||||
|
||||
# Inputs modified to be useful albert inputs.
|
||||
def generate_inputs(input_details):
|
||||
for input in input_details:
|
||||
print(str(input["shape"]), input["dtype"].__name__)
|
||||
|
||||
args = []
|
||||
args.append(
|
||||
np.random.randint(
|
||||
low=0,
|
||||
high=256,
|
||||
size=input_details[0]["shape"],
|
||||
dtype=input_details[0]["dtype"],
|
||||
)
|
||||
)
|
||||
args.append(
|
||||
np.ones(
|
||||
shape=input_details[1]["shape"], dtype=input_details[1]["dtype"]
|
||||
)
|
||||
)
|
||||
args.append(
|
||||
np.zeros(
|
||||
shape=input_details[2]["shape"], dtype=input_details[2]["dtype"]
|
||||
)
|
||||
)
|
||||
return args
|
||||
|
||||
|
||||
def compare_results(mlir_results, tflite_results, details):
|
||||
print("Compare mlir_results VS tflite_results: ")
|
||||
assert len(mlir_results) == len(
|
||||
tflite_results
|
||||
), "Number of results do not match"
|
||||
for i in range(len(details)):
|
||||
mlir_result = mlir_results[i]
|
||||
tflite_result = tflite_results[i]
|
||||
mlir_result = mlir_result.astype(np.single)
|
||||
tflite_result = tflite_result.astype(np.single)
|
||||
assert mlir_result.shape == tflite_result.shape, "shape doesnot match"
|
||||
max_error = np.max(np.abs(mlir_result - tflite_result))
|
||||
print("Max error (%d): %f", i, max_error)
|
||||
|
||||
|
||||
class AlbertTfliteModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
amdshark_args.save_mlir = self.save_mlir
|
||||
amdshark_args.save_vmfb = self.save_vmfb
|
||||
tflite_preprocessor = TFLitePreprocessor(model_name="albert_lite_base")
|
||||
|
||||
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
|
||||
inputs = tflite_preprocessor.get_inputs()
|
||||
tflite_interpreter = tflite_preprocessor.get_interpreter()
|
||||
|
||||
my_amdshark_importer = AMDSharkImporter(
|
||||
module=tflite_interpreter,
|
||||
inputs=inputs,
|
||||
frontend="tflite",
|
||||
raw_model_file=raw_model_file_path,
|
||||
)
|
||||
mlir_model, func_name = my_amdshark_importer.import_mlir()
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module=mlir_model,
|
||||
function_name=func_name,
|
||||
device=self.device,
|
||||
mlir_dialect="tflite",
|
||||
)
|
||||
|
||||
# Case1: Use amdshark_importer default generate inputs
|
||||
amdshark_module.compile()
|
||||
mlir_results = amdshark_module.forward(inputs)
|
||||
## post process results for compare
|
||||
input_details, output_details = tflite_preprocessor.get_model_details()
|
||||
mlir_results = list(mlir_results)
|
||||
for i in range(len(output_details)):
|
||||
dtype = output_details[i]["dtype"]
|
||||
mlir_results[i] = mlir_results[i].astype(dtype)
|
||||
tflite_results = tflite_preprocessor.get_golden_output()
|
||||
compare_results(mlir_results, tflite_results, output_details)
|
||||
|
||||
# Case2: Use manually set inputs
|
||||
input_details, output_details = tflite_preprocessor.get_model_details()
|
||||
inputs = generate_inputs(input_details) # new inputs
|
||||
|
||||
amdshark_module = AMDSharkInference(
|
||||
mlir_module=mlir_model,
|
||||
function_name=func_name,
|
||||
device=self.device,
|
||||
mlir_dialect="tflite",
|
||||
)
|
||||
amdshark_module.compile()
|
||||
mlir_results = amdshark_module.forward(inputs)
|
||||
## post process results for compare
|
||||
tflite_results = tflite_preprocessor.get_golden_output()
|
||||
compare_results(mlir_results, tflite_results, output_details)
|
||||
# print(mlir_results)
|
||||
|
||||
|
||||
# A specific case can be run by commenting different cases. Runs all the test
|
||||
# across cpu, gpu and vulkan according to available drivers.
|
||||
pytest_param = pytest.mark.parametrize(
|
||||
("dynamic", "device"),
|
||||
[
|
||||
pytest.param(False, "cpu"),
|
||||
# TODO: Language models are failing for dynamic case..
|
||||
pytest.param(True, "cpu", marks=pytest.mark.skip),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest_param
|
||||
@pytest.mark.xfail(
|
||||
sys.platform == "darwin", reason="known macos tflite install issue"
|
||||
)
|
||||
def test_albert(dynamic, device):
|
||||
module_tester = AlbertTfliteModuleTester(dynamic=dynamic, device=device)
|
||||
module_tester.create_and_check_module()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_albert(False, "cpu")
|
||||
@@ -1,31 +0,0 @@
|
||||
# Copyright 2022 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import subprocess
|
||||
import sys
|
||||
import importlib.util
|
||||
|
||||
|
||||
def test_stress_test():
|
||||
subprocess.check_call(
|
||||
[
|
||||
sys.executable,
|
||||
importlib.util.find_spec("amdshark.stress_test").origin,
|
||||
"--model=squeezenet1_0",
|
||||
"--devices",
|
||||
"cpu",
|
||||
"--max-iterations=1",
|
||||
]
|
||||
)
|
||||
@@ -1,62 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
export_settings,
|
||||
load_settings,
|
||||
all_gradio_labels,
|
||||
)
|
||||
|
||||
|
||||
class TestExportSettings(unittest.TestCase):
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
@patch("json.dump")
|
||||
def test_export_settings(self, mock_json_dump, mock_file):
|
||||
test_values = ["value1", "value2", "value3"]
|
||||
expected_output = {
|
||||
"txt2img": {
|
||||
label: value
|
||||
for label, value in zip(all_gradio_labels, test_values)
|
||||
}
|
||||
}
|
||||
|
||||
export_settings(*test_values)
|
||||
mock_file.assert_called_once_with("./ui/settings.json", "w")
|
||||
mock_json_dump.assert_called_once_with(
|
||||
expected_output, mock_file(), indent=4
|
||||
)
|
||||
|
||||
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=mock_open,
|
||||
read_data='{"txt2img": {"some_setting": "some_value"}}',
|
||||
)
|
||||
def test_load_settings_file_exists(self, mock_file, mock_json_load):
|
||||
mock_json_load.return_value = {
|
||||
"txt2img": {
|
||||
"txt2img_custom_model": "custom_model_value",
|
||||
"custom_vae": "custom_vae_value",
|
||||
}
|
||||
}
|
||||
|
||||
settings = load_settings()
|
||||
self.assertEqual(settings[0], "custom_model_value")
|
||||
self.assertEqual(settings[1], "custom_vae_value")
|
||||
|
||||
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
|
||||
@patch("builtins.open", side_effect=FileNotFoundError)
|
||||
def test_load_settings_file_not_found(self, mock_file, mock_json_load):
|
||||
settings = load_settings()
|
||||
|
||||
default_lora_weights = "None"
|
||||
self.assertEqual(settings[4], default_lora_weights)
|
||||
|
||||
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="{}")
|
||||
def test_load_settings_key_error(self, mock_file, mock_json_load):
|
||||
mock_json_load.return_value = {}
|
||||
|
||||
settings = load_settings()
|
||||
default_lora_weights = "None"
|
||||
self.assertEqual(settings[4], default_lora_weights)
|
||||
@@ -1,208 +0,0 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import os
|
||||
import csv
|
||||
import urllib.request
|
||||
|
||||
|
||||
class TFLiteModelUtil:
|
||||
def __init__(self, raw_model_file):
|
||||
self.raw_model_file = str(raw_model_file)
|
||||
self.tflite_interpreter = None
|
||||
self.input_details = None
|
||||
self.output_details = None
|
||||
self.inputs = []
|
||||
|
||||
def setup_tflite_interpreter(self):
|
||||
self.tflite_interpreter = tf.lite.Interpreter(
|
||||
model_path=self.raw_model_file
|
||||
)
|
||||
self.tflite_interpreter.allocate_tensors()
|
||||
# default input initialization
|
||||
return self.get_model_details()
|
||||
|
||||
def get_model_details(self):
|
||||
print("Get tflite input output details")
|
||||
self.input_details = self.tflite_interpreter.get_input_details()
|
||||
self.output_details = self.tflite_interpreter.get_output_details()
|
||||
return self.input_details, self.output_details
|
||||
|
||||
def invoke_tflite(self, inputs):
|
||||
self.inputs = inputs
|
||||
print("invoke_tflite")
|
||||
for i, input in enumerate(self.inputs):
|
||||
self.tflite_interpreter.set_tensor(
|
||||
self.input_details[i]["index"], input
|
||||
)
|
||||
self.tflite_interpreter.invoke()
|
||||
|
||||
# post process tflite_result for compare with mlir_result,
|
||||
# for tflite the output is a list of numpy.tensor
|
||||
tflite_results = []
|
||||
for output_detail in self.output_details:
|
||||
tflite_results.append(
|
||||
np.array(
|
||||
self.tflite_interpreter.get_tensor(output_detail["index"])
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(len(self.output_details)):
|
||||
# print("output_details ", i, "shape", self.output_details[i]["shape"].__name__,
|
||||
# ", dtype: ", self.output_details[i]["dtype"].__name__)
|
||||
out_dtype = self.output_details[i]["dtype"]
|
||||
tflite_results[i] = tflite_results[i].astype(out_dtype)
|
||||
return tflite_results
|
||||
|
||||
|
||||
class TFLitePreprocessor:
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
input_details=None,
|
||||
output_details=None,
|
||||
model_path=None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.input_details = (
|
||||
input_details # used for tflite, optional for tf/pytorch
|
||||
)
|
||||
self.output_details = (
|
||||
output_details # used for tflite, optional for tf/pytorch
|
||||
)
|
||||
self.inputs = []
|
||||
self.model_path = model_path # url to download the model
|
||||
self.raw_model_file = (
|
||||
None # local address for raw tf/tflite/pytorch model
|
||||
)
|
||||
self.mlir_file = (
|
||||
None # local address for .mlir file of tf/tflite/pytorch model
|
||||
)
|
||||
self.mlir_model = None # read of .mlir file
|
||||
self.output_tensor = (
|
||||
None # the raw tf/pytorch/tflite_output_tensor, not mlir_tensor
|
||||
)
|
||||
self.interpreter = (
|
||||
None # could be tflite/tf/torch_interpreter in utils
|
||||
)
|
||||
self.input_file = None
|
||||
self.output_file = None
|
||||
|
||||
# create tmp model file directory
|
||||
if self.model_path is None and self.model_name is None:
|
||||
print(
|
||||
"Error. No model_path, No model name,Please input either one."
|
||||
)
|
||||
return
|
||||
|
||||
print("Setting up for TMP_WORK_DIR")
|
||||
self.workdir = os.path.join(
|
||||
os.path.dirname(__file__), "./../gen_amdshark_tank"
|
||||
)
|
||||
os.makedirs(self.workdir, exist_ok=True)
|
||||
print(f"TMP_WORK_DIR = {self.workdir}")
|
||||
|
||||
# compile and run tfhub tflite
|
||||
load_model_success = self.load_tflite_model()
|
||||
if not load_model_success:
|
||||
print("Error, load tflite model fail")
|
||||
return
|
||||
|
||||
if (self.input_details is None) or (self.output_details is None):
|
||||
# print("Setting up tflite interpreter to get model input details")
|
||||
self.setup_interpreter()
|
||||
|
||||
inputs = self.generate_inputs(self.input_details) # device_inputs
|
||||
self.setup_inputs(inputs)
|
||||
|
||||
def load_tflite_model(self):
|
||||
# use model name get dir.
|
||||
tflite_model_name_dir = os.path.join(
|
||||
self.workdir, str(self.model_name)
|
||||
)
|
||||
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
|
||||
self.raw_model_file = "/".join(
|
||||
[tflite_model_name_dir, str(self.model_name) + "_tflite.tflite"]
|
||||
)
|
||||
self.mlir_file = "/".join(
|
||||
[tflite_model_name_dir, str(self.model_name) + "_tflite.mlir"]
|
||||
)
|
||||
self.input_file = "/".join([tflite_model_name_dir, "inputs"])
|
||||
self.output_file = "/".join([tflite_model_name_dir, "golden_out"])
|
||||
# np.save("/".join([tflite_model_name_dir, "function_name"]), np.array("main"))
|
||||
|
||||
if os.path.exists(self.raw_model_file):
|
||||
print(
|
||||
"Local address for .tflite model file Exists: ",
|
||||
self.raw_model_file,
|
||||
)
|
||||
else:
|
||||
print("No local tflite file, Download tflite model")
|
||||
if self.model_path is None:
|
||||
# get model file from tflite_model_list.csv or download from gs://bucket
|
||||
print("No model_path, get from tflite_model_list.csv")
|
||||
tflite_model_list_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"../tank/tflite/tflite_model_list.csv",
|
||||
)
|
||||
tflite_model_list = csv.reader(open(tflite_model_list_path))
|
||||
for row in tflite_model_list:
|
||||
if str(row[0]) == str(self.model_name):
|
||||
self.model_path = row[1]
|
||||
print("tflite_model_name", str(row[0]))
|
||||
print("tflite_model_link", self.model_path)
|
||||
if self.model_path is None:
|
||||
print("Error, No model path find in tflite_model_list.csv")
|
||||
return False
|
||||
urllib.request.urlretrieve(self.model_path, self.raw_model_file)
|
||||
return True
|
||||
|
||||
def setup_interpreter(self):
|
||||
self.interpreter = TFLiteModelUtil(self.raw_model_file)
|
||||
(
|
||||
self.input_details,
|
||||
self.output_details,
|
||||
) = self.interpreter.setup_tflite_interpreter()
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
self.inputs = []
|
||||
for tmp_input in input_details:
|
||||
print(
|
||||
"input_details shape:",
|
||||
str(tmp_input["shape"]),
|
||||
" type:",
|
||||
tmp_input["dtype"].__name__,
|
||||
)
|
||||
self.inputs.append(
|
||||
np.ones(shape=tmp_input["shape"], dtype=tmp_input["dtype"])
|
||||
)
|
||||
return self.inputs
|
||||
|
||||
def setup_inputs(self, inputs):
|
||||
# print("Setting up inputs")
|
||||
self.inputs = inputs
|
||||
|
||||
def get_mlir_model(self):
|
||||
return self.mlir_model
|
||||
|
||||
def get_mlir_file(self):
|
||||
return self.mlir_file
|
||||
|
||||
def get_inputs(self):
|
||||
return self.inputs
|
||||
|
||||
def get_golden_output(self):
|
||||
self.output_tensor = self.interpreter.invoke_tflite(self.inputs)
|
||||
np.savez(self.output_file, *self.output_tensor)
|
||||
return self.output_tensor
|
||||
|
||||
def get_model_details(self):
|
||||
return self.input_details, self.output_details
|
||||
|
||||
def get_raw_model_file(self):
|
||||
return self.raw_model_file
|
||||
|
||||
def get_interpreter(self):
|
||||
return self.interpreter
|
||||
@@ -1,220 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
import contextlib
|
||||
import re
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from torch_mlir.eager_mode.ir_building import build_mlir_module
|
||||
from torch_mlir.eager_mode.torch_mlir_dispatch import (
|
||||
UnsupportedByTorchMlirEagerMode,
|
||||
normalize_args_kwargs,
|
||||
check_get_aliased_arg,
|
||||
)
|
||||
from torch_mlir.eager_mode import EAGER_MODE_DEBUG
|
||||
from torch_mlir.eager_mode.torch_mlir_tensor import (
|
||||
TorchMLIRTensor,
|
||||
check_requires_grad,
|
||||
make_wrapper_subclass_from_torch_tensor,
|
||||
make_bare_wrapper_subclass,
|
||||
UNSUPPORTED_OPS,
|
||||
no_dispatch,
|
||||
)
|
||||
from torch_mlir.eager_mode import torch_mlir_tensor
|
||||
from amdshark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
|
||||
|
||||
|
||||
backend = EagerModeIREELinalgOnTensorsBackend("cpu")
|
||||
torch_mlir_tensor.backend = backend
|
||||
rtol = 1e-04
|
||||
atol = 1e-05
|
||||
|
||||
|
||||
class TorchMLIRLockstepTensor(TorchMLIRTensor):
|
||||
"""This class overrides the dispatching for TorchMLIRTensor to allow for an op-by-op numerical comparison between PyTorch and the Torch-MLIR -> IREE backend compilation pipeline. This only supports the IREE backend and focuses on op-by-op level verification.
|
||||
|
||||
TODO: Extend this to do a cumulative trace with summary statistics at the end. Possibly requires a wrapper environment to store full trace info.
|
||||
"""
|
||||
|
||||
def __new__(cls, elem, **kwargs):
|
||||
if kwargs.get("constructing_from_device_tensor", False):
|
||||
tensor_meta_data = backend.get_torch_metadata(elem, kwargs)
|
||||
r = make_bare_wrapper_subclass(
|
||||
cls=cls,
|
||||
size=tensor_meta_data.size,
|
||||
strides=tensor_meta_data.strides,
|
||||
storage_offset=tensor_meta_data.storage_offset,
|
||||
dtype=tensor_meta_data.dtype,
|
||||
layout=tensor_meta_data.layout,
|
||||
device=tensor_meta_data.device,
|
||||
requires_grad=tensor_meta_data.requires_grad,
|
||||
)
|
||||
r.elem = elem
|
||||
elif isinstance(elem, torch.nn.Parameter):
|
||||
r = make_wrapper_subclass_from_torch_tensor(
|
||||
cls, elem.data, **kwargs
|
||||
)
|
||||
# This is a hack to handle non-contiguous data through IREE-backend
|
||||
nt = elem.detach().data.numpy()
|
||||
if not nt.flags["C_CONTIGUOUS"]:
|
||||
nt = np.ascontiguousarray(nt, dtype=nt.dtype)
|
||||
r.elem = backend.transfer_from_torch_to_device(
|
||||
torch.from_numpy(nt)
|
||||
)
|
||||
elif isinstance(elem, torch.Tensor):
|
||||
r = make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs)
|
||||
# Ditto TODO: Find a better way to handle this
|
||||
nt = elem.numpy()
|
||||
if not nt.flags["C_CONTIGUOUS"]:
|
||||
nt = np.ascontiguousarray(nt, dtype=nt.dtype)
|
||||
r.elem = backend.transfer_from_torch_to_device(
|
||||
torch.from_numpy(nt)
|
||||
)
|
||||
# This branch handles the case when a python scalar is passed to some op
|
||||
# or is returned from some aten op, such as _local_scalar_dense.
|
||||
elif isinstance(elem, (int, float, bool)):
|
||||
return elem
|
||||
else:
|
||||
raise ValueError(f"Unknown element type: {type(elem)}")
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"TorchMLIRLockstepTensor({self.elem}, backend={backend.__class__.__name__}, grad_fn={self.grad_fn})"
|
||||
else:
|
||||
return f"TorchMLIRLockstepTensor({self.elem}, backend={backend.__class__.__name__})"
|
||||
|
||||
"""This does essentially the same dispatch as TorchMLIRTensor but operates as if debug mode is enabled. The numeric verification happens after the Torch-MLIR result is obtained by comparing against the
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, _types, args=(), kwargs=None):
|
||||
requires_grad = check_requires_grad(*args, **kwargs)
|
||||
try:
|
||||
with no_dispatch():
|
||||
if hasattr(func, "op_name"):
|
||||
op_name = func.op_name
|
||||
elif hasattr(func, "__name__"):
|
||||
# Handle builtin_function_or_method.
|
||||
op_name = func.__name__
|
||||
else:
|
||||
raise RuntimeError(f"op {func} has no name")
|
||||
|
||||
if UNSUPPORTED_OPS.match(op_name):
|
||||
raise UnsupportedByTorchMlirEagerMode(op_name)
|
||||
|
||||
if not hasattr(func, "_schema"):
|
||||
raise RuntimeError(f"op {func} has no schema.")
|
||||
|
||||
normalized_kwargs = normalize_args_kwargs(func, args, kwargs)
|
||||
|
||||
if "layout" in normalized_kwargs and normalized_kwargs[
|
||||
"layout"
|
||||
] not in {0, None}:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
f"{normalized_kwargs['layout']} layout not supported."
|
||||
)
|
||||
if "memory_format" in normalized_kwargs and normalized_kwargs[
|
||||
"memory_format"
|
||||
] not in {0, None}:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
f"{normalized_kwargs['memory_format']} memory format not supported."
|
||||
)
|
||||
eager_module = build_mlir_module(func, normalized_kwargs)
|
||||
device_tensor_args = [
|
||||
kwarg.elem
|
||||
for _, kwarg in normalized_kwargs.items()
|
||||
if isinstance(kwarg, cls)
|
||||
]
|
||||
assert len(eager_module.body.operations[0].arguments) == len(
|
||||
device_tensor_args
|
||||
), "Number of parameters and number of arguments differs."
|
||||
op_mlir_backend_callable = backend.compile(eager_module)
|
||||
out = op_mlir_backend_callable(*device_tensor_args)
|
||||
out = tree_map(
|
||||
lambda x: cls(
|
||||
x,
|
||||
requires_grad=requires_grad,
|
||||
constructing_from_device_tensor=True,
|
||||
),
|
||||
out,
|
||||
)
|
||||
|
||||
# Numeric verification; Value for comparison comes from PyTorch eager
|
||||
with no_dispatch():
|
||||
unwrapped_args = tree_map(cls.unwrap, args)
|
||||
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
|
||||
if "_reshape_alias" in op_name:
|
||||
native_out = torch.ops.aten.view(
|
||||
unwrapped_args[0], unwrapped_args[1]
|
||||
)
|
||||
else:
|
||||
native_out = func(*unwrapped_args, **unwrapped_kwargs)
|
||||
|
||||
native_out = tree_map(
|
||||
lambda x: cls(x, requires_grad=requires_grad), native_out
|
||||
).elem
|
||||
tmp_out = out.elem
|
||||
|
||||
try:
|
||||
np.testing.assert_allclose(
|
||||
native_out.to_host(),
|
||||
tmp_out.to_host(),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
except Exception as e:
|
||||
shaped_args = [
|
||||
arg.shape if torch.is_tensor(arg) else arg
|
||||
for arg in unwrapped_args
|
||||
]
|
||||
shaped_kwargs = [
|
||||
kwarg.shape if torch.is_tensor(kwarg) else kwarg
|
||||
for kwarg in unwrapped_kwargs
|
||||
]
|
||||
warnings.warn(
|
||||
f"Lockstep accuracy verification failed with error: *{str(e)}*; "
|
||||
f"Dispatched function name: *{str(func)}*; "
|
||||
f"Dispatched function args: *{str(shaped_args)}*; "
|
||||
f"Dispatched function kwargs: *{str(shaped_kwargs)}*; "
|
||||
)
|
||||
except Exception as e:
|
||||
warnings.warn(traceback.format_exc())
|
||||
if isinstance(e, UnsupportedByTorchMlirEagerMode):
|
||||
warnings.warn(
|
||||
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
|
||||
f"Running through PyTorch eager"
|
||||
)
|
||||
|
||||
with no_dispatch():
|
||||
unwrapped_args = tree_map(cls.unwrap, args)
|
||||
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
|
||||
if "_reshape_alias" in op_name:
|
||||
out = torch.ops.aten.view(
|
||||
unwrapped_args[0], unwrapped_args[1]
|
||||
)
|
||||
else:
|
||||
out = func(*unwrapped_args, **unwrapped_kwargs)
|
||||
|
||||
out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out)
|
||||
|
||||
maybe_aliased_arg_name = check_get_aliased_arg(func)
|
||||
if maybe_aliased_arg_name is not None:
|
||||
warnings.warn(
|
||||
f"Found aliased arg, but didn't copy tensor contents. This could lead to incorrect results for E2E model execution but doesn't affect the validity of the lockstep op verification."
|
||||
)
|
||||
# TODO: Find a way to handle argument aliasing for IREE backend
|
||||
# backend.copy_into(normalized_kwargs[maybe_aliased_arg_name].elem, out.elem)
|
||||
|
||||
return out
|
||||
@@ -1,90 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from torch_mlir.ir import StringAttr
|
||||
import torch_mlir
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
import tempfile
|
||||
from amdshark.parser import amdshark_args
|
||||
import io
|
||||
|
||||
mlir_type_mapping_dict = {
|
||||
"linalg": torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
"stablehlo": torch_mlir.OutputType.STABLEHLO,
|
||||
"tosa": torch_mlir.OutputType.TOSA,
|
||||
}
|
||||
|
||||
|
||||
def get_module_name_for_asm_dump(module):
|
||||
"""Gets a name suitable for an assembly dump.
|
||||
The name is not guaranteed to be unique.
|
||||
"""
|
||||
if not "torch.debug_module_name" in module.operation.attributes:
|
||||
return "UnnammedModule"
|
||||
return StringAttr(
|
||||
module.operation.attributes["torch.debug_module_name"]
|
||||
).value
|
||||
|
||||
|
||||
def run_on_refbackend(torch_module, inputs):
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(torch_module)
|
||||
jit_module = backend.load(compiled)
|
||||
np_inputs = [x.numpy() for x in inputs]
|
||||
return jit_module.forward(np_inputs[0])
|
||||
|
||||
|
||||
# Creates dynamic dims for all dims.
|
||||
# TODO: Pass user specified dynamic dims.
|
||||
def create_dynamic_placeholders(inputs):
|
||||
placeholders = []
|
||||
for inp in inputs:
|
||||
placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
inp, dynamic_axes=[i for i in range(len(inp.shape))]
|
||||
)
|
||||
placeholders.append(placeholder)
|
||||
return tuple(placeholders)
|
||||
|
||||
|
||||
def get_torch_mlir_module(
|
||||
module,
|
||||
input: tuple,
|
||||
dynamic: bool,
|
||||
jit_trace: bool,
|
||||
return_str: bool = False,
|
||||
mlir_type: str = "linalg",
|
||||
):
|
||||
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
|
||||
ignore_traced_shapes = False
|
||||
if dynamic:
|
||||
input = create_dynamic_placeholders(input)
|
||||
if jit_trace:
|
||||
ignore_traced_shapes = True
|
||||
|
||||
tempfile.tempdir = "."
|
||||
|
||||
mlir_module = torch_mlir.compile(
|
||||
module,
|
||||
input,
|
||||
output_type=mlir_type_mapping_dict[mlir_type],
|
||||
use_tracing=jit_trace,
|
||||
ignore_traced_shapes=ignore_traced_shapes,
|
||||
)
|
||||
|
||||
if return_str:
|
||||
return mlir_module.operation.get_asm()
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
return bytecode
|
||||
@@ -1,48 +0,0 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from apps.amdshark_studio.studio_imports import pathex, datas, hiddenimports
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
a = Analysis(
|
||||
['web/index.py'],
|
||||
pathex=pathex,
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
module_collection_mode={
|
||||
'gradio': 'py', # Collect gradio package as source .py files
|
||||
},
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='nodai_amdshark_studio',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=False,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
@@ -1,107 +0,0 @@
|
||||
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors
|
||||
import os
|
||||
import PIL
|
||||
import numpy as np
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_generated_imgs_path,
|
||||
)
|
||||
from datetime import datetime
|
||||
from PIL import Image
|
||||
from gradio.components.image_editor import (
|
||||
EditorValue,
|
||||
)
|
||||
|
||||
|
||||
class control_adapter:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
):
|
||||
self.model = None
|
||||
|
||||
def export_control_adapter_model(model_keyword):
|
||||
return None
|
||||
|
||||
def export_xl_control_adapter_model(model_keyword):
|
||||
return None
|
||||
|
||||
|
||||
class preprocessors:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
):
|
||||
self.model = None
|
||||
|
||||
def export_controlnet_model(model_keyword):
|
||||
return None
|
||||
|
||||
|
||||
control_adapter_map = {
|
||||
"sd15": {
|
||||
"canny": {"initializer": control_adapter.export_control_adapter_model},
|
||||
"openpose": {"initializer": control_adapter.export_control_adapter_model},
|
||||
"scribble": {"initializer": control_adapter.export_control_adapter_model},
|
||||
"zoedepth": {"initializer": control_adapter.export_control_adapter_model},
|
||||
},
|
||||
"sdxl": {
|
||||
"canny": {"initializer": control_adapter.export_xl_control_adapter_model},
|
||||
},
|
||||
}
|
||||
preprocessor_model_map = {
|
||||
"canny": {"initializer": preprocessors.export_controlnet_model},
|
||||
"openpose": {"initializer": preprocessors.export_controlnet_model},
|
||||
"scribble": {"initializer": preprocessors.export_controlnet_model},
|
||||
"zoedepth": {"initializer": preprocessors.export_controlnet_model},
|
||||
}
|
||||
|
||||
|
||||
class PreprocessorModel:
|
||||
def __init__(
|
||||
self,
|
||||
hf_model_id,
|
||||
device="cpu",
|
||||
):
|
||||
self.model = hf_model_id
|
||||
self.device = device
|
||||
|
||||
def compile(self):
|
||||
print("compile not implemented for preprocessor.")
|
||||
return
|
||||
|
||||
def run(self, inputs):
|
||||
print("run not implemented for preprocessor.")
|
||||
return inputs
|
||||
|
||||
|
||||
def cnet_preview(model, input_image):
|
||||
curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
|
||||
control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints")
|
||||
if not os.path.exists(control_imgs_path):
|
||||
os.mkdir(control_imgs_path)
|
||||
img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png")
|
||||
match model:
|
||||
case "canny":
|
||||
canny = PreprocessorModel("canny")
|
||||
result = canny(
|
||||
np.array(input_image),
|
||||
100,
|
||||
200,
|
||||
)
|
||||
Image.fromarray(result).save(fp=img_dest)
|
||||
return result, img_dest
|
||||
case "openpose":
|
||||
openpose = PreprocessorModel("openpose")
|
||||
result = openpose(np.array(input_image))
|
||||
Image.fromarray(result[0]).save(fp=img_dest)
|
||||
return result, img_dest
|
||||
case "zoedepth":
|
||||
zoedepth = PreprocessorModel("ZoeDepth")
|
||||
result = zoedepth(np.array(input_image))
|
||||
Image.fromarray(result).save(fp=img_dest)
|
||||
return result, img_dest
|
||||
case "scribble":
|
||||
input_image.save(fp=img_dest)
|
||||
return input_image, img_dest
|
||||
case _:
|
||||
return None, None
|
||||
@@ -1,125 +0,0 @@
|
||||
import importlib
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import warnings
|
||||
import json
|
||||
from threading import Thread
|
||||
|
||||
from apps.amdshark_studio.modules.timer import startup_timer
|
||||
|
||||
from apps.amdshark_studio.web.utils.tmp_configs import (
|
||||
config_tmp,
|
||||
clear_tmp_mlir,
|
||||
clear_tmp_imgs,
|
||||
amdshark_tmp,
|
||||
)
|
||||
|
||||
|
||||
def imports():
|
||||
import torch # noqa: F401
|
||||
|
||||
startup_timer.record("import torch")
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=DeprecationWarning, module="torch"
|
||||
)
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
|
||||
|
||||
import gradio # noqa: F401
|
||||
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
global_obj._init()
|
||||
startup_timer.record("initialize globals")
|
||||
|
||||
from apps.amdshark_studio.modules import (
|
||||
img_processing,
|
||||
) # noqa: F401
|
||||
|
||||
startup_timer.record("other imports")
|
||||
|
||||
|
||||
def initialize():
|
||||
configure_sigint_handler()
|
||||
# Setup to use amdshark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
|
||||
config_tmp()
|
||||
# clear_tmp_mlir()
|
||||
clear_tmp_imgs()
|
||||
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
create_model_folders,
|
||||
)
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
create_model_folders()
|
||||
|
||||
import gradio as gr
|
||||
|
||||
# initialize_rest(reload_script_modules=False)
|
||||
|
||||
|
||||
def initialize_rest(*, reload_script_modules=False):
|
||||
"""
|
||||
Called both from initialize() and when reloading the webui.
|
||||
"""
|
||||
# Keep this for adding reload options to the webUI.
|
||||
|
||||
|
||||
def dumpstacks():
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
id2name = {th.ident: th.name for th in threading.enumerate()}
|
||||
code = []
|
||||
for threadId, stack in sys._current_frames().items():
|
||||
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
|
||||
for filename, lineno, name, line in traceback.extract_stack(stack):
|
||||
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
||||
if line:
|
||||
code.append(" " + line.strip())
|
||||
with open(os.path.join(amdshark_tmp, "stack_dump.log"), "w") as f:
|
||||
f.write("\n".join(code))
|
||||
|
||||
|
||||
def setup_middleware(app):
|
||||
from starlette.middleware.gzip import GZipMiddleware
|
||||
|
||||
app.middleware_stack = (
|
||||
None # reset current middleware to allow modifying user provided list
|
||||
)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
configure_cors_middleware(app)
|
||||
app.build_middleware_stack() # rebuild middleware stack on-the-fly
|
||||
|
||||
|
||||
def configure_cors_middleware(app):
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
cors_options = {
|
||||
"allow_methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
"allow_credentials": True,
|
||||
}
|
||||
if cmd_opts.api_accept_origin:
|
||||
cors_options["allow_origins"] = cmd_opts.api_accept_origin.split(",")
|
||||
|
||||
app.add_middleware(CORSMiddleware, **cors_options)
|
||||
|
||||
|
||||
def configure_sigint_handler():
|
||||
# make the program just exit at ctrl+c without waiting for anything
|
||||
def sigint_handler(sig, frame):
|
||||
print(f"Interrupted with signal {sig} in {frame}")
|
||||
|
||||
dumpstacks()
|
||||
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
@@ -1,475 +0,0 @@
|
||||
from turbine_models.custom_models import stateless_llama
|
||||
from turbine_models.model_runner import vmfbRunner
|
||||
from turbine_models.gen_external_params.gen_external_params import gen_external_params
|
||||
import time
|
||||
from amdshark.iree_utils.compile_utils import compile_module_to_flatbuffer
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_resource_path,
|
||||
get_checkpoints_path,
|
||||
)
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.api.utils import parse_device
|
||||
from urllib.request import urlopen
|
||||
import iree.runtime as ireert
|
||||
from itertools import chain
|
||||
import gc
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
llm_model_map = {
|
||||
"meta-llama/Llama-2-7b-chat-hf": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
|
||||
"stop_token": 2,
|
||||
"max_tokens": 4096,
|
||||
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
|
||||
},
|
||||
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
|
||||
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
|
||||
"stop_token": 2,
|
||||
"max_tokens": 4096,
|
||||
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
|
||||
},
|
||||
"TinyPixel/small-llama2": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "TinyPixel/small-llama2",
|
||||
"compile_flags": ["--iree-opt-const-expr-hoisting=True"],
|
||||
"stop_token": 2,
|
||||
"max_tokens": 1024,
|
||||
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
|
||||
},
|
||||
}
|
||||
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<s>", "</s>"
|
||||
|
||||
DEFAULT_CHAT_SYS_PROMPT = """<s>[INST] <<SYS>>
|
||||
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <</SYS>>\n\n
|
||||
"""
|
||||
|
||||
|
||||
def append_user_prompt(history, input_prompt):
|
||||
user_prompt = f"{B_INST} {input_prompt} {E_INST}"
|
||||
history += user_prompt
|
||||
return history
|
||||
|
||||
|
||||
class LanguageModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_auth_token=None,
|
||||
device=None,
|
||||
quantization="int4",
|
||||
precision="",
|
||||
external_weights=None,
|
||||
use_system_prompt=True,
|
||||
streaming_llm=False,
|
||||
):
|
||||
_, _, self.triple = parse_device(device)
|
||||
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
|
||||
self.device = device.split("=>")[-1].strip()
|
||||
self.backend = self.device.split("://")[0]
|
||||
self.driver = self.backend
|
||||
if "cpu" in device:
|
||||
self.device = "cpu"
|
||||
self.backend = "llvm-cpu"
|
||||
self.driver = "local-task"
|
||||
|
||||
print(f"Selected {self.backend} as IREE target backend.")
|
||||
self.precision = "f32" if "cpu" in device else "f16"
|
||||
self.quantization = quantization
|
||||
self.safe_name = self.hf_model_name.replace("/", "_").replace("-", "_")
|
||||
self.external_weight_file = None
|
||||
# TODO: find a programmatic solution for model arch spec instead of hardcoding llama2
|
||||
self.file_spec = "_".join(
|
||||
[
|
||||
self.safe_name,
|
||||
self.precision,
|
||||
]
|
||||
)
|
||||
if self.quantization != "None":
|
||||
self.file_spec += "_" + self.quantization
|
||||
|
||||
if external_weights in ["safetensors", "gguf"]:
|
||||
self.external_weight_file = get_resource_path(
|
||||
os.path.join("..", self.file_spec + "." + external_weights)
|
||||
)
|
||||
else:
|
||||
self.external_weights = None
|
||||
self.external_weight_file = None
|
||||
|
||||
if streaming_llm:
|
||||
# Add streaming suffix to file spec after setting external weights filename.
|
||||
self.file_spec += "_streaming"
|
||||
self.streaming_llm = streaming_llm
|
||||
|
||||
self.tempfile_name = get_resource_path(
|
||||
os.path.join("..", f"{self.file_spec}.tempfile")
|
||||
)
|
||||
# TODO: Tag vmfb with target triple of device instead of HAL backend
|
||||
self.vmfb_name = str(
|
||||
get_resource_path(
|
||||
os.path.join("..", f"{self.file_spec}_{self.backend}.vmfb.tempfile")
|
||||
)
|
||||
)
|
||||
|
||||
self.max_tokens = llm_model_map[model_name]["max_tokens"]
|
||||
self.iree_module_dict = None
|
||||
self.use_system_prompt = use_system_prompt
|
||||
self.global_iter = 0
|
||||
self.prev_token_len = 0
|
||||
self.first_input = True
|
||||
self.hf_auth_token = hf_auth_token
|
||||
if self.external_weight_file is not None:
|
||||
if not os.path.exists(self.external_weight_file):
|
||||
print(
|
||||
f"External weight file {self.external_weight_file} does not exist. Generating..."
|
||||
)
|
||||
gen_external_params(
|
||||
hf_model_name=self.hf_model_name,
|
||||
quantization=self.quantization,
|
||||
weight_path=self.external_weight_file,
|
||||
hf_auth_token=hf_auth_token,
|
||||
precision=self.precision,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"External weight file {self.external_weight_file} found for {self.vmfb_name}"
|
||||
)
|
||||
self.external_weight_file = str(self.external_weight_file)
|
||||
|
||||
if os.path.exists(self.vmfb_name) and (
|
||||
external_weights is None or os.path.exists(str(self.external_weight_file))
|
||||
):
|
||||
self.runner = vmfbRunner(
|
||||
device=self.driver,
|
||||
vmfb_path=self.vmfb_name,
|
||||
external_weight_path=self.external_weight_file,
|
||||
)
|
||||
if self.streaming_llm:
|
||||
self.model = self.runner.ctx.modules.streaming_state_update
|
||||
else:
|
||||
self.model = self.runner.ctx.modules.state_update
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_name,
|
||||
use_fast=False,
|
||||
use_auth_token=hf_auth_token,
|
||||
)
|
||||
elif not os.path.exists(self.tempfile_name):
|
||||
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
|
||||
"initializer"
|
||||
](
|
||||
self.hf_model_name,
|
||||
hf_auth_token,
|
||||
compile_to="torch",
|
||||
external_weights=external_weights,
|
||||
precision=self.precision,
|
||||
quantization=self.quantization,
|
||||
streaming_llm=self.streaming_llm,
|
||||
decomp_attn=True,
|
||||
)
|
||||
with open(self.tempfile_name, "w+") as f:
|
||||
f.write(self.torch_ir)
|
||||
del self.torch_ir
|
||||
gc.collect()
|
||||
self.compile()
|
||||
else:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_name,
|
||||
use_fast=False,
|
||||
use_auth_token=hf_auth_token,
|
||||
)
|
||||
self.compile()
|
||||
# Reserved for running HF torch model as reference.
|
||||
self.hf_mod = None
|
||||
|
||||
def compile(self) -> None:
|
||||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
|
||||
# ONLY architecture/api-specific compile-time flags for each backend, if needed.
|
||||
# hf_model_id-specific global flags currently in model map.
|
||||
flags = []
|
||||
if "cpu" in self.backend:
|
||||
flags.extend(
|
||||
[
|
||||
"--iree-global-opt-enable-quantized-matmul-reassociation",
|
||||
]
|
||||
)
|
||||
elif self.backend == "vulkan":
|
||||
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
|
||||
elif self.backend == "rocm":
|
||||
flags.extend(
|
||||
[
|
||||
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
|
||||
"--iree-llvmgpu-enable-prefetch=true",
|
||||
"--iree-opt-outer-dim-concat=true",
|
||||
"--iree-flow-enable-aggressive-fusion",
|
||||
]
|
||||
)
|
||||
if "gfx9" in self.triple:
|
||||
flags.extend(
|
||||
[
|
||||
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
|
||||
"--iree-codegen-llvmgpu-use-vector-distribution=true",
|
||||
]
|
||||
)
|
||||
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
self.tempfile_name,
|
||||
device=self.device,
|
||||
frontend="auto",
|
||||
model_config_path=None,
|
||||
extra_args=flags,
|
||||
write_to=self.vmfb_name,
|
||||
)
|
||||
self.runner = vmfbRunner(
|
||||
device=self.driver,
|
||||
vmfb_path=self.vmfb_name,
|
||||
external_weight_path=self.external_weight_file,
|
||||
)
|
||||
if self.streaming_llm:
|
||||
self.model = self.runner.ctx.modules.streaming_state_update
|
||||
else:
|
||||
self.model = self.runner.ctx.modules.state_update
|
||||
|
||||
def sanitize_prompt(self, prompt):
|
||||
if isinstance(prompt, list):
|
||||
prompt = list(chain.from_iterable(prompt))
|
||||
prompt = " ".join([x for x in prompt if isinstance(x, str)])
|
||||
prompt = prompt.replace("\n", " ")
|
||||
prompt = prompt.replace("\t", " ")
|
||||
prompt = prompt.replace("\r", " ")
|
||||
if self.use_system_prompt and self.global_iter == 0:
|
||||
prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt)
|
||||
return prompt
|
||||
else:
|
||||
return f"{B_INST} {prompt} {E_INST}"
|
||||
|
||||
def chat(self, prompt):
|
||||
prompt = self.sanitize_prompt(prompt)
|
||||
|
||||
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
def format_out(results):
|
||||
return torch.tensor(results.to_host()[0][0])
|
||||
|
||||
history = []
|
||||
for iter in range(self.max_tokens):
|
||||
if self.streaming_llm:
|
||||
token_slice = max(self.prev_token_len - 1, 0)
|
||||
input_tensor = input_tensor[:, token_slice:]
|
||||
if self.streaming_llm and self.model["get_seq_step"]() > 600:
|
||||
print("Evicting cache space!")
|
||||
self.model["evict_kvcache_space"]()
|
||||
token_len = input_tensor.shape[-1]
|
||||
device_inputs = [
|
||||
ireert.asdevicearray(self.runner.config.device, input_tensor)
|
||||
]
|
||||
if self.first_input or not self.streaming_llm:
|
||||
st_time = time.time()
|
||||
token = self.model["run_initialize"](*device_inputs)
|
||||
total_time = time.time() - st_time
|
||||
token_len += 1
|
||||
self.first_input = False
|
||||
else:
|
||||
st_time = time.time()
|
||||
token = self.model["run_cached_initialize"](*device_inputs)
|
||||
total_time = time.time() - st_time
|
||||
token_len += 1
|
||||
|
||||
history.append(format_out(token))
|
||||
while (
|
||||
format_out(token) != llm_model_map[self.hf_model_name]["stop_token"]
|
||||
and len(history) < self.max_tokens
|
||||
):
|
||||
dec_time = time.time()
|
||||
if self.streaming_llm and self.model["get_seq_step"]() > 600:
|
||||
print("Evicting cache space!")
|
||||
self.model["evict_kvcache_space"]()
|
||||
token = self.model["run_forward"](token)
|
||||
history.append(format_out(token))
|
||||
total_time = time.time() - dec_time
|
||||
yield self.tokenizer.decode(history), total_time
|
||||
|
||||
self.prev_token_len = token_len + len(history)
|
||||
|
||||
if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]:
|
||||
break
|
||||
|
||||
for i in range(len(history)):
|
||||
if type(history[i]) != int:
|
||||
history[i] = int(history[i])
|
||||
result_output = self.tokenizer.decode(history)
|
||||
self.global_iter += 1
|
||||
return result_output, total_time
|
||||
|
||||
# Reference HF model function for sanity checks.
|
||||
def chat_hf(self, prompt):
|
||||
if self.hf_mod is None:
|
||||
self.hf_mod = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_name,
|
||||
torch_dtype=torch.float,
|
||||
token=self.hf_auth_token,
|
||||
)
|
||||
prompt = self.sanitize_prompt(prompt)
|
||||
|
||||
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
history = []
|
||||
for iter in range(self.max_tokens):
|
||||
token_len = input_tensor.shape[-1]
|
||||
if self.first_input:
|
||||
st_time = time.time()
|
||||
result = self.hf_mod(input_tensor)
|
||||
token = torch.argmax(result.logits[:, -1, :], dim=1)
|
||||
total_time = time.time() - st_time
|
||||
token_len += 1
|
||||
pkv = result.past_key_values
|
||||
self.first_input = False
|
||||
|
||||
history.append(int(token))
|
||||
while token != llm_model_map[self.hf_model_name]["stop_token"]:
|
||||
dec_time = time.time()
|
||||
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
|
||||
history.append(int(token))
|
||||
total_time = time.time() - dec_time
|
||||
token = torch.argmax(result.logits[:, -1, :], dim=1)
|
||||
pkv = result.past_key_values
|
||||
yield self.tokenizer.decode(history), total_time
|
||||
|
||||
self.prev_token_len = token_len + len(history)
|
||||
|
||||
if token == llm_model_map[self.hf_model_name]["stop_token"]:
|
||||
break
|
||||
for i in range(len(history)):
|
||||
if type(history[i]) != int:
|
||||
history[i] = int(history[i])
|
||||
result_output = self.tokenizer.decode(history)
|
||||
self.global_iter += 1
|
||||
return result_output, total_time
|
||||
|
||||
|
||||
def get_mfma_spec_path(target_chip, save_dir):
|
||||
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
|
||||
attn_spec = urlopen(url).read().decode("utf-8")
|
||||
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
|
||||
if os.path.exists(spec_path):
|
||||
return spec_path
|
||||
with open(spec_path, "w") as f:
|
||||
f.write(attn_spec)
|
||||
return spec_path
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
from datetime import datetime as dt
|
||||
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
print(f"Input keys : {InputData.keys()}")
|
||||
|
||||
# print(f"model : {InputData['model']}")
|
||||
|
||||
is_chat_completion_api = (
|
||||
"messages" in InputData.keys()
|
||||
) # else it is the legacy `completion` api
|
||||
|
||||
# For Debugging input data from API
|
||||
if is_chat_completion_api:
|
||||
print(f"message -> role : {InputData['messages'][0]['role']}")
|
||||
print(f"message -> content : {InputData['messages'][0]['content']}")
|
||||
else:
|
||||
print(f"prompt : {InputData['prompt']}")
|
||||
|
||||
model_name = (
|
||||
InputData["model"]
|
||||
if "model" in InputData.keys()
|
||||
else "meta-llama/Llama-2-7b-chat-hf"
|
||||
)
|
||||
model_path = llm_model_map[model_name]
|
||||
device = InputData["device"] if "device" in InputData.keys() else "cpu"
|
||||
precision = "fp16"
|
||||
max_tokens = InputData["max_tokens"] if "max_tokens" in InputData.keys() else 4096
|
||||
|
||||
device_id = None
|
||||
if not global_obj.get_llm_obj():
|
||||
print("\n[LOG] Initializing new pipeline...")
|
||||
global_obj.clear_cache()
|
||||
gc.collect()
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "cpu" in device:
|
||||
device = "cpu"
|
||||
precision = "fp32"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
llm_model = LanguageModel(
|
||||
model_name=model_name,
|
||||
hf_auth_token=cmd_opts.hf_auth_token,
|
||||
device=device,
|
||||
quantization=cmd_opts.quantization,
|
||||
external_weights="safetensors",
|
||||
use_system_prompt=True,
|
||||
streaming_llm=False,
|
||||
)
|
||||
global_obj.set_llm_obj(llm_model)
|
||||
else:
|
||||
llm_model = global_obj.get_llm_obj()
|
||||
|
||||
llm_model.max_tokens = max_tokens
|
||||
# TODO: add role dict for different models
|
||||
if is_chat_completion_api:
|
||||
# TODO: add funtionality for multiple messages
|
||||
prompt = append_user_prompt(
|
||||
InputData["messages"][0]["role"], InputData["messages"][0]["content"]
|
||||
)
|
||||
else:
|
||||
prompt = InputData["prompt"]
|
||||
print("prompt = ", prompt)
|
||||
|
||||
for res_op, _ in llm_model.chat(prompt):
|
||||
if is_chat_completion_api:
|
||||
choices = [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": res_op, # since we are yeilding the result
|
||||
},
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
else:
|
||||
choices = [
|
||||
{
|
||||
"text": res_op,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
|
||||
return {
|
||||
"id": end_time,
|
||||
"object": "chat.completion" if is_chat_completion_api else "text_completion",
|
||||
"created": int(end_time),
|
||||
"choices": choices,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
lm = LanguageModel(
|
||||
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
|
||||
hf_auth_token=None,
|
||||
device="cpu-task",
|
||||
external_weights="safetensors",
|
||||
)
|
||||
|
||||
print("model loaded")
|
||||
for i in lm.chat("hi, what are you?"):
|
||||
print(i)
|
||||
@@ -1,505 +0,0 @@
|
||||
import gc
|
||||
import torch
|
||||
import gradio as gr
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import copy
|
||||
import importlib.util
|
||||
import sys
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from turbine_models.custom_models.sd_inference.sd_pipeline import AMDSharkSDPipeline
|
||||
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
|
||||
AMDSharkSDXLPipeline,
|
||||
)
|
||||
|
||||
|
||||
from apps.amdshark_studio.api.controlnet import control_adapter_map
|
||||
from apps.amdshark_studio.api.utils import parse_device
|
||||
from apps.amdshark_studio.web.utils.state import status_label
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
safe_name,
|
||||
get_resource_path,
|
||||
get_checkpoints_path,
|
||||
)
|
||||
|
||||
from apps.amdshark_studio.modules.img_processing import (
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
from apps.amdshark_studio.modules.ckpt_processing import (
|
||||
preprocessCKPT,
|
||||
save_irpa,
|
||||
)
|
||||
|
||||
EMPTY_SD_MAP = {
|
||||
"clip": None,
|
||||
"scheduler": None,
|
||||
"unet": None,
|
||||
"vae_decode": None,
|
||||
}
|
||||
|
||||
EMPTY_SDXL_MAP = {
|
||||
"prompt_encoder": None,
|
||||
"scheduled_unet": None,
|
||||
"vae_decode": None,
|
||||
"pipeline": None,
|
||||
"full_pipeline": None,
|
||||
}
|
||||
|
||||
EMPTY_FLAGS = {
|
||||
"clip": None,
|
||||
"unet": None,
|
||||
"vae": None,
|
||||
"pipeline": None,
|
||||
}
|
||||
|
||||
|
||||
def load_script(source, module_name):
|
||||
"""
|
||||
reads file source and loads it as a module
|
||||
|
||||
:param source: file to load
|
||||
:param module_name: name of module to register in sys.modules
|
||||
:return: loaded module
|
||||
"""
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, source)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
# This class is responsible for executing image generation and creating
|
||||
# /managing a set of compiled modules to run Stable Diffusion. The init
|
||||
# aims to be as general as possible, and the class will infer and compile
|
||||
# a list of necessary modules or a combined "pipeline module" for a
|
||||
# specified job based on the inference task.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model_id,
|
||||
height: int,
|
||||
width: int,
|
||||
batch_size: int,
|
||||
steps: int,
|
||||
scheduler: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
target_triple: str = None,
|
||||
custom_vae: str = None,
|
||||
num_loras: int = 0,
|
||||
import_ir: bool = True,
|
||||
is_controlled: bool = False,
|
||||
external_weights: str = "safetensors",
|
||||
):
|
||||
self.precision = precision
|
||||
self.compiled_pipeline = False
|
||||
self.base_model_id = base_model_id
|
||||
self.custom_vae = custom_vae
|
||||
self.is_sdxl = "xl" in self.base_model_id.lower()
|
||||
self.is_custom = ".py" in self.base_model_id.lower()
|
||||
if self.is_custom:
|
||||
custom_module = load_script(
|
||||
os.path.join(get_checkpoints_path("scripts"), self.base_model_id),
|
||||
"custom_pipeline",
|
||||
)
|
||||
self.turbine_pipe = custom_module.StudioPipeline
|
||||
self.model_map = custom_module.MODEL_MAP
|
||||
elif self.is_sdxl:
|
||||
self.turbine_pipe = AMDSharkSDXLPipeline
|
||||
self.model_map = EMPTY_SDXL_MAP
|
||||
else:
|
||||
self.turbine_pipe = AMDSharkSDPipeline
|
||||
self.model_map = EMPTY_SD_MAP
|
||||
max_length = 64
|
||||
target_backend, self.rt_device, triple = parse_device(device, target_triple)
|
||||
pipe_id_list = [
|
||||
safe_name(base_model_id),
|
||||
str(batch_size),
|
||||
str(max_length),
|
||||
f"{str(height)}x{str(width)}",
|
||||
precision,
|
||||
triple,
|
||||
]
|
||||
if num_loras > 0:
|
||||
pipe_id_list.append(str(num_loras) + "lora")
|
||||
if is_controlled:
|
||||
pipe_id_list.append("controlled")
|
||||
if custom_vae:
|
||||
pipe_id_list.append(custom_vae)
|
||||
self.pipe_id = "_".join(pipe_id_list)
|
||||
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
|
||||
self.weights_path = Path(
|
||||
os.path.join(
|
||||
get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)
|
||||
)
|
||||
)
|
||||
if not os.path.exists(self.weights_path):
|
||||
os.mkdir(self.weights_path)
|
||||
|
||||
decomp_attn = True
|
||||
attn_spec = None
|
||||
if triple in ["gfx940", "gfx942", "gfx90a"]:
|
||||
decomp_attn = False
|
||||
attn_spec = "mfma"
|
||||
elif triple in ["gfx1100", "gfx1103", "gfx1150"]:
|
||||
decomp_attn = False
|
||||
attn_spec = "wmma"
|
||||
if triple in ["gfx1103", "gfx1150"]:
|
||||
# external weights have issues on igpu
|
||||
external_weights = None
|
||||
elif target_backend == "llvm-cpu":
|
||||
decomp_attn = False
|
||||
|
||||
self.sd_pipe = self.turbine_pipe(
|
||||
hf_model_name=base_model_id,
|
||||
scheduler_id=scheduler,
|
||||
height=height,
|
||||
width=width,
|
||||
precision=precision,
|
||||
max_length=max_length,
|
||||
batch_size=batch_size,
|
||||
num_inference_steps=steps,
|
||||
device=target_backend,
|
||||
iree_target_triple=triple,
|
||||
ireec_flags=EMPTY_FLAGS,
|
||||
attn_spec=attn_spec,
|
||||
decomp_attn=decomp_attn,
|
||||
pipeline_dir=self.pipeline_dir,
|
||||
external_weights_dir=self.weights_path,
|
||||
external_weights=external_weights,
|
||||
custom_vae=custom_vae,
|
||||
)
|
||||
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
|
||||
gc.collect()
|
||||
|
||||
def prepare_pipe(
|
||||
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
|
||||
):
|
||||
print(f"\n[LOG] Preparing pipeline...")
|
||||
self.is_img2img = False
|
||||
mlirs = copy.deepcopy(self.model_map)
|
||||
vmfbs = copy.deepcopy(self.model_map)
|
||||
weights = copy.deepcopy(self.model_map)
|
||||
if not self.is_sdxl:
|
||||
compiled_pipeline = False
|
||||
self.compiled_pipeline = compiled_pipeline
|
||||
|
||||
if custom_weights:
|
||||
custom_weights = os.path.join(
|
||||
get_checkpoints_path("checkpoints"),
|
||||
safe_name(self.base_model_id.split("/")[-1]),
|
||||
custom_weights,
|
||||
)
|
||||
diffusers_weights_path = preprocessCKPT(custom_weights, self.precision)
|
||||
for key in weights:
|
||||
if key in ["scheduled_unet", "unet"]:
|
||||
unet_weights_path = os.path.join(
|
||||
diffusers_weights_path,
|
||||
"unet",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
weights[key] = save_irpa(unet_weights_path, "unet.")
|
||||
|
||||
elif key in ["clip", "prompt_encoder"]:
|
||||
if not self.is_sdxl:
|
||||
sd1_path = os.path.join(
|
||||
diffusers_weights_path, "text_encoder", "model.safetensors"
|
||||
)
|
||||
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
|
||||
else:
|
||||
clip_1_path = os.path.join(
|
||||
diffusers_weights_path, "text_encoder", "model.safetensors"
|
||||
)
|
||||
clip_2_path = os.path.join(
|
||||
diffusers_weights_path,
|
||||
"text_encoder_2",
|
||||
"model.safetensors",
|
||||
)
|
||||
weights[key] = [
|
||||
save_irpa(clip_1_path, "text_encoder_model_1."),
|
||||
save_irpa(clip_2_path, "text_encoder_model_2."),
|
||||
]
|
||||
|
||||
elif key in ["vae_decode"] and weights[key] is None:
|
||||
vae_weights_path = os.path.join(
|
||||
diffusers_weights_path,
|
||||
"vae",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
weights[key] = save_irpa(vae_weights_path, "vae.")
|
||||
|
||||
vmfbs, weights = self.sd_pipe.check_prepared(
|
||||
mlirs, vmfbs, weights, interactive=False
|
||||
)
|
||||
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
|
||||
self.sd_pipe.load_pipeline(
|
||||
vmfbs, weights, self.rt_device, self.compiled_pipeline
|
||||
)
|
||||
print(
|
||||
"\n[LOG] Pipeline successfully prepared for runtime. Generating images..."
|
||||
)
|
||||
return
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
ondemand,
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
):
|
||||
img = self.sd_pipe.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
1,
|
||||
guidance_scale,
|
||||
seed,
|
||||
return_imgs=True,
|
||||
)
|
||||
return img
|
||||
|
||||
|
||||
def amdshark_sd_fn_dict_input(
|
||||
sd_kwargs: dict,
|
||||
):
|
||||
print("\n[LOG] Submitting Request...")
|
||||
|
||||
for key in sd_kwargs:
|
||||
if sd_kwargs[key] in [None, []]:
|
||||
sd_kwargs[key] = None
|
||||
if sd_kwargs[key] in ["None"]:
|
||||
sd_kwargs[key] = ""
|
||||
if key == "seed":
|
||||
sd_kwargs[key] = int(sd_kwargs[key])
|
||||
|
||||
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
|
||||
if not sd_kwargs["device"]:
|
||||
gr.Warning("No device specified. Please specify a device.")
|
||||
return None, ""
|
||||
if sd_kwargs["height"] not in [512, 1024]:
|
||||
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
|
||||
return None, ""
|
||||
if sd_kwargs["height"] != sd_kwargs["width"]:
|
||||
gr.Warning("Height and width must be the same. This is a temporary limitation.")
|
||||
return None, ""
|
||||
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
|
||||
if sd_kwargs["steps"] > 10:
|
||||
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
|
||||
return None, ""
|
||||
if sd_kwargs["guidance_scale"] > 3:
|
||||
gr.Warning(
|
||||
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
|
||||
)
|
||||
return None, ""
|
||||
if sd_kwargs["target_triple"] == "":
|
||||
if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "":
|
||||
gr.Warning(
|
||||
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
|
||||
)
|
||||
return None, ""
|
||||
|
||||
generated_imgs = yield from amdshark_sd_fn(**sd_kwargs)
|
||||
return generated_imgs
|
||||
|
||||
|
||||
def amdshark_sd_fn(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image: list,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
seed: list,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
base_model_id: str,
|
||||
custom_weights: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
target_triple: str,
|
||||
ondemand: bool,
|
||||
compiled_pipeline: bool,
|
||||
resample_type: str,
|
||||
controlnets: dict,
|
||||
embeddings: dict,
|
||||
):
|
||||
sd_kwargs = locals()
|
||||
if not isinstance(sd_init_image, list):
|
||||
sd_init_image = [sd_init_image]
|
||||
is_img2img = True if sd_init_image[0] is not None else False
|
||||
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
adapters = {}
|
||||
is_controlled = False
|
||||
control_mode = None
|
||||
hints = []
|
||||
num_loras = 0
|
||||
import_ir = True
|
||||
for i in embeddings:
|
||||
num_loras += 1 if embeddings[i] else 0
|
||||
if "model" in controlnets:
|
||||
for i, model in enumerate(controlnets["model"]):
|
||||
if "xl" not in base_model_id.lower():
|
||||
adapters[f"control_adapter_{model}"] = {
|
||||
"hf_id": control_adapter_map["runwayml/stable-diffusion-v1-5"][
|
||||
model
|
||||
],
|
||||
"strength": controlnets["strength"][i],
|
||||
}
|
||||
else:
|
||||
adapters[f"control_adapter_{model}"] = {
|
||||
"hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][
|
||||
model
|
||||
],
|
||||
"strength": controlnets["strength"][i],
|
||||
}
|
||||
if model is not None:
|
||||
is_controlled = True
|
||||
control_mode = controlnets["control_mode"]
|
||||
for i in controlnets["hint"]:
|
||||
hints.append[i]
|
||||
|
||||
submit_pipe_kwargs = {
|
||||
"base_model_id": base_model_id,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"batch_size": batch_size,
|
||||
"precision": precision,
|
||||
"device": device,
|
||||
"target_triple": target_triple,
|
||||
"custom_vae": custom_vae,
|
||||
"num_loras": num_loras,
|
||||
"import_ir": import_ir,
|
||||
"is_controlled": is_controlled,
|
||||
"steps": steps,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
submit_prep_kwargs = {
|
||||
"custom_weights": custom_weights,
|
||||
"adapters": adapters,
|
||||
"embeddings": embeddings,
|
||||
"is_img2img": is_img2img,
|
||||
"compiled_pipeline": compiled_pipeline,
|
||||
}
|
||||
submit_run_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"image": sd_init_image,
|
||||
"strength": strength,
|
||||
"guidance_scale": guidance_scale,
|
||||
"seed": seed,
|
||||
"ondemand": ondemand,
|
||||
"resample_type": resample_type,
|
||||
"control_mode": control_mode,
|
||||
"hints": hints,
|
||||
}
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
|
||||
):
|
||||
print("\n[LOG] Initializing new pipeline...")
|
||||
global_obj.clear_cache()
|
||||
gc.collect()
|
||||
|
||||
# Initializes the pipeline and retrieves IR based on all
|
||||
# parameters that are static in the turbine output format,
|
||||
# which is currently MLIR in the torch dialect.
|
||||
|
||||
sd_pipe = StableDiffusion(
|
||||
**submit_pipe_kwargs,
|
||||
)
|
||||
global_obj.set_sd_obj(sd_pipe)
|
||||
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
|
||||
if (
|
||||
not global_obj.get_prep_kwargs()
|
||||
or global_obj.get_prep_kwargs() != submit_prep_kwargs
|
||||
):
|
||||
global_obj.set_prep_kwargs(submit_prep_kwargs)
|
||||
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
|
||||
|
||||
generated_imgs = []
|
||||
for current_batch in range(batch_count):
|
||||
start_time = time.time()
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
|
||||
if not isinstance(out_imgs, list):
|
||||
out_imgs = [out_imgs]
|
||||
# total_time = time.time() - start_time
|
||||
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
|
||||
# print(f"\n[LOG] {text_output}")
|
||||
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
# break
|
||||
# else:
|
||||
for batch in range(batch_size):
|
||||
save_output_img(
|
||||
out_imgs[batch],
|
||||
seed,
|
||||
sd_kwargs,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
# TODO: make seed changes over batch counts more configurable.
|
||||
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
|
||||
yield generated_imgs, status_label(
|
||||
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
return (generated_imgs, "")
|
||||
|
||||
|
||||
def unload_sd():
|
||||
print("Unloading models.")
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
global_obj.clear_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
print("Inject call to cancel longer API calls.")
|
||||
return
|
||||
|
||||
|
||||
def view_json_file(file_path):
|
||||
content = ""
|
||||
with open(file_path, "r") as fopen:
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
|
||||
def safe_name(name):
|
||||
return name.replace("/", "_").replace("\\", "_").replace(".", "_")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
global_obj._init()
|
||||
|
||||
sd_json = view_json_file(
|
||||
get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json"))
|
||||
)
|
||||
sd_kwargs = json.loads(sd_json)
|
||||
for arg in vars(cmd_opts):
|
||||
if arg in sd_kwargs:
|
||||
sd_kwargs[arg] = getattr(cmd_opts, arg)
|
||||
for i in amdshark_sd_fn_dict_input(sd_kwargs):
|
||||
print(i)
|
||||
@@ -1,389 +0,0 @@
|
||||
import numpy as np
|
||||
import json
|
||||
from random import (
|
||||
randint,
|
||||
seed as seed_random,
|
||||
getstate as random_getstate,
|
||||
setstate as random_setstate,
|
||||
)
|
||||
|
||||
from pathlib import Path
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
# TODO: migrate these utils to studio
|
||||
from amdshark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
def get_devices_by_name(driver_name):
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
driver_name = iree_device_map(driver_name)
|
||||
device_list_dict = get_all_devices(driver_name)
|
||||
print(f"{driver_name} devices are available.")
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
cpu_name = get_cpu_info()["brand_raw"]
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_name = (
|
||||
cpu_name if device["name"] == "default" else device["name"]
|
||||
)
|
||||
if "local" in driver_name:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
# for drivers with single devices
|
||||
# let the default device be selected without any indexing
|
||||
if len(device_list_dict) == 1:
|
||||
device_list.append(f"{device_name} => {driver_name}")
|
||||
else:
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
|
||||
from amdshark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
hip_devices = get_devices_by_name("hip")
|
||||
available_devices.extend(hip_devices)
|
||||
|
||||
for idx, device_str in enumerate(available_devices):
|
||||
if "AMD Radeon(TM) Graphics =>" in device_str:
|
||||
igpu_id_candidates = [
|
||||
x.split("w/")[-1].split("=>")[0]
|
||||
for x in available_devices
|
||||
if "M Graphics" in x
|
||||
]
|
||||
for igpu_name in igpu_id_candidates:
|
||||
if igpu_name:
|
||||
available_devices[idx] = device_str.replace(
|
||||
"AMD Radeon(TM) Graphics", igpu_name
|
||||
)
|
||||
break
|
||||
return available_devices
|
||||
|
||||
|
||||
def set_init_device_flags():
|
||||
if "vulkan" in cmd_opts.device:
|
||||
# set runtime flags for vulkan.
|
||||
set_iree_runtime_flags()
|
||||
|
||||
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
||||
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
|
||||
if not cmd_opts.iree_vulkan_target_triple:
|
||||
triple = get_vulkan_target_triple(device_name)
|
||||
if triple is not None:
|
||||
cmd_opts.iree_vulkan_target_triple = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{cmd_opts.iree_vulkan_target_triple}."
|
||||
)
|
||||
elif "cuda" in cmd_opts.device:
|
||||
cmd_opts.device = "cuda"
|
||||
elif "metal" in cmd_opts.device:
|
||||
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
|
||||
if not cmd_opts.iree_metal_target_platform:
|
||||
from amdshark.iree_utils.metal_utils import get_metal_target_triple
|
||||
|
||||
triple = get_metal_target_triple(device_name)
|
||||
if triple is not None:
|
||||
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{cmd_opts.iree_metal_target_platform}."
|
||||
)
|
||||
elif "cpu" in cmd_opts.device:
|
||||
cmd_opts.device = "cpu"
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if cmd_opts.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if cmd_opts.device_allocator_heap_key:
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
def parse_device(device_str, target_override=""):
|
||||
from amdshark.iree_utils.compile_utils import (
|
||||
clean_device_info,
|
||||
get_iree_target_triple,
|
||||
iree_target_map,
|
||||
)
|
||||
|
||||
rt_driver, device_id = clean_device_info(device_str)
|
||||
target_backend = iree_target_map(rt_driver)
|
||||
if device_id:
|
||||
rt_device = f"{rt_driver}://{device_id}"
|
||||
else:
|
||||
rt_device = rt_driver
|
||||
|
||||
if target_override:
|
||||
return target_backend, rt_device, target_override
|
||||
match target_backend:
|
||||
case "vulkan-spirv":
|
||||
triple = get_iree_target_triple(device_str)
|
||||
return target_backend, rt_device, triple
|
||||
case "rocm":
|
||||
triple = get_rocm_target_chip(device_str)
|
||||
return target_backend, rt_device, triple
|
||||
case "llvm-cpu":
|
||||
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
|
||||
|
||||
|
||||
def get_rocm_target_chip(device_str):
|
||||
# TODO: Use a data file to map device_str to target chip.
|
||||
rocm_chip_map = {
|
||||
"6700": "gfx1031",
|
||||
"6800": "gfx1030",
|
||||
"6900": "gfx1030",
|
||||
"7900": "gfx1100",
|
||||
"MI300X": "gfx942",
|
||||
"MI300A": "gfx940",
|
||||
"MI210": "gfx90a",
|
||||
"MI250": "gfx90a",
|
||||
"MI100": "gfx908",
|
||||
"MI50": "gfx906",
|
||||
"MI60": "gfx906",
|
||||
"780M": "gfx1103",
|
||||
}
|
||||
for key in rocm_chip_map:
|
||||
if key in device_str:
|
||||
return rocm_chip_map[key]
|
||||
raise AssertionError(
|
||||
f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/AMD-SHARK-Studio/issues."
|
||||
)
|
||||
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
"""
|
||||
Inputs: driver_name
|
||||
Returns a list of all the available devices for a given driver sorted by
|
||||
the iree path names of the device as in --list_devices option in iree.
|
||||
"""
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return device_list_src
|
||||
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired
|
||||
combination of name/path.
|
||||
"""
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
|
||||
def get_output_value(dev_dict):
|
||||
if key_combination == 1:
|
||||
return f"{driver}://{dev_dict['path']}"
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
|
||||
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
return device_map
|
||||
|
||||
|
||||
def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags = []
|
||||
if len(cmd_opts.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
|
||||
)
|
||||
if "rocm" in cmd_opts.device:
|
||||
from amdshark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
if cmd_opts.iree_constant_folding == False:
|
||||
iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
iree_flags.append(
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
)
|
||||
if cmd_opts.data_tiling == False:
|
||||
iree_flags.append("--iree-opt-data-tiling=False")
|
||||
|
||||
if "vae" not in model:
|
||||
# Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# dims before dispatch formation right now.
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
return iree_flags
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user
|
||||
selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for
|
||||
the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
try:
|
||||
device_mapping = device_map[device]
|
||||
except KeyError:
|
||||
raise ValueError(f"Device '{device}' is not a valid device.")
|
||||
return device_mapping
|
||||
|
||||
def get_devices_by_name(driver_name):
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
driver_name = iree_device_map(driver_name)
|
||||
device_list_dict = get_all_devices(driver_name)
|
||||
print(f"{driver_name} devices are available.")
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
cpu_name = get_cpu_info()["brand_raw"]
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_name = (
|
||||
cpu_name if device["name"] == "default" else device["name"]
|
||||
)
|
||||
if "local" in driver_name:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
# for drivers with single devices
|
||||
# let the default device be selected without any indexing
|
||||
if len(device_list_dict) == 1:
|
||||
device_list.append(f"{device_name} => {driver_name}")
|
||||
else:
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
from amdshark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
return available_devices
|
||||
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the
|
||||
# supported range (including -1)
|
||||
def sanitize_seed(seed: int | str):
|
||||
seed = int(seed)
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
return seed
|
||||
|
||||
|
||||
# take a seed expression in an input format and convert it to
|
||||
# a list of integers, where possible
|
||||
def parse_seed_input(seed_input: str | list | int):
|
||||
if isinstance(seed_input, str):
|
||||
try:
|
||||
seed_input = json.loads(seed_input)
|
||||
except (ValueError, TypeError):
|
||||
seed_input = None
|
||||
|
||||
if isinstance(seed_input, int):
|
||||
return [seed_input]
|
||||
|
||||
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
|
||||
return seed_input
|
||||
|
||||
raise TypeError(
|
||||
"Seed input must be an integer or an array of integers in JSON format"
|
||||
)
|
||||
@@ -1,145 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import requests
|
||||
import torch
|
||||
import safetensors
|
||||
from iree.turbine.aot.params import (
|
||||
ParameterArchiveBuilder,
|
||||
)
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from omegaconf import OmegaConf
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
create_vae_diffusers_config,
|
||||
convert_ldm_vae_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"):
|
||||
path = Path(custom_weights)
|
||||
diffusers_path = path.parent.absolute()
|
||||
diffusers_directory_name = os.path.join("diffusers", path.stem + f"_{precision}")
|
||||
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
|
||||
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
|
||||
path_to_diffusers = complete_path_to_diffusers.as_posix()
|
||||
return path_to_diffusers
|
||||
|
||||
|
||||
def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False):
|
||||
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision)
|
||||
if next(Path(path_to_diffusers).iterdir(), None):
|
||||
print("Checkpoint already loaded at : ", path_to_diffusers)
|
||||
return path_to_diffusers
|
||||
else:
|
||||
print(
|
||||
"Diffusers' checkpoint will be identified here : ",
|
||||
path_to_diffusers,
|
||||
)
|
||||
from_safetensors = (
|
||||
True if custom_weights.lower().endswith(".safetensors") else False
|
||||
)
|
||||
# EMA weights usually yield higher quality images for inference but
|
||||
# non-EMA weights have been yielding better results in our case.
|
||||
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
|
||||
# they want to go for EMA weight extraction or not.
|
||||
extract_ema = False
|
||||
print("Loading diffusers' pipeline from original stable diffusion checkpoint")
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path_or_dict=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
num_in_channels=num_in_channels,
|
||||
)
|
||||
if precision == "fp16":
|
||||
pipe.to(dtype=torch.float16)
|
||||
pipe.save_pretrained(path_to_diffusers)
|
||||
del pipe
|
||||
print("Loading complete")
|
||||
return path_to_diffusers
|
||||
|
||||
|
||||
def save_irpa(weights_path, prepend_str):
|
||||
weights = safetensors.torch.load_file(weights_path)
|
||||
archive = ParameterArchiveBuilder()
|
||||
for key in weights.keys():
|
||||
new_key = prepend_str + key
|
||||
archive.add_tensor(new_key, weights[key])
|
||||
|
||||
irpa_file = weights_path.replace(".safetensors", ".irpa")
|
||||
archive.save(irpa_file)
|
||||
return irpa_file
|
||||
|
||||
|
||||
def convert_original_vae(vae_checkpoint):
|
||||
vae_state_dict = {}
|
||||
for key in list(vae_checkpoint.keys()):
|
||||
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
|
||||
|
||||
config_url = (
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
|
||||
"main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=512)
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config)
|
||||
return converted_vae_checkpoint
|
||||
|
||||
|
||||
def process_custom_pipe_weights(custom_weights):
|
||||
if custom_weights != "":
|
||||
if custom_weights.startswith("https://civitai.com/api/"):
|
||||
# download the checkpoint from civitai if we don't already have it
|
||||
weights_path = get_civitai_checkpoint(custom_weights)
|
||||
|
||||
# act as if we were given the local file as custom_weights originally
|
||||
custom_weights_tgt = get_path_to_diffusers_checkpoint(weights_path)
|
||||
custom_weights_params = weights_path
|
||||
|
||||
else:
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
custom_weights_params = custom_weights
|
||||
|
||||
return custom_weights_params, custom_weights_tgt
|
||||
|
||||
|
||||
def get_civitai_checkpoint(url: str):
|
||||
with requests.get(url, allow_redirects=True, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# civitai api returns the filename in the content disposition
|
||||
base_filename = re.findall(
|
||||
'"([^"]*)"', response.headers["Content-Disposition"]
|
||||
)[0]
|
||||
destination_path = Path.cwd() / (cmd_opts.model_dir or "models") / base_filename
|
||||
|
||||
# we don't have this model downloaded yet
|
||||
if not destination_path.is_file():
|
||||
print(f"downloading civitai model from {url} to {destination_path}")
|
||||
|
||||
size = int(response.headers["content-length"], 0)
|
||||
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=65536):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
# we already have this model downloaded
|
||||
else:
|
||||
print(f"civitai model already downloaded to {destination_path}")
|
||||
|
||||
response.close()
|
||||
return destination_path.as_posix()
|
||||
@@ -1,185 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import json
|
||||
import safetensors
|
||||
from dataclasses import dataclass
|
||||
from safetensors.torch import load_file
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_checkpoint_pathfile,
|
||||
get_path_stem,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAweight:
|
||||
up: torch.tensor
|
||||
down: torch.tensor
|
||||
mid: torch.tensor
|
||||
alpha: torch.float32 = 1.0
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
state_dict = load_file(use_lora)
|
||||
else:
|
||||
state_dict = torch.load(use_lora)
|
||||
|
||||
# gather the weights from the LoRA in a more convenient form, assumes
|
||||
# everything will have an up.weight.
|
||||
weight_dict: dict[str, LoRAweight] = {}
|
||||
for key in state_dict:
|
||||
if key.startswith(splitting_prefix) and key.endswith("up.weight"):
|
||||
stem = key.split("up.weight")[0]
|
||||
weight_key = stem.removesuffix(".lora_")
|
||||
weight_key = weight_key.removesuffix("_lora_")
|
||||
weight_key = weight_key.removesuffix(".lora_linear_layer.")
|
||||
|
||||
if weight_key not in weight_dict:
|
||||
weight_dict[weight_key] = LoRAweight(
|
||||
state_dict[f"{stem}up.weight"],
|
||||
state_dict[f"{stem}down.weight"],
|
||||
state_dict.get(f"{stem}mid.weight", None),
|
||||
(
|
||||
state_dict[f"{weight_key}.alpha"]
|
||||
/ state_dict[f"{stem}up.weight"].shape[1]
|
||||
if f"{weight_key}.alpha" in state_dict
|
||||
else 1.0
|
||||
),
|
||||
)
|
||||
|
||||
# Directly update weight in model
|
||||
|
||||
# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
|
||||
# and similar code in https://github.com/huggingface/diffusers/issues/3064
|
||||
|
||||
# TODO: handle mid weights (how do they even work?)
|
||||
for key, lora_weight in weight_dict.items():
|
||||
curr_layer = model
|
||||
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
while len(layer_infos) > -1:
|
||||
try:
|
||||
curr_layer = curr_layer.__getattr__(temp_name)
|
||||
if len(layer_infos) > 0:
|
||||
temp_name = layer_infos.pop(0)
|
||||
elif len(layer_infos) == 0:
|
||||
break
|
||||
except Exception:
|
||||
if len(temp_name) > 0:
|
||||
temp_name += "_" + layer_infos.pop(0)
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
|
||||
weight = curr_layer.weight.data
|
||||
scale = lora_weight.alpha * lora_strength
|
||||
if len(weight.size()) == 2:
|
||||
if len(lora_weight.up.shape) == 4:
|
||||
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
change = torch.mm(lora_weight.up, lora_weight.down)
|
||||
elif lora_weight.down.size()[2:4] == (1, 1):
|
||||
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
change = torch.nn.functional.conv2d(
|
||||
lora_weight.down.permute(1, 0, 2, 3),
|
||||
lora_weight.up,
|
||||
).permute(1, 0, 2, 3)
|
||||
|
||||
curr_layer.weight.data += change * scale
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
|
||||
extensions = [".bin", ".safetensors", ".pt"]
|
||||
if not any([extension in use_lora for extension in extensions]):
|
||||
# We assume if it is a HF ID with standalone LoRA weights.
|
||||
unet.load_attn_procs(use_lora)
|
||||
return unet
|
||||
|
||||
main_file_name = get_path_stem(use_lora)
|
||||
if ".bin" in use_lora:
|
||||
main_file_name += ".bin"
|
||||
elif ".safetensors" in use_lora:
|
||||
main_file_name += ".safetensors"
|
||||
elif ".pt" in use_lora:
|
||||
main_file_name += ".pt"
|
||||
else:
|
||||
sys.exit("Only .bin and .safetensors format for LoRA is supported")
|
||||
|
||||
try:
|
||||
dir_name = os.path.dirname(use_lora)
|
||||
unet.load_attn_procs(dir_name, weight_name=main_file_name)
|
||||
return unet
|
||||
except:
|
||||
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)
|
||||
|
||||
|
||||
def update_lora_weight(model, use_lora, model_name, lora_strength=1.0):
|
||||
if "unet" in model_name:
|
||||
return update_lora_weight_for_unet(model, use_lora, lora_strength)
|
||||
try:
|
||||
return processLoRA(model, use_lora, "lora_te_", lora_strength)
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def get_lora_metadata(lora_filename):
|
||||
# get the metadata from the file
|
||||
filename = get_checkpoint_pathfile(lora_filename, "lora")
|
||||
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
|
||||
# guard clause for if there isn't any metadata
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
# metadata is a dictionary of strings, the values of the keys we're
|
||||
# interested in are actually json, and need to be loaded as such
|
||||
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
|
||||
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
|
||||
tag_dirs = [dir for dir in tag_frequencies.keys()]
|
||||
|
||||
# gather the tag frequency information for all the datasets trained
|
||||
all_frequencies = {}
|
||||
for dataset in tag_dirs:
|
||||
frequencies = sorted(
|
||||
[entry for entry in tag_frequencies[dataset].items()],
|
||||
reverse=True,
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
|
||||
# get a figure for the total number of images processed for this dataset
|
||||
# either then number actually listed or in its dataset_dir entry or
|
||||
# the highest frequency's number if that doesn't exist
|
||||
img_count = dataset_dirs.get(dir, {}).get("img_count", frequencies[0][1])
|
||||
|
||||
# add the dataset frequencies to the overall frequencies replacing the
|
||||
# frequency counts on the tags with a percentage/ratio
|
||||
all_frequencies.update(
|
||||
[(entry[0], entry[1] / img_count) for entry in frequencies]
|
||||
)
|
||||
|
||||
trained_model_id = " ".join(
|
||||
[
|
||||
metadata.get("ss_sd_model_hash", ""),
|
||||
metadata.get("ss_sd_model_name", ""),
|
||||
metadata.get("ss_base_model_version", ""),
|
||||
]
|
||||
).strip()
|
||||
|
||||
# return the topmost <count> of all frequencies in all datasets
|
||||
return {
|
||||
"model": trained_model_id,
|
||||
"frequencies": sorted(
|
||||
all_frequencies.items(), reverse=True, key=lambda x: x[1]
|
||||
),
|
||||
}
|
||||
@@ -1,202 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from csv import DictWriter
|
||||
from PIL import Image, PngImagePlugin
|
||||
from pathlib import Path
|
||||
from datetime import datetime as dt
|
||||
from base64 import decode
|
||||
|
||||
|
||||
resamplers = {
|
||||
"Lanczos": Image.Resampling.LANCZOS,
|
||||
"Nearest Neighbor": Image.Resampling.NEAREST,
|
||||
"Bilinear": Image.Resampling.BILINEAR,
|
||||
"Bicubic": Image.Resampling.BICUBIC,
|
||||
"Hamming": Image.Resampling.HAMMING,
|
||||
"Box": Image.Resampling.BOX,
|
||||
}
|
||||
|
||||
resampler_list = resamplers.keys()
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed, extra_info=None):
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
if extra_info is None:
|
||||
extra_info = {}
|
||||
generated_imgs_path = Path(
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompt"][0][:15])
|
||||
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
|
||||
|
||||
img_model = extra_info["base_model_id"]
|
||||
if extra_info["custom_weights"] not in [None, "None"]:
|
||||
img_model = Path(os.path.basename(extra_info["custom_weights"])).stem
|
||||
|
||||
img_vae = None
|
||||
if extra_info["custom_vae"]:
|
||||
img_vae = Path(os.path.basename(extra_info["custom_vae"])).stem
|
||||
|
||||
img_loras = None
|
||||
if extra_info["embeddings"]:
|
||||
img_lora = []
|
||||
for i in extra_info["embeddings"]:
|
||||
img_lora += Path(os.path.basename(cmd_opts.use_lora)).stem
|
||||
img_loras = ", ".join(img_lora)
|
||||
|
||||
if cmd_opts.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if cmd_opts.write_metadata_to_png:
|
||||
# Using a conditional expression caused problems, so setting a new
|
||||
# variable for now.
|
||||
# if cmd_opts.use_hiresfix:
|
||||
# png_size_text = (
|
||||
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
|
||||
# )
|
||||
# else:
|
||||
png_size_text = f"{extra_info['width']}x{extra_info['height']}"
|
||||
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{extra_info['prompt'][0]}"
|
||||
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
|
||||
f"\nSteps: {extra_info['steps']},"
|
||||
f"Sampler: {extra_info['scheduler']}, "
|
||||
f"CFG scale: {extra_info['guidance_scale']}, "
|
||||
f"Seed: {img_seed},"
|
||||
f"Size: {png_size_text}, "
|
||||
f"Model: {img_model}, "
|
||||
f"VAE: {img_vae}, "
|
||||
f"LoRA: {img_loras}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
if cmd_opts.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {cmd_opts.output_img_format} is not "
|
||||
f"supported yet. Image saved as png instead."
|
||||
f"Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
# To be as low-impact as possible to the existing CSV format, we append
|
||||
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
|
||||
# importance for each data point. Something to consider.
|
||||
new_entry = {}
|
||||
|
||||
new_entry.update(extra_info)
|
||||
|
||||
csv_mode = "a" if os.path.isfile(csv_path) else "w"
|
||||
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
if csv_mode == "w":
|
||||
dictwriter_obj.writeheader()
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
|
||||
# For stencil, the input image can be of any size, but we need to ensure that
|
||||
# it conforms with our model constraints :-
|
||||
# Both width and height should be in the range of [128, 768] and multiple of 8.
|
||||
# This utility function performs the transformation on the input image while
|
||||
# also maintaining the aspect ratio before sending it to the stencil pipeline.
|
||||
def resize_stencil(image: Image.Image, width, height, resampler_type=None):
|
||||
aspect_ratio = width / height
|
||||
min_size = min(width, height)
|
||||
if min_size < 128:
|
||||
n_size = 128
|
||||
if width == min_size:
|
||||
width = n_size
|
||||
height = n_size / aspect_ratio
|
||||
else:
|
||||
height = n_size
|
||||
width = n_size * aspect_ratio
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
n_width = width // 8
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
|
||||
min_size = min(width, height)
|
||||
if min_size > 768:
|
||||
n_size = 768
|
||||
if width == min_size:
|
||||
height = n_size
|
||||
width = n_size * aspect_ratio
|
||||
else:
|
||||
width = n_size
|
||||
height = n_size / aspect_ratio
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
n_width = width // 8
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
if resampler_type in resamplers:
|
||||
resampler = resamplers[resampler_type]
|
||||
else:
|
||||
resampler = resamplers["Nearest Neighbor"]
|
||||
new_image = image.resize((n_width, n_height), resampler=resampler)
|
||||
return new_image, n_width, n_height
|
||||
|
||||
|
||||
def process_sd_init_image(self, sd_init_image, resample_type):
|
||||
if isinstance(sd_init_image, list):
|
||||
images = []
|
||||
for img in sd_init_image:
|
||||
img, _ = self.process_sd_init_image(img, resample_type)
|
||||
images.append(img)
|
||||
is_img2img = True
|
||||
return images, is_img2img
|
||||
if isinstance(sd_init_image, str):
|
||||
if os.path.isfile(sd_init_image):
|
||||
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
|
||||
image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type)
|
||||
else:
|
||||
image = None
|
||||
is_img2img = False
|
||||
elif isinstance(sd_init_image, Image.Image):
|
||||
image = sd_init_image.convert("RGB")
|
||||
elif sd_init_image:
|
||||
image = sd_init_image["image"].convert("RGB")
|
||||
else:
|
||||
image = None
|
||||
is_img2img = False
|
||||
if image:
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
image = image.resize((self.width, self.height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
is_img2img = True
|
||||
image = image_arr
|
||||
return image, is_img2img
|
||||
@@ -1,37 +0,0 @@
|
||||
import sys
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, filename, filter=None):
|
||||
self.terminal = sys.stdout
|
||||
self.log = open(filename, "w")
|
||||
self.filter = filter
|
||||
|
||||
def write(self, message):
|
||||
for x in message.split("\n"):
|
||||
if self.filter in x:
|
||||
self.log.write(message)
|
||||
else:
|
||||
self.terminal.write(message)
|
||||
|
||||
def flush(self):
|
||||
self.terminal.flush()
|
||||
self.log.flush()
|
||||
|
||||
def isatty(self):
|
||||
return False
|
||||
|
||||
|
||||
def logger_test(x):
|
||||
print("[LOG] This is a test")
|
||||
print(f"This is another test, without the filter")
|
||||
return x
|
||||
|
||||
|
||||
def read_sd_logs():
|
||||
sys.stdout.flush()
|
||||
with open("amdshark_tmp/sd.log", "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
sys.stdout = Logger("amdshark_tmp/sd.log", filter="[LOG]")
|
||||
@@ -1,205 +0,0 @@
|
||||
from amdshark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
load_vmfb_using_mmap,
|
||||
clean_device_info,
|
||||
get_iree_target_triple,
|
||||
)
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_checkpoints_path,
|
||||
get_resource_path,
|
||||
)
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import (
|
||||
cmd_opts,
|
||||
)
|
||||
from iree import runtime as ireert
|
||||
from pathlib import Path
|
||||
import gc
|
||||
import os
|
||||
|
||||
|
||||
class AMDSharkPipelineBase:
|
||||
# This class is a lightweight base for managing an
|
||||
# inference API class. It should provide methods for:
|
||||
# - compiling a set (model map) of torch IR modules
|
||||
# - preparing weights for an inference job
|
||||
# - loading weights for an inference job
|
||||
# - utilites like benchmarks, tests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_map: dict,
|
||||
base_model_id: str,
|
||||
static_kwargs: dict,
|
||||
device: str,
|
||||
import_mlir: bool = True,
|
||||
):
|
||||
self.model_map = model_map
|
||||
self.pipe_map = {}
|
||||
self.static_kwargs = static_kwargs
|
||||
self.base_model_id = base_model_id
|
||||
self.triple = get_iree_target_triple(device)
|
||||
self.device, self.device_id = clean_device_info(device)
|
||||
self.import_mlir = import_mlir
|
||||
self.iree_module_dict = {}
|
||||
self.tmp_dir = get_resource_path(cmd_opts.tmp_dir)
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.mkdir(self.tmp_dir)
|
||||
self.tempfiles = {}
|
||||
self.pipe_vmfb_path = ""
|
||||
|
||||
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
|
||||
# First checks whether we have .vmfbs precompiled, then populates the map
|
||||
# with the precompiled executables and fetches executables for the rest of the map.
|
||||
# The weights aren't static here anymore so this function should be a part of pipeline
|
||||
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
|
||||
# and your model map is populated with any IR - unique model IDs and their static params,
|
||||
# call this method to get the artifacts associated with your map.
|
||||
self.pipe_id = self.safe_name(pipe_id)
|
||||
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
|
||||
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
|
||||
if submodel == "None":
|
||||
print("\n[LOG] Gathering any pre-compiled artifacts....")
|
||||
for key in self.model_map:
|
||||
self.get_compiled_map(pipe_id, submodel=key)
|
||||
else:
|
||||
self.pipe_map[submodel] = {}
|
||||
self.get_precompiled(self.pipe_id, submodel)
|
||||
ireec_flags = []
|
||||
if submodel in self.iree_module_dict:
|
||||
return
|
||||
elif "vmfb_path" in self.pipe_map[submodel]:
|
||||
return
|
||||
elif submodel not in self.tempfiles:
|
||||
print(
|
||||
f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..."
|
||||
)
|
||||
if submodel in self.static_kwargs:
|
||||
init_kwargs = self.static_kwargs[submodel]
|
||||
for key in self.static_kwargs["pipe"]:
|
||||
if key not in init_kwargs:
|
||||
init_kwargs[key] = self.static_kwargs["pipe"][key]
|
||||
self.import_torch_ir(submodel, init_kwargs)
|
||||
self.get_compiled_map(pipe_id, submodel)
|
||||
else:
|
||||
ireec_flags = (
|
||||
self.model_map[submodel]["ireec_flags"]
|
||||
if "ireec_flags" in self.model_map[submodel]
|
||||
else []
|
||||
)
|
||||
|
||||
weights_path = self.get_io_params(submodel)
|
||||
if weights_path:
|
||||
ireec_flags.append("--iree-opt-const-eval=False")
|
||||
|
||||
self.iree_module_dict[submodel] = get_iree_compiled_module(
|
||||
self.tempfiles[submodel],
|
||||
device=self.device,
|
||||
frontend="torch",
|
||||
mmap=True,
|
||||
external_weight_file=weights_path,
|
||||
extra_args=ireec_flags,
|
||||
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
|
||||
)
|
||||
return
|
||||
|
||||
def get_io_params(self, submodel):
|
||||
if "external_weight_file" in self.static_kwargs[submodel]:
|
||||
# we are using custom weights
|
||||
weights_path = self.static_kwargs[submodel]["external_weight_file"]
|
||||
elif "external_weight_path" in self.static_kwargs[submodel]:
|
||||
# we are using the default weights for the HF model
|
||||
weights_path = self.static_kwargs[submodel]["external_weight_path"]
|
||||
else:
|
||||
# assume the torch IR contains the weights.
|
||||
weights_path = None
|
||||
return weights_path
|
||||
|
||||
def get_precompiled(self, pipe_id, submodel="None"):
|
||||
if submodel == "None":
|
||||
for model in self.model_map:
|
||||
self.get_precompiled(pipe_id, model)
|
||||
vmfbs = []
|
||||
for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path):
|
||||
vmfbs.extend(filenames)
|
||||
break
|
||||
for file in vmfbs:
|
||||
if submodel in file:
|
||||
self.pipe_map[submodel]["vmfb_path"] = os.path.join(
|
||||
self.pipe_vmfb_path, file
|
||||
)
|
||||
return
|
||||
|
||||
def import_torch_ir(self, submodel, kwargs):
|
||||
torch_ir = self.model_map[submodel]["initializer"](
|
||||
**self.safe_dict(kwargs), compile_to="torch"
|
||||
)
|
||||
if submodel == "clip":
|
||||
# clip.export_clip_model returns (torch_ir, tokenizer)
|
||||
torch_ir = torch_ir[0]
|
||||
|
||||
self.tempfiles[submodel] = os.path.join(
|
||||
self.tmp_dir, f"{submodel}.torch.tempfile"
|
||||
)
|
||||
|
||||
with open(self.tempfiles[submodel], "w+") as f:
|
||||
f.write(torch_ir)
|
||||
del torch_ir
|
||||
gc.collect()
|
||||
return
|
||||
|
||||
def load_submodels(self, submodels: list):
|
||||
for submodel in submodels:
|
||||
if submodel in self.iree_module_dict:
|
||||
print(f"\n[LOG] {submodel} is ready for inference.")
|
||||
continue
|
||||
if "vmfb_path" in self.pipe_map[submodel]:
|
||||
weights_path = self.get_io_params(submodel)
|
||||
# print(
|
||||
# f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
|
||||
# )
|
||||
self.iree_module_dict[submodel] = {}
|
||||
(
|
||||
self.iree_module_dict[submodel]["vmfb"],
|
||||
self.iree_module_dict[submodel]["config"],
|
||||
self.iree_module_dict[submodel]["temp_file_to_unlink"],
|
||||
) = load_vmfb_using_mmap(
|
||||
self.pipe_map[submodel]["vmfb_path"],
|
||||
self.device,
|
||||
device_idx=0,
|
||||
rt_flags=[],
|
||||
external_weight_file=weights_path,
|
||||
)
|
||||
else:
|
||||
self.get_compiled_map(self.pipe_id, submodel)
|
||||
return
|
||||
|
||||
def unload_submodels(self, submodels: list):
|
||||
for submodel in submodels:
|
||||
if submodel in self.iree_module_dict:
|
||||
del self.iree_module_dict[submodel]
|
||||
gc.collect()
|
||||
return
|
||||
|
||||
def run(self, submodel, inputs):
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
inp = [
|
||||
ireert.asdevicearray(
|
||||
self.iree_module_dict[submodel]["config"].device, input
|
||||
)
|
||||
for input in inputs
|
||||
]
|
||||
return self.iree_module_dict[submodel]["vmfb"]["main"](*inp)
|
||||
|
||||
def safe_name(self, name):
|
||||
return name.replace("/", "_").replace("-", "_").replace("\\", "_")
|
||||
|
||||
def safe_dict(self, kwargs: dict):
|
||||
flat_args = {}
|
||||
for i in kwargs:
|
||||
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
|
||||
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
|
||||
else:
|
||||
flat_args[i] = kwargs[i]
|
||||
|
||||
return flat_args
|
||||
@@ -1,376 +0,0 @@
|
||||
from typing import List, Optional, Union
|
||||
from iree import runtime as ireert
|
||||
import re
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
re_attention = re.compile(
|
||||
r"""
|
||||
\\\(|
|
||||
\\\)|
|
||||
\\\[|
|
||||
\\]|
|
||||
\\\\|
|
||||
\\|
|
||||
\(|
|
||||
\[|
|
||||
:([+-]?[.\d]+)\)|
|
||||
\)|
|
||||
]|
|
||||
[^\\()\[\]:]+|
|
||||
:
|
||||
""",
|
||||
re.X,
|
||||
)
|
||||
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
Parses a string with attention tokens and returns a list of pairs:
|
||||
text and its associated weight.
|
||||
Accepted tokens are:
|
||||
(abc) - increases attention to abc by a multiplier of 1.1
|
||||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
[abc] - decreases attention to abc by a multiplier of 1.1
|
||||
\( - literal character '('
|
||||
\[ - literal character '['
|
||||
\) - literal character ')'
|
||||
\] - literal character ']'
|
||||
\\ - literal character '\'
|
||||
anything else - just text
|
||||
>>> parse_prompt_attention('normal text')
|
||||
[['normal text', 1.0]]
|
||||
>>> parse_prompt_attention('an (important) word')
|
||||
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
>>> parse_prompt_attention('(unbalanced')
|
||||
[['unbalanced', 1.1]]
|
||||
>>> parse_prompt_attention('\(literal\]')
|
||||
[['(literal]', 1.0]]
|
||||
>>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
[['unnecessaryparens', 1.1]]
|
||||
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
[['a ', 1.0],
|
||||
['house', 1.5730000000000004],
|
||||
[' ', 1.1],
|
||||
['on', 1.0],
|
||||
[' a ', 1.1],
|
||||
['hill', 0.55],
|
||||
[', sun, ', 1.1],
|
||||
['sky', 1.4641000000000006],
|
||||
['.', 1.1]]
|
||||
"""
|
||||
|
||||
res = []
|
||||
round_brackets = []
|
||||
square_brackets = []
|
||||
|
||||
round_bracket_multiplier = 1.1
|
||||
square_bracket_multiplier = 1 / 1.1
|
||||
|
||||
def multiply_range(start_position, multiplier):
|
||||
for p in range(start_position, len(res)):
|
||||
res[p][1] *= multiplier
|
||||
|
||||
for m in re_attention.finditer(text):
|
||||
text = m.group(0)
|
||||
weight = m.group(1)
|
||||
|
||||
if text.startswith("\\"):
|
||||
res.append([text[1:], 1.0])
|
||||
elif text == "(":
|
||||
round_brackets.append(len(res))
|
||||
elif text == "[":
|
||||
square_brackets.append(len(res))
|
||||
elif weight is not None and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), float(weight))
|
||||
elif text == ")" and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||
elif text == "]" and len(square_brackets) > 0:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
res.append([text, 1.0])
|
||||
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
|
||||
for pos in square_brackets:
|
||||
multiply_range(pos, square_bracket_multiplier)
|
||||
|
||||
if len(res) == 0:
|
||||
res = [["", 1.0]]
|
||||
|
||||
# merge runs of identical weights
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
|
||||
r"""
|
||||
Tokenize a list of prompts and return its tokens with weights of each token.
|
||||
No padding, starting or ending token is included.
|
||||
"""
|
||||
tokens = []
|
||||
weights = []
|
||||
truncated = False
|
||||
for text in prompt:
|
||||
texts_and_weights = parse_prompt_attention(text)
|
||||
text_token = []
|
||||
text_weight = []
|
||||
for word, weight in texts_and_weights:
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||
text_token += token
|
||||
# copy the weight by length of token
|
||||
text_weight += [weight] * len(token)
|
||||
# stop if the text is too long (longer than truncation limit)
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
break
|
||||
# truncate
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
text_token = text_token[:max_length]
|
||||
text_weight = text_weight[:max_length]
|
||||
tokens.append(text_token)
|
||||
weights.append(text_weight)
|
||||
if truncated:
|
||||
print(
|
||||
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
|
||||
)
|
||||
return tokens, weights
|
||||
|
||||
|
||||
def pad_tokens_and_weights(
|
||||
tokens,
|
||||
weights,
|
||||
max_length,
|
||||
bos,
|
||||
eos,
|
||||
no_boseos_middle=True,
|
||||
chunk_length=77,
|
||||
):
|
||||
r"""
|
||||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||
"""
|
||||
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
||||
weights_length = (
|
||||
max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
||||
)
|
||||
for i in range(len(tokens)):
|
||||
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
||||
if no_boseos_middle:
|
||||
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
||||
else:
|
||||
w = []
|
||||
if len(weights[i]) == 0:
|
||||
w = [1.0] * weights_length
|
||||
else:
|
||||
for j in range(max_embeddings_multiples):
|
||||
w.append(1.0) # weight for starting token in this chunk
|
||||
w += weights[i][
|
||||
j
|
||||
* (chunk_length - 2) : min(
|
||||
len(weights[i]), (j + 1) * (chunk_length - 2)
|
||||
)
|
||||
]
|
||||
w.append(1.0) # weight for ending token in this chunk
|
||||
w += [1.0] * (weights_length - len(w))
|
||||
weights[i] = w[:]
|
||||
|
||||
return tokens, weights
|
||||
|
||||
|
||||
def get_unweighted_text_embeddings(
|
||||
pipe,
|
||||
text_input,
|
||||
chunk_length: int,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
it should be split into chunks and sent to the text encoder individually.
|
||||
"""
|
||||
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
||||
if max_embeddings_multiples > 1:
|
||||
text_embeddings = []
|
||||
for i in range(max_embeddings_multiples):
|
||||
# extract the i-th chunk
|
||||
text_input_chunk = text_input[
|
||||
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
|
||||
].clone()
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
|
||||
text_embedding = pipe.run("clip", text_input_chunk)[0].to_host()
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
text_embedding = text_embedding[:, :-1]
|
||||
elif i == max_embeddings_multiples - 1:
|
||||
# discard the starting token
|
||||
text_embedding = text_embedding[:, 1:]
|
||||
else:
|
||||
# discard both starting and ending tokens
|
||||
text_embedding = text_embedding[:, 1:-1]
|
||||
|
||||
text_embeddings.append(text_embedding)
|
||||
# AMDSHARK: Convert the result to tensor
|
||||
# text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
text_embeddings_np = np.concatenate(np.array(text_embeddings))
|
||||
text_embeddings = torch.from_numpy(text_embeddings_np)
|
||||
else:
|
||||
text_embeddings = pipe.run("clip", text_input)[0]
|
||||
text_embeddings = torch.from_numpy(text_embeddings.to_host())
|
||||
return text_embeddings
|
||||
|
||||
|
||||
# This function deals with NoneType values occuring in tokens after padding
|
||||
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
|
||||
def filter_nonetype_tokens(tokens: List[List]):
|
||||
return [[49407 if token is None else token for token in tokens[0]]]
|
||||
|
||||
|
||||
def get_weighted_text_embeddings(
|
||||
pipe,
|
||||
prompt: List[str],
|
||||
uncond_prompt: List[str] = None,
|
||||
max_embeddings_multiples: Optional[int] = 8,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
skip_parsing: Optional[bool] = False,
|
||||
skip_weighting: Optional[bool] = False,
|
||||
):
|
||||
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
|
||||
if not skip_parsing:
|
||||
prompt_tokens, prompt_weights = get_prompts_with_weights(
|
||||
pipe, prompt, max_length - 2
|
||||
)
|
||||
if uncond_prompt is not None:
|
||||
uncond_tokens, uncond_weights = get_prompts_with_weights(
|
||||
pipe, uncond_prompt, max_length - 2
|
||||
)
|
||||
else:
|
||||
prompt_tokens = [
|
||||
token[1:-1]
|
||||
for token in pipe.tokenizer(
|
||||
prompt, max_length=max_length, truncation=True
|
||||
).input_ids
|
||||
]
|
||||
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
||||
if uncond_prompt is not None:
|
||||
if isinstance(uncond_prompt, str):
|
||||
uncond_prompt = [uncond_prompt]
|
||||
uncond_tokens = [
|
||||
token[1:-1]
|
||||
for token in pipe.tokenizer(
|
||||
uncond_prompt, max_length=max_length, truncation=True
|
||||
).input_ids
|
||||
]
|
||||
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
||||
|
||||
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
||||
max_length = max([len(token) for token in prompt_tokens])
|
||||
if uncond_prompt is not None:
|
||||
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
||||
max_embeddings_multiples = min(
|
||||
max_embeddings_multiples,
|
||||
(max_length - 1) // (pipe.model_max_length - 2) + 1,
|
||||
)
|
||||
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
||||
|
||||
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
|
||||
# pad the length of tokens and weights
|
||||
bos = pipe.tokenizer.bos_token_id
|
||||
eos = pipe.tokenizer.eos_token_id
|
||||
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
||||
prompt_tokens,
|
||||
prompt_weights,
|
||||
max_length,
|
||||
bos,
|
||||
eos,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
chunk_length=pipe.model_max_length,
|
||||
)
|
||||
|
||||
# FIXME: This is a hacky fix caused by tokenizer padding with None values
|
||||
prompt_tokens = filter_nonetype_tokens(prompt_tokens)
|
||||
|
||||
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
||||
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
|
||||
if uncond_prompt is not None:
|
||||
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
||||
uncond_tokens,
|
||||
uncond_weights,
|
||||
max_length,
|
||||
bos,
|
||||
eos,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
chunk_length=pipe.model_max_length,
|
||||
)
|
||||
|
||||
# FIXME: This is a hacky fix caused by tokenizer padding with None values
|
||||
uncond_tokens = filter_nonetype_tokens(uncond_tokens)
|
||||
|
||||
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
||||
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device="cpu")
|
||||
|
||||
# get the embeddings
|
||||
text_embeddings = get_unweighted_text_embeddings(
|
||||
pipe,
|
||||
prompt_tokens,
|
||||
pipe.model_max_length,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
||||
prompt_weights = torch.tensor(prompt_weights, dtype=torch.float, device="cpu")
|
||||
if uncond_prompt is not None:
|
||||
uncond_embeddings = get_unweighted_text_embeddings(
|
||||
pipe,
|
||||
uncond_tokens,
|
||||
pipe.model_max_length,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
||||
uncond_weights = torch.tensor(uncond_weights, dtype=torch.float, device="cpu")
|
||||
|
||||
# assign weights to the prompts and normalize in the sense of mean
|
||||
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
||||
if (not skip_parsing) and (not skip_weighting):
|
||||
previous_mean = (
|
||||
text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
)
|
||||
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||
current_mean = (
|
||||
text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
)
|
||||
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
if uncond_prompt is not None:
|
||||
previous_mean = (
|
||||
uncond_embeddings.float()
|
||||
.mean(axis=[-2, -1])
|
||||
.to(uncond_embeddings.dtype)
|
||||
)
|
||||
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||
current_mean = (
|
||||
uncond_embeddings.float()
|
||||
.mean(axis=[-2, -1])
|
||||
.to(uncond_embeddings.dtype)
|
||||
)
|
||||
uncond_embeddings *= (
|
||||
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
)
|
||||
|
||||
if uncond_prompt is not None:
|
||||
return text_embeddings, uncond_embeddings
|
||||
return text_embeddings, None
|
||||
@@ -1,118 +0,0 @@
|
||||
# from amdshark_turbine.turbine_models.schedulers import export_scheduler_model
|
||||
from diffusers import (
|
||||
LCMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDPMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
|
||||
|
||||
def get_schedulers(model_id):
|
||||
# TODO: switch over to turbine and run all on GPU
|
||||
print(f"\n[LOG] Initializing schedulers from model id: {model_id}")
|
||||
schedulers = dict()
|
||||
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
# schedulers["DDPM"] = DDPMScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
# model_id, subfolder="scheduler", algorithm_type="dpmsolver"
|
||||
# )
|
||||
# schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
# model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
|
||||
# )
|
||||
# schedulers["DPMSolverMultistepKarras"] = (
|
||||
# DPMSolverMultistepScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# use_karras_sigmas=True,
|
||||
# )
|
||||
# )
|
||||
# schedulers["DPMSolverMultistepKarras++"] = (
|
||||
# DPMSolverMultistepScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# algorithm_type="dpmsolver++",
|
||||
# use_karras_sigmas=True,
|
||||
# )
|
||||
# )
|
||||
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["EulerAncestralDiscrete"] = (
|
||||
EulerAncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
)
|
||||
# schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["KDPM2AncestralDiscrete"] = (
|
||||
# KDPM2AncestralDiscreteScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# )
|
||||
# schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
return schedulers
|
||||
|
||||
|
||||
def export_scheduler_model(model):
|
||||
return "None", "None"
|
||||
|
||||
|
||||
scheduler_model_map = {
|
||||
"PNDM": export_scheduler_model("PNDMScheduler"),
|
||||
# "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
|
||||
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
|
||||
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
|
||||
# "LCM": export_scheduler_model("LCMScheduler"),
|
||||
# "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
|
||||
# "DDPM": export_scheduler_model("DDPMScheduler"),
|
||||
# "DDIM": export_scheduler_model("DDIMScheduler"),
|
||||
# "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
|
||||
# "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
|
||||
# "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
|
||||
# "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
|
||||
# "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
|
||||
# "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
import numpy as np
|
||||
import json
|
||||
from random import (
|
||||
randint,
|
||||
seed as seed_random,
|
||||
getstate as random_getstate,
|
||||
setstate as random_setstate,
|
||||
)
|
||||
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the
|
||||
# supported range (including -1)
|
||||
def sanitize_seed(seed: int | str):
|
||||
seed = int(seed)
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
return seed
|
||||
|
||||
|
||||
# take a seed expression in an input format and convert it to
|
||||
# a list of integers, where possible
|
||||
def parse_seed_input(seed_input: str | list | int):
|
||||
if isinstance(seed_input, str):
|
||||
try:
|
||||
seed_input = json.loads(seed_input)
|
||||
except (ValueError, TypeError):
|
||||
seed_input = None
|
||||
|
||||
if isinstance(seed_input, int):
|
||||
return [seed_input]
|
||||
|
||||
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
|
||||
return seed_input
|
||||
|
||||
raise TypeError(
|
||||
"Seed input must be an integer or an array of integers in JSON format"
|
||||
)
|
||||
|
||||
|
||||
# Generate a set of seeds from an input expression for batch_count batches,
|
||||
# optionally using that input as the rng seed for any randomly generated seeds.
|
||||
def batch_seeds(seed_input: str | list | int, batch_count: int, repeatable=False):
|
||||
# turn the input into a list if possible
|
||||
seeds = parse_seed_input(seed_input)
|
||||
|
||||
# slice or pad the list to be of batch_count length
|
||||
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
|
||||
|
||||
if repeatable:
|
||||
if all(seed < 0 for seed in seeds):
|
||||
seeds[0] = sanitize_seed(seeds[0])
|
||||
|
||||
# set seed for the rng based on what we have so far
|
||||
saved_random_state = random_getstate()
|
||||
seed_random(str([n for n in seeds if n > -1]))
|
||||
|
||||
# generate any seeds that are unspecified
|
||||
seeds = [sanitize_seed(seed) for seed in seeds]
|
||||
|
||||
if repeatable:
|
||||
# reset the rng back to normal
|
||||
random_setstate(saved_random_state)
|
||||
|
||||
return seeds
|
||||
@@ -1,791 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from apps.amdshark_studio.modules.img_processing import resampler_list
|
||||
|
||||
|
||||
def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
|
||||
|
||||
def is_valid_file(arg):
|
||||
if not os.path.exists(arg):
|
||||
return None
|
||||
else:
|
||||
return arg
|
||||
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"-a",
|
||||
"--app",
|
||||
default="txt2img",
|
||||
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
|
||||
)
|
||||
p.add_argument(
|
||||
"-p",
|
||||
"--prompt",
|
||||
nargs="+",
|
||||
default=[
|
||||
"a photo taken of the front of a super-car drifting on a road near "
|
||||
"mountains at high speeds with smoke coming off the tires, front "
|
||||
"angle, front point of view, trees in the mountains of the "
|
||||
"background, ((sharp focus))"
|
||||
],
|
||||
help="Text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative_prompt",
|
||||
nargs="+",
|
||||
default=[
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
|
||||
"blurry, ugly, blur, oversaturated, cropped"
|
||||
],
|
||||
help="Text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--sd_init_image",
|
||||
type=str,
|
||||
help="Path to the image input for img2img/inpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="The number of steps to do the sampling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=str,
|
||||
default=-1,
|
||||
help="The seed or list of seeds to use. -1 for a random one.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=range(1, 4),
|
||||
help="The number of inferences to be made in a single `batch_count`.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 1025, 8),
|
||||
help="The height of the output image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 1025, 8),
|
||||
help="The width of the output image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="The value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--noise_level",
|
||||
type=int,
|
||||
default=20,
|
||||
help="The value to be used for noise level of upscaler.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Max length of the tokenizer output, options are 64 and 77.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_embeddings_multiples",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The max multiple length of prompt embeddings compared to the max "
|
||||
"output length of text encoder.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="The strength of change applied on the given input image for " "img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_hiresfix",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Use Hires Fix to do higher resolution images, while trying to "
|
||||
"avoid the issues that come with it. This is accomplished by first "
|
||||
"generating an image using txt2img, then running it through img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_height",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The height of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_width",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The width of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_strength",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="The denoising strength to apply for the Hires Fix.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--resample_type",
|
||||
type=str,
|
||||
default="Nearest Neighbor",
|
||||
choices=resampler_list,
|
||||
help="The resample type to use when resizing an image before being run "
|
||||
"through stable diffusion.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Stable Diffusion Training Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--lora_save_dir",
|
||||
type=str,
|
||||
default="models/lora/",
|
||||
help="Directory to save the lora fine tuned model.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--training_images_dir",
|
||||
type=str,
|
||||
default="models/lora/training_images/",
|
||||
help="Directory containing images that are an example of the prompt.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--training_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The number of steps to train.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Inpainting and Outpainting Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--mask_path",
|
||||
type=str,
|
||||
help="Path to the mask image input for inpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--inpaint_full_res",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If inpaint only masked area or whole picture.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--inpaint_full_res_padding",
|
||||
type=int,
|
||||
default=32,
|
||||
choices=range(0, 257, 4),
|
||||
help="Number of pixels for only masked padding.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--pixels",
|
||||
type=int,
|
||||
default=128,
|
||||
choices=range(8, 257, 8),
|
||||
help="Number of expended pixels for one direction for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--mask_blur",
|
||||
type=int,
|
||||
default=8,
|
||||
choices=range(0, 65),
|
||||
help="Number of blur pixels for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--left",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If extend left for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--right",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If extend right for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--up",
|
||||
"--top",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If extend top for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--down",
|
||||
"--bottom",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If extend bottom for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--noise_q",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fall-off exponent for outpainting (lower=higher detail) "
|
||||
"(min=0.0, max=4.0).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--color_variation",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Color variation for outpainting (min=0.0, max=1.0).",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Model Config and Usage Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument("--device", type=str, default="vulkan", help="Device to run the model.")
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp16", help="Precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Imports the model from torch module to amdshark_module otherwise "
|
||||
"downloads the model from amdshark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_base_vae",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Do conversion from the VAE output to pixel space on cpu.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="DDIM",
|
||||
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
|
||||
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
|
||||
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
|
||||
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
|
||||
"HeunDiscrete].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_img_format",
|
||||
type=str,
|
||||
default="png",
|
||||
help="Specify the format in which output image is save. "
|
||||
"Supported options: jpg / png.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=os.path.join(os.getcwd(), "generated_imgs"),
|
||||
help="Directory path to save the output images and json.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--batch_count",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of batches to be generated with random seeds in " "single execution.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--repeatable_seeds",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="The seed of the first batch will be used as the rng seed to "
|
||||
"generate the subsequent seeds for subsequent batches in that run.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--custom_weights",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to a .safetensors or .ckpt file for SD pipeline weights.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--custom_vae",
|
||||
type=str,
|
||||
default="",
|
||||
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
|
||||
"needs to be plugged in.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--base_model_id",
|
||||
type=str,
|
||||
default="stabilityai/stable-diffusion-2-1-base",
|
||||
help="The repo-id of hugging face.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--low_cpu_mem_usage",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use the accelerate package to reduce cpu memory consumption.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--attention_slicing",
|
||||
type=str,
|
||||
default="none",
|
||||
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
|
||||
"or an integer).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_stencil",
|
||||
choices=["canny", "openpose", "scribble", "zoedepth"],
|
||||
help="Enable the stencil feature.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--control_mode",
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
default="Balanced",
|
||||
help="How Controlnet injection should be prioritized.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_lora",
|
||||
type=str,
|
||||
default="",
|
||||
help="Use standalone LoRA weight using a HF ID or a checkpoint " "file (~3 MB).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_quantize",
|
||||
type=str,
|
||||
default="none",
|
||||
help="Runs the quantized version of stable diffusion model. "
|
||||
"This is currently in experimental phase. "
|
||||
"Currently, only runs the stable-diffusion-2-1-base model in "
|
||||
"int8 quantization.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--lowvram",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Load and unload models for low VRAM.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--external_weights",
|
||||
type=str,
|
||||
default=None,
|
||||
help="What type of externalized weights to use. Currently options are 'safetensors' and defaults to inlined weights.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device_allocator_heap_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify heap key for device caching allocator."
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree_vulkan_target_triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_metal_target_platform",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for metal.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Misc. Debug and Optimization flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--use_compiled_scheduler",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use the default scheduler precompiled into the model if available.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
help="Specify where to save downloaded amdshark_tank artifacts. "
|
||||
"If this is not set, the default is ~/.local/amdshark_tank/.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dump_isa",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled call amdllpc to get ISA dumps. " "Use with dispatch benchmarks.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks",
|
||||
default=None,
|
||||
help="Dispatches to return benchmark data on. "
|
||||
'Use "All" for all, and None for none.',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="temp_dispatch_benchmarks",
|
||||
help="Directory where you want to store dispatch data "
|
||||
'generated with "--dispatch_benchmarks".',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--enable_rgp",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for inserting debug frames between iterations " "for use with rgp.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hide_steps",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for hiding the details of iteration/sec for each step.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--warmup_count",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Flag setting warmup count for CLIP and VAE [>= 0].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--clear_all",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to clear all mlir and vmfb from common locations. "
|
||||
"Recompiling will take several minutes.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_metadata_to_json",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether or not to save a generation information "
|
||||
"json file with the image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--write_metadata_to_png",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether or not to save generation information in "
|
||||
"PNG chunk text to generated images.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If import_mlir is True, saves mlir via the debug option "
|
||||
"in amdshark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--compile_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to toggle debug assert/verify flags for imported IR in the"
|
||||
"iree-compiler. Default to false.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_constant_folding",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Controls constant folding in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--data_tiling",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Controls data tiling in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--quantization",
|
||||
type=str,
|
||||
default="None",
|
||||
help="Quantization to be used for api-exposed model.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Web UI flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--webui",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="controls whether the webui is launched.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--progress_bar",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for removing the progress bar animation during " "image generation.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--tmp_dir",
|
||||
type=str,
|
||||
default=os.path.join(os.getcwd(), "amdshark_tmp"),
|
||||
help="Path to tmp directory",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--config_dir",
|
||||
type=str,
|
||||
default=os.path.join(os.getcwd(), "configs"),
|
||||
help="Path to config directory",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--model_dir",
|
||||
type=str,
|
||||
default=os.path.join(os.getcwd(), "models"),
|
||||
help="Path to directory where all .ckpts are stored in order to populate "
|
||||
"them in the web UI.",
|
||||
)
|
||||
|
||||
# TODO: replace API flag when these can be run together
|
||||
p.add_argument(
|
||||
"--ui",
|
||||
type=str,
|
||||
default="app" if os.name == "nt" else "web",
|
||||
help="One of: [api, app, web].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--share",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for generating a public URL.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=8080,
|
||||
help="Flag for setting server port.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--api",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling rest API.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--api_accept_origin",
|
||||
action="append",
|
||||
type=str,
|
||||
help="An origin to be accepted by the REST api for Cross Origin"
|
||||
"Resource Sharing (CORS). Use multiple times for multiple origins, "
|
||||
'or use --api_accept_origin="*" to accept all origins. If no origins '
|
||||
"are set no CORS headers will be returned by the api. Use, for "
|
||||
"instance, if you need to access the REST api from Javascript running "
|
||||
"in a web browser.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling debugging log in WebUI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for removing the output gallery tab, and avoid exposing "
|
||||
"images under --output_dir in the UI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--configs_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to .json config directory.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery_followlinks",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether the output gallery tab in the UI should "
|
||||
"follow symlinks when listing subdirectories under --output_dir.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--api_log",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Enables Compatibility API logging.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# SD model auto-annotation flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_output",
|
||||
type=path_expand,
|
||||
default="./",
|
||||
help="Directory to save the annotated mlir file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_model",
|
||||
type=str,
|
||||
default="unet",
|
||||
help="Options are unet and vae.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_annotation",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Save annotated mlir file.",
|
||||
)
|
||||
##############################################################################
|
||||
# SD model auto-tuner flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--tuned_config_dir",
|
||||
type=path_expand,
|
||||
default="./",
|
||||
help="Directory to save the tuned config file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--num_iters",
|
||||
type=int,
|
||||
default=400,
|
||||
help="Number of iterations for tuning.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--search_op",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Op to be optimized, options are matmul, bmm, conv and all.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# DocuChat Flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--run_docuchat_web",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Specifies whether the docuchat's web version is running or not.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# rocm Flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree_rocm_target_chip",
|
||||
type=str,
|
||||
default="",
|
||||
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
|
||||
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
|
||||
)
|
||||
|
||||
cmd_opts, unknown = p.parse_known_args()
|
||||
if cmd_opts.import_debug:
|
||||
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
|
||||
os.getcwd(), cmd_opts.hf_model_id.replace("/", "_")
|
||||
)
|
||||
@@ -1,106 +0,0 @@
|
||||
import time
|
||||
import argparse
|
||||
|
||||
|
||||
class TimerSubcategory:
|
||||
def __init__(self, timer, category):
|
||||
self.timer = timer
|
||||
self.category = category
|
||||
self.start = None
|
||||
self.original_base_category = timer.base_category
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
self.timer.base_category = self.original_base_category + self.category + "/"
|
||||
self.timer.subcategory_level += 1
|
||||
|
||||
if self.timer.print_log:
|
||||
print(f"{' ' * self.timer.subcategory_level}{self.category}:")
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
elapsed_for_subcategroy = time.time() - self.start
|
||||
self.timer.base_category = self.original_base_category
|
||||
self.timer.add_time_to_record(
|
||||
self.original_base_category + self.category,
|
||||
elapsed_for_subcategroy,
|
||||
)
|
||||
self.timer.subcategory_level -= 1
|
||||
self.timer.record(self.category, disable_log=True)
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self, print_log=False):
|
||||
self.start = time.time()
|
||||
self.records = {}
|
||||
self.total = 0
|
||||
self.base_category = ""
|
||||
self.print_log = print_log
|
||||
self.subcategory_level = 0
|
||||
|
||||
def elapsed(self):
|
||||
end = time.time()
|
||||
res = end - self.start
|
||||
self.start = end
|
||||
return res
|
||||
|
||||
def add_time_to_record(self, category, amount):
|
||||
if category not in self.records:
|
||||
self.records[category] = 0
|
||||
|
||||
self.records[category] += amount
|
||||
|
||||
def record(self, category, extra_time=0, disable_log=False):
|
||||
e = self.elapsed()
|
||||
|
||||
self.add_time_to_record(self.base_category + category, e + extra_time)
|
||||
|
||||
self.total += e + extra_time
|
||||
|
||||
if self.print_log and not disable_log:
|
||||
print(
|
||||
f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s"
|
||||
)
|
||||
|
||||
def subcategory(self, name):
|
||||
self.elapsed()
|
||||
|
||||
subcat = TimerSubcategory(self, name)
|
||||
return subcat
|
||||
|
||||
def summary(self):
|
||||
res = f"{self.total:.1f}s"
|
||||
|
||||
additions = [
|
||||
(category, time_taken)
|
||||
for category, time_taken in self.records.items()
|
||||
if time_taken >= 0.1 and "/" not in category
|
||||
]
|
||||
if not additions:
|
||||
return res
|
||||
|
||||
res += " ("
|
||||
res += ", ".join(
|
||||
[f"{category}: {time_taken:.1f}s" for category, time_taken in additions]
|
||||
)
|
||||
res += ")"
|
||||
|
||||
return res
|
||||
|
||||
def dump(self):
|
||||
return {"total": self.total, "records": self.records}
|
||||
|
||||
def reset(self):
|
||||
self.__init__()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument(
|
||||
"--log-startup",
|
||||
action="store_true",
|
||||
help="print a detailed log of what's happening at startup",
|
||||
)
|
||||
args = parser.parse_known_args()[0]
|
||||
|
||||
startup_timer = Timer(print_log=args.log_startup)
|
||||
|
||||
startup_record = None
|
||||
@@ -1,68 +0,0 @@
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
|
||||
import sys
|
||||
|
||||
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
# python path for pyinstaller
|
||||
pathex = [
|
||||
".",
|
||||
]
|
||||
|
||||
# datafiles for pyinstaller
|
||||
datas = []
|
||||
datas += copy_metadata("torch")
|
||||
datas += copy_metadata("tokenizers")
|
||||
datas += copy_metadata("tqdm")
|
||||
datas += copy_metadata("regex")
|
||||
datas += copy_metadata("requests")
|
||||
datas += copy_metadata("packaging")
|
||||
datas += copy_metadata("filelock")
|
||||
datas += copy_metadata("numpy")
|
||||
datas += copy_metadata("importlib_metadata")
|
||||
datas += copy_metadata("omegaconf")
|
||||
datas += copy_metadata("safetensors")
|
||||
datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += copy_metadata("gradio")
|
||||
datas += copy_metadata("scipy")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files("diffusers")
|
||||
datas += collect_data_files("transformers")
|
||||
datas += collect_data_files("gradio")
|
||||
datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree", include_py_files=True)
|
||||
datas += collect_data_files("amdshark", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("scipy", include_py_files=True)
|
||||
datas += [
|
||||
("web/ui/css/*", "ui/css"),
|
||||
("web/ui/js/*", "ui/js"),
|
||||
("web/ui/logos/*", "logos"),
|
||||
]
|
||||
|
||||
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["amdshark", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
for x in collect_submodules("transformers")
|
||||
if not any(kw in x for kw in blacklist)
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
|
||||
hiddenimports += ["iree._runtime"]
|
||||
hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x]
|
||||
@@ -1,58 +0,0 @@
|
||||
# Copyright 2023 Nod Labs, Inc
|
||||
#
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
import json
|
||||
import gc
|
||||
from apps.amdshark_studio.api.llm import LanguageModel, llm_chat_api
|
||||
from apps.amdshark_studio.api.sd import amdshark_sd_fn_dict_input, view_json_file
|
||||
from apps.amdshark_studio.web.utils.file_utils import get_resource_path
|
||||
|
||||
# class SDAPITest(unittest.TestCase):
|
||||
# def testSDSimple(self):
|
||||
# from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
# import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
# global_obj._init()
|
||||
|
||||
# sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
|
||||
# sd_kwargs = json.loads(sd_json)
|
||||
# for arg in vars(cmd_opts):
|
||||
# if arg in sd_kwargs:
|
||||
# sd_kwargs[arg] = getattr(cmd_opts, arg)
|
||||
# for i in amdshark_sd_fn_dict_input(sd_kwargs):
|
||||
# print(i)
|
||||
|
||||
|
||||
class LLMAPITest(unittest.TestCase):
|
||||
def test01_LLMSmall(self):
|
||||
lm = LanguageModel(
|
||||
"TinyPixel/small-llama2",
|
||||
hf_auth_token=None,
|
||||
device="cpu",
|
||||
precision="fp32",
|
||||
quantization="None",
|
||||
streaming_llm=True,
|
||||
)
|
||||
count = 0
|
||||
label = "Turkishoure Turkish"
|
||||
for msg, _ in lm.chat("hi, what are you?"):
|
||||
# skip first token output
|
||||
if count == 0:
|
||||
count += 1
|
||||
continue
|
||||
assert (
|
||||
msg.strip(" ") == label
|
||||
), f"LLM API failed to return correct response, expected '{label}', received {msg}"
|
||||
break
|
||||
del lm
|
||||
gc.collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
unittest.main()
|
||||
@@ -1,41 +0,0 @@
|
||||
import torch
|
||||
from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, hf_model_name):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
hf_model_name,
|
||||
subfolder="unet",
|
||||
)
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
|
||||
samples = torch.cat([sample] * 2)
|
||||
unet_out = self.unet.forward(
|
||||
samples, timestep, encoder_hidden_states, return_dict=False
|
||||
)[0]
|
||||
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
hf_model_name = "CompVis/stable-diffusion-v1-4"
|
||||
unet = UnetModel(hf_model_name)
|
||||
inputs = (torch.randn(1, 4, 64, 64), 1, torch.randn(2, 77, 768), 7.5)
|
||||
|
||||
fx_g = make_fx(
|
||||
unet,
|
||||
decomposition_table={},
|
||||
tracing_mode="symbolic",
|
||||
_allow_non_fake_inputs=True,
|
||||
_allow_fake_constant=False,
|
||||
)(*inputs)
|
||||
|
||||
print(fx_g)
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 347 KiB |
@@ -1,45 +0,0 @@
|
||||
import requests
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import json
|
||||
|
||||
|
||||
def llm_chat_test(verbose=False):
|
||||
# Define values here
|
||||
prompt = "What is the significance of the number 42?"
|
||||
|
||||
url = "http://127.0.0.1:8080/v1/chat/completions"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
"device": "vulkan://0",
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
res_dict = json.loads(res.content.decode("utf-8"))
|
||||
print(f"[chat] response from server was : {res.status_code} {res.reason}")
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res_dict['choices'][0]['message']['content']}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# "Exercises the chatbot REST API of AMDShark. Make sure "
|
||||
# "AMDShark is running in API mode on 127.0.0.1:8080 before running"
|
||||
# "this script."
|
||||
|
||||
llm_chat_test(verbose=True)
|
||||
@@ -1,286 +0,0 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
import uvicorn
|
||||
import ipaddress
|
||||
import requests
|
||||
import threading
|
||||
import collections
|
||||
import gradio as gr
|
||||
from PIL import Image, PngImagePlugin
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
# from sdapi_v1 import amdshark_sd_api
|
||||
from apps.amdshark_studio.api.llm import llm_chat_api
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("http://") or encoding.startswith("https://"):
|
||||
headers = {}
|
||||
response = requests.get(encoding, timeout=30, headers=headers)
|
||||
try:
|
||||
image = Image.open(BytesIO(response.content))
|
||||
return image
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid image url") from e
|
||||
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
image.save(
|
||||
output_bytes,
|
||||
format="PNG",
|
||||
pnginfo=(metadata if use_metadata else None),
|
||||
)
|
||||
|
||||
bytes_data = output_bytes.getvalue()
|
||||
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
|
||||
# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
|
||||
class FIFOLock(object):
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._inner_lock = threading.Lock()
|
||||
self._pending_threads = collections.deque()
|
||||
|
||||
def acquire(self, blocking=True):
|
||||
with self._inner_lock:
|
||||
lock_acquired = self._lock.acquire(False)
|
||||
if lock_acquired:
|
||||
return True
|
||||
elif not blocking:
|
||||
return False
|
||||
|
||||
release_event = threading.Event()
|
||||
self._pending_threads.append(release_event)
|
||||
|
||||
release_event.wait()
|
||||
return self._lock.acquire()
|
||||
|
||||
def release(self):
|
||||
with self._inner_lock:
|
||||
if self._pending_threads:
|
||||
release_event = self._pending_threads.popleft()
|
||||
release_event.set()
|
||||
|
||||
self._lock.release()
|
||||
|
||||
__enter__ = acquire
|
||||
|
||||
def __exit__(self, t, v, tb):
|
||||
self.release()
|
||||
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
rich_available = False
|
||||
try:
|
||||
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
|
||||
import anyio # importing just so it can be placed on silent list
|
||||
import starlette # importing just so it can be placed on silent list
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
rich_available = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_and_time(req: Request, call_next):
|
||||
ts = time.time()
|
||||
res: Response = await call_next(req)
|
||||
duration = str(round(time.time() - ts, 4))
|
||||
res.headers["X-Process-Time"] = duration
|
||||
endpoint = req.scope.get("path", "err")
|
||||
if cmd_opts.api_log and endpoint.startswith("/sdapi"):
|
||||
print(
|
||||
"API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}".format(
|
||||
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
|
||||
code=res.status_code,
|
||||
ver=req.scope.get("http_version", "0.0"),
|
||||
cli=req.scope.get("client", ("0:0.0.0", 0))[0],
|
||||
prot=req.scope.get("scheme", "err"),
|
||||
method=req.scope.get("method", "err"),
|
||||
endpoint=endpoint,
|
||||
duration=duration,
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
def handle_exception(request: Request, e: Exception):
|
||||
err = {
|
||||
"error": type(e).__name__,
|
||||
"detail": vars(e).get("detail", ""),
|
||||
"body": vars(e).get("body", ""),
|
||||
"errors": str(e),
|
||||
}
|
||||
if not isinstance(
|
||||
e, HTTPException
|
||||
): # do not print backtrace on known httpexceptions
|
||||
message = f"API error: {request.method}: {request.url} {err}"
|
||||
if rich_available:
|
||||
print(message)
|
||||
console.print_exception(
|
||||
show_locals=True,
|
||||
max_frames=2,
|
||||
extra_lines=1,
|
||||
suppress=[anyio, starlette],
|
||||
word_wrap=False,
|
||||
width=min([console.width, 200]),
|
||||
)
|
||||
else:
|
||||
print(message)
|
||||
raise (e)
|
||||
return JSONResponse(
|
||||
status_code=vars(e).get("status_code", 500),
|
||||
content=jsonable_encoder(err),
|
||||
)
|
||||
|
||||
@app.middleware("http")
|
||||
async def exception_handling(request: Request, call_next):
|
||||
try:
|
||||
return await call_next(request)
|
||||
except Exception as e:
|
||||
return handle_exception(request, e)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def fastapi_exception_handler(request: Request, e: Exception):
|
||||
return handle_exception(request, e)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, e: HTTPException):
|
||||
return handle_exception(request, e)
|
||||
|
||||
|
||||
class ApiCompat:
|
||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||
self.router = APIRouter()
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
api_middleware(self.app)
|
||||
# self.add_api_route("/sdapi/v1/txt2img", amdshark_sd_api, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/img2img", amdshark_sd_api, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||
# self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||
# self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
||||
# self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
||||
# self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||
# self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||
# self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||
# self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||
# self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
|
||||
# self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||
# self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
||||
# self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||
# self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||
# self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||
# self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||
# self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
# self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||
# self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||
# self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||
# self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||
# self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||
# self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||
# self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||
# self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||
|
||||
# chat APIs needed for compatibility with multiple extensions using OpenAI API
|
||||
self.add_api_route("/v1/chat/completions", llm_chat_api, methods=["POST"])
|
||||
self.add_api_route("/v1/completions", llm_chat_api, methods=["POST"])
|
||||
self.add_api_route("/chat/completions", llm_chat_api, methods=["POST"])
|
||||
self.add_api_route("/completions", llm_chat_api, methods=["POST"])
|
||||
self.add_api_route(
|
||||
"/v1/engines/codegen/completions", llm_chat_api, methods=["POST"]
|
||||
)
|
||||
|
||||
self.default_script_arg_txt2img = []
|
||||
self.default_script_arg_img2img = []
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||
|
||||
# def refresh_checkpoints(self):
|
||||
# with self.queue_lock:
|
||||
# studio_data.refresh_checkpoints()
|
||||
|
||||
# def refresh_vae(self):
|
||||
# with self.queue_lock:
|
||||
# studio_data.refresh_vae_list()
|
||||
|
||||
# def unloadapi(self):
|
||||
# unload_model_weights()
|
||||
|
||||
# return {}
|
||||
|
||||
# def reloadapi(self):
|
||||
# reload_model_weights()
|
||||
|
||||
# return {}
|
||||
|
||||
# def skip(self):
|
||||
# studio.state.skip()
|
||||
|
||||
def launch(self, server_name, port, root_path):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
host=server_name,
|
||||
port=port,
|
||||
root_path=root_path,
|
||||
)
|
||||
|
||||
# def kill_studio(self):
|
||||
# restart.stop_program()
|
||||
|
||||
# def restart_studio(self):
|
||||
# if restart.is_restartable():
|
||||
# restart.restart_program()
|
||||
# return Response(status_code=501)
|
||||
|
||||
# def preprocess(self, args: dict):
|
||||
# try:
|
||||
# studio.state.begin(job="preprocess")
|
||||
# preprocess(**args)
|
||||
# studio.state.end()
|
||||
# return models.PreprocessResponse(info="preprocess complete")
|
||||
# except:
|
||||
# studio.state.end()
|
||||
|
||||
# def stop_studio(request):
|
||||
# studio.state.server_command = "stop"
|
||||
# return Response("Stopping.")
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,222 +0,0 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
|
||||
freeze_support()
|
||||
from PIL import Image
|
||||
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
import logging
|
||||
import apps.amdshark_studio.api.initializers as initialize
|
||||
|
||||
|
||||
from apps.amdshark_studio.modules import timer
|
||||
|
||||
startup_timer = timer.startup_timer
|
||||
startup_timer.record("launcher")
|
||||
|
||||
initialize.imports()
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
# import before IREE to avoid MLIR library issues
|
||||
import torch_mlir
|
||||
|
||||
|
||||
def create_api(app):
|
||||
from apps.amdshark_studio.web.api.compat import ApiCompat, FIFOLock
|
||||
|
||||
queue_lock = FIFOLock()
|
||||
api = ApiCompat(app, queue_lock)
|
||||
return api
|
||||
|
||||
|
||||
def api_only():
|
||||
from fastapi import FastAPI
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
initialize.initialize()
|
||||
|
||||
app = FastAPI()
|
||||
initialize.setup_middleware(app)
|
||||
api = create_api(app)
|
||||
|
||||
# from modules import script_callbacks
|
||||
# script_callbacks.before_ui_callback()
|
||||
# script_callbacks.app_started_callback(None, app)
|
||||
|
||||
print(f"Startup time: {startup_timer.summary()}.")
|
||||
api.launch(
|
||||
server_name="0.0.0.0",
|
||||
port=cmd_opts.server_port,
|
||||
root_path="",
|
||||
)
|
||||
|
||||
|
||||
def launch_webui(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
|
||||
window = Tk()
|
||||
|
||||
# get screen width and height of display and make it more reasonably
|
||||
# sized as we aren't making it full-screen or maximized
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"AMDSHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
def webui():
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.web.ui.utils import (
|
||||
amdicon_loc,
|
||||
amdlogo_loc,
|
||||
)
|
||||
|
||||
launch_api = cmd_opts.api
|
||||
initialize.initialize()
|
||||
|
||||
from ui.chat import chat_element
|
||||
from ui.sd import sd_element
|
||||
from ui.outputgallery import outputgallery_element
|
||||
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
|
||||
# if args.api or "api" in args.ui.split(","):
|
||||
# from apps.amdshark_studio.api.llm import (
|
||||
# chat,
|
||||
# )
|
||||
# from apps.amdshark_studio.web.api import sdapi
|
||||
#
|
||||
# from fastapi import FastAPI, APIRouter
|
||||
# from fastapi.middleware.cors import CORSMiddleware
|
||||
# import uvicorn
|
||||
#
|
||||
# # init global sd pipeline and config
|
||||
# global_obj._init()
|
||||
#
|
||||
# api = FastAPI()
|
||||
# api.mount("/sdapi/", sdapi)
|
||||
#
|
||||
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
|
||||
# api.add_api_route(
|
||||
# "/v1/chat/completions", llm_chat_api, methods=["post"]
|
||||
# )
|
||||
# api.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
|
||||
# api.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
|
||||
# api.add_api_route("/completions", llm_chat_api, methods=["post"])
|
||||
# api.add_api_route(
|
||||
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
|
||||
# )
|
||||
# api.include_router(APIRouter())
|
||||
#
|
||||
# # deal with CORS requests if CORS accept origins are set
|
||||
# if args.api_accept_origin:
|
||||
# print(
|
||||
# f"API Configured for CORS. Accepting origins: { args.api_accept_origin }"
|
||||
# )
|
||||
# api.add_middleware(
|
||||
# CORSMiddleware,
|
||||
# allow_origins=args.api_accept_origin,
|
||||
# allow_methods=["GET", "POST"],
|
||||
# allow_headers=["*"],
|
||||
# )
|
||||
# else:
|
||||
# print("API not configured for CORS")
|
||||
#
|
||||
# uvicorn.run(api, host="0.0.0.0", port=args.server_port)
|
||||
# sys.exit(0)
|
||||
import gradio as gr
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js")
|
||||
|
||||
# from apps.amdshark_studio.web.ui import load_ui_from_script
|
||||
|
||||
def register_button_click(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x[0]["name"] if len(x) != 0 else None,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_outputgallery_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme,
|
||||
js=gradio_workarounds,
|
||||
analytics_enabled=False,
|
||||
title="AMDShark Studio 2.0 Beta",
|
||||
) as studio_web:
|
||||
amd_logo = Image.open(amdlogo_loc)
|
||||
gr.Image(
|
||||
value=amd_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="tab_bar_logo",
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Tabs() as tabs:
|
||||
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
|
||||
# have a unique id that doesn't clash with any of the other tabs,
|
||||
# and that the order in the code here is the order they should
|
||||
# appear in the ui, as the id value doesn't determine the order.
|
||||
|
||||
# Where possible, avoid changing the id of any tab that is the
|
||||
# destination of one of the 'send to' buttons. If you do have to change
|
||||
# that id, make sure you update the relevant register_button_click calls
|
||||
# further down with the new id.
|
||||
with gr.TabItem(label="Stable Diffusion", id=0):
|
||||
sd_element.render()
|
||||
with gr.TabItem(label="Output Gallery", id=1):
|
||||
outputgallery_element.render()
|
||||
with gr.TabItem(label="Chat Bot", id=2):
|
||||
chat_element.render()
|
||||
|
||||
studio_web.queue()
|
||||
|
||||
# if args.ui == "app":
|
||||
# t = Process(
|
||||
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
# )
|
||||
# t.start()
|
||||
studio_web.launch(
|
||||
share=cmd_opts.share,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=cmd_opts.server_port,
|
||||
favicon_path=amdicon_loc,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
if cmd_opts.webui == False:
|
||||
api_only()
|
||||
else:
|
||||
webui()
|
||||
@@ -1,239 +0,0 @@
|
||||
import gradio as gr
|
||||
import time
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import sys
|
||||
from apps.amdshark_studio.api.llm import (
|
||||
llm_model_map,
|
||||
LanguageModel,
|
||||
)
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
B_SYS, E_SYS = "<s>", "</s>"
|
||||
|
||||
B_SYS, E_SYS = "<s>", "</s>"
|
||||
|
||||
B_SYS, E_SYS = "<s>", "</s>"
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
def append_bot_prompt(history, input_prompt):
|
||||
user_prompt = f"{input_prompt} {E_SYS} {E_SYS}"
|
||||
history += user_prompt
|
||||
return history
|
||||
|
||||
|
||||
language_model = None
|
||||
|
||||
|
||||
def get_default_config():
|
||||
return False
|
||||
|
||||
|
||||
# model_vmfb_key = ""
|
||||
|
||||
|
||||
def chat_fn(
|
||||
prompt_prefix,
|
||||
history,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
streaming_llm,
|
||||
cli=False,
|
||||
):
|
||||
global language_model
|
||||
if streaming_llm and prompt_prefix == "Clear":
|
||||
language_model = None
|
||||
return "Clearing history...", ""
|
||||
if language_model is None:
|
||||
history[-1][-1] = "Getting the model ready..."
|
||||
yield history, ""
|
||||
language_model = LanguageModel(
|
||||
model,
|
||||
device=device,
|
||||
precision=precision,
|
||||
external_weights="safetensors",
|
||||
use_system_prompt=prompt_prefix,
|
||||
streaming_llm=streaming_llm,
|
||||
hf_auth_token=cmd_opts.hf_auth_token,
|
||||
)
|
||||
history[-1][-1] = "Getting the model ready... Done"
|
||||
yield history, ""
|
||||
history[-1][-1] = ""
|
||||
token_count = 0
|
||||
total_time = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, exec_time in language_model.chat(history):
|
||||
history[-1][-1] = f"{text}{E_SYS}"
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
yield history, f"Prefill: {prefill_time:.2f}"
|
||||
else:
|
||||
total_time += exec_time
|
||||
token_count += 1
|
||||
tokens_per_sec = token_count / total_time
|
||||
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
|
||||
|
||||
|
||||
def view_json_file(file_obj):
|
||||
content = ""
|
||||
with open(file_obj.name, "r") as fopen:
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
|
||||
with gr.Blocks(title="Chat") as chat_element:
|
||||
with gr.Row():
|
||||
model_choices = list(llm_model_map.keys())
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
supported_devices = global_obj.get_device_list()
|
||||
enabled = True
|
||||
if len(supported_devices) == 0:
|
||||
supported_devices = ["cpu-task"]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0],
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp32",
|
||||
choices=[
|
||||
# "int4",
|
||||
# "int8",
|
||||
# "fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from AMDShark tank if available",
|
||||
value=False,
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
streaming_llm = gr.Checkbox(
|
||||
label="Run in streaming mode (requires recompilation)",
|
||||
value=True,
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
prompt_prefix = gr.Checkbox(
|
||||
label="Add System Prompt",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
interactive=enabled,
|
||||
container=False,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(label="Upload sharding configuration", visible=False)
|
||||
json_view_button = gr.Button("View as JSON", visible=False)
|
||||
json_view = gr.JSON(visible=False)
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
submit_event = msg.submit(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
streaming_llm,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
streaming_llm,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[submit_event, submit_click_event],
|
||||
queue=False,
|
||||
)
|
||||
clear.click(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
clear,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
streaming_llm,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
).then(lambda: None, None, [chatbot], queue=False)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user