Compare commits

..

1 Commits

Author SHA1 Message Date
Ean Garvey
288602ad1d Add PIL hidden imports to sd spec. 2023-06-16 11:04:22 -05:00
221 changed files with 19795 additions and 10733 deletions

View File

@@ -2,4 +2,4 @@
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py
exclude = lit.cfg.py

View File

@@ -50,13 +50,12 @@ jobs:
shell: powershell
run: |
./setup_venv.ps1
python process_skipfiles.py
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip install -e .
pip freeze -l
pyinstaller .\apps\shark_studio\shark_studio.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/shark_sd.exe ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets
@@ -75,3 +74,80 @@ jobs:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
linux-build:
runs-on: a100
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
backend: [IREE, SHARK]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Setup pip cache
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
- name: Build and validate the IREE package
if: ${{ matrix.backend == 'IREE' }}
continue-on-error: true
run: |
cd $GITHUB_WORKSPACE
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
source iree.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt
if !(grep -Fxq " failed" pytest_results.txt)
then
export SHA=$(git log -1 --format='%h')
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
fi
rm -rf ./wheelhouse/nodai*
- name: Build and validate the SHARK Runtime package
if: ${{ matrix.backend == 'SHARK' }}
run: |
cd $GITHUB_WORKSPACE
./setup_venv.sh
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt

163
.github/workflows/test-models.yml vendored Normal file
View File

@@ -0,0 +1,163 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Validate Models on Shark Runtime
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
suite: [cpu,cuda,vulkan]
python-version: ["3.11"]
include:
- os: ubuntu-latest
suite: lint
exclude:
- os: ubuntu-latest
suite: vulkan
- os: ubuntu-latest
suite: cuda
- os: ubuntu-latest
suite: cpu
- os: MacStudio
suite: cuda
- os: MacStudio
suite: cpu
- os: icelake
suite: vulkan
- os: icelake
suite: cuda
- os: a100
suite: cpu
- os: 7950x
suite: cpu
- os: 7950x
suite: cuda
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
if: matrix.os != '7950x'
- name: Set Environment Variables
if: matrix.os != '7950x'
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
run: |
# See https://github.com/actions/setup-python/issues/433
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'
#cache: 'pip'
#cache-dependency-path: |
# **/requirements-importer.txt
# **/requirements.txt
- uses: actions/checkout@v2
if: matrix.os == '7950x'
- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --check .
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
- name: Validate Models on NVIDIA GPU
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
pytest -k vulkan -s --ci
- name: Validate Stable Diffusion Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
python build_tools/stable_diffusion_testing.py --device=vulkan

View File

@@ -1,85 +0,0 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Validate Shark Studio
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [nodai-ubuntu-builder-large]
suite: [cpu] #,cuda,vulkan]
python-version: ["3.11"]
include:
- os: nodai-ubuntu-builder-large
suite: lint
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set Environment Variables
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
run: |
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'
- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --check apps/shark_studio
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
python${{ matrix.python-version }} -m venv shark.venv
source shark.venv/bin/activate
pip install -r requirements.txt --no-cache-dir
pip install -e .
# Disabled due to hang when exporting test llama2
# python apps/shark_studio/tests/api_test.py

28
.gitignore vendored
View File

@@ -2,8 +2,6 @@
__pycache__/
*.py[cod]
*$py.class
*.mlir
*.vmfb
# C extensions
*.so
@@ -159,12 +157,12 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
#.idea/
# vscode related
.vscode
# Shark related artifacts
# Shark related artefacts
*venv/
shark_tmp/
*.vmfb
@@ -172,7 +170,6 @@ shark_tmp/
tank/dict_configs.py
*.csv
reproducers/
apps/shark_studio/web/configs
# ORT related artefacts
cache_models/
@@ -183,29 +180,10 @@ generated_imgs/
# Custom model related artefacts
variants.json
/models/
*.safetensors
models/
# models folder
apps/stable_diffusion/web/models/
# model artifacts (SHARK)
*.tempfile
*.mlir
*.vmfb
# Stencil annotators.
stencil_annotator/
# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData
# Embeded browser cache and other
apps/stable_diffusion/web/EBWebView/
# Llama2 tokenizer configs
llama2_tokenizer_configs/
# Webview2 runtime artefacts
EBWebView/

2
.gitmodules vendored
View File

@@ -1,4 +1,4 @@
[submodule "inference/thirdparty/shark-runtime"]
path = inference/thirdparty/shark-runtime
url =https://github.com/nod-ai/SRT.git
url =https://github.com/nod-ai/SHARK-Runtime.git
branch = shark-06032022

View File

@@ -2,20 +2,18 @@
High Performance Machine Learning Distribution
*We are currently rebuilding SHARK to take advantage of [Turbine](https://github.com/nod-ai/SHARK-Turbine). Until that is complete make sure you use an .exe release or a checkout of the `SHARK-1.0` branch, for a working SHARK*
[![Nightly Release](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
[![Validate torch-models on Shark Runtime](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
<details>
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
#### Linux Drivers
* MESA / RADV drivers wont work with FP16. Please use the latest AMGPU-PRO drivers (non-pro OSS drivers also wont work) or the latest NVidia Linux Drivers.
@@ -24,23 +22,23 @@ Other users please ensure you have your latest vendor drivers and Vulkan SDK fro
</details>
### Quick Start for SHARK Stable Diffusion for Windows 10/11 Users
Install the Driver from (Prerequisites)[https://github.com/nod-ai/SHARK#install-your-hardware-drivers] above
Install the Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
Download the [stable release](https://github.com/nod-ai/shark/releases/latest) or the most recent [SHARK 1.0 pre-release](https://github.com/nod-ai/shark/releases).
Download the [stable release](https://github.com/nod-ai/shark/releases/latest)
Double click the .exe, or [run from the command line](#running) (recommended), and you should have the [UI](http://localhost:8080/) in the browser.
Double click the .exe and you should have the [UI](http://localhost:8080/) in the browser.
If you have custom models put them in a `models/` directory where the .exe is.
If you have custom models put them in a `models/` directory where the .exe is.
Enjoy.
Enjoy.
<details>
<summary>More installation notes</summary>
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
## Running
@@ -48,22 +46,17 @@ Enjoy.
* The first run may take few minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
* You will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/.
* If you prefer to always run in the browser, use the `--ui=web` command argument when running the EXE.
## Stopping
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
</details>
<details>
<summary>Advanced Installation (Only for developers)</summary>
## Advanced Installation (Windows, Linux and macOS) for developers
### Windows 10/11 Users
* Install Git for Windows from [here](https://git-scm.com/download/win) if you don't already have it.
## Check out the code
```shell
@@ -71,22 +64,14 @@ git clone https://github.com/nod-ai/SHARK.git
cd SHARK
```
## Switch to the Correct Branch (IMPORTANT!)
Currently SHARK is being rebuilt for [Turbine](https://github.com/nod-ai/SHARK-Turbine) on the `main` branch. For now you are strongly discouraged from using `main` unless you are working on the rebuild effort, and should not expect the code there to produce a working application for Image Generation, So for now you'll need switch over to the `SHARK-1.0` branch and use the stable code.
```shell
git checkout SHARK-1.0
```
The following setup instructions assume you are on this branch.
## Setup your Python VirtualEnvironment and Dependencies
### Windows 10/11 Users
* Install the latest Python 3.11.x version from [here](https://www.python.org/downloads/windows/)
* Install Git for Windows from [here](https://git-scm.com/download/win)
#### Allow the install script to run in Powershell
```powershell
set-executionpolicy remotesigned
@@ -101,20 +86,21 @@ set-executionpolicy remotesigned
```shell
./setup_venv.sh
source shark1.venv/bin/activate
source shark.venv/bin/activate
```
### Run Stable Diffusion on your device - WebUI
#### Windows 10/11 Users
```powershell
(shark1.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark1.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
(shark.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
```
#### Linux / macOS Users
```shell
(shark1.venv) > cd apps/stable_diffusion/web
(shark1.venv) > python index.py
(shark.venv) > cd apps/stable_diffusion/web
(shark.venv) > python index.py
```
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
@@ -128,7 +114,7 @@ source shark1.venv/bin/activate
#### Windows 10/11 Users
```powershell
(shark1.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
@@ -156,7 +142,7 @@ Here are some samples generated:
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
<details>
@@ -184,7 +170,7 @@ python -m pip install --upgrade pip
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
```shell
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```
### Run shark tank model tests.
@@ -219,7 +205,7 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
If you want to use Python3.11 and with TF Import tools you can use the environment variables like:
Set `USE_IREE=1` to use upstream IREE
```
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
```
### Run any of the hundreds of SHARK tank models via the test framework
@@ -228,7 +214,7 @@ python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use g
# Or a pytest
pytest tank/test_models.py -k "MiniLM"
```
### How to use your locally built IREE / Torch-MLIR with SHARK
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
@@ -254,12 +240,12 @@ Now the SHARK will use your locally build Torch-MLIR repo.
## Benchmarking Dispatches
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
```
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
```
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.
@@ -268,6 +254,7 @@ if you want to instead incorporate this into a python script, you can pass the `
```
shark_module = SharkInference(
mlir_model,
func_name,
device=args.device,
mlir_dialect="tm_tensor",
dispatch_benchmarks="all",
@@ -278,7 +265,7 @@ shark_module = SharkInference(
Output will include:
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
- An .mlir file containing the dispatch benchmark
- An .mlir file containing the dispatch benchmark
- A compiled .vmfb file containing the dispatch benchmark
- An .mlir file containing just the hal executable
- A compiled .vmfb file of the hal executable
@@ -310,7 +297,7 @@ torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
from shark.shark_inference import SharkInference
shark_module = SharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input))
@@ -333,20 +320,15 @@ mhlo_ir = r"""builtin.module {
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
shark_module.compile()
result = shark_module.forward((arg0, arg1))
```
</details>
## Examples Using the REST API
* [Setting up SHARK for use with Blender](./docs/shark_sd_blender.md)
* [Setting up SHARK for use with Koboldcpp](./docs/shark_sd_koboldcpp.md)
## Supported and Validated Models
SHARK is maintained to support the latest innovations in ML Models:
SHARK is maintained to support the latest innovations in ML Models:
| TF HuggingFace Models | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------|----------|-------------|
@@ -372,7 +354,7 @@ For a complete list of the models supported in SHARK, please refer to [tank/READ
* [Upstream IREE issues](https://github.com/google/iree/issues): Feature requests,
bugs, and other work tracking
* [Upstream IREE Discord server](https://discord.gg/wEWh6Z9nMU): Daily development
* [Upstream IREE Discord server](https://discord.gg/26P4xW4): Daily development
discussions with the core team and collaborators
* [iree-discuss email list](https://groups.google.com/forum/#!forum/iree-discuss):
Announcements, general and low-priority discussion
@@ -387,7 +369,7 @@ For a complete list of the models supported in SHARK, please refer to [tank/READ
* Weekly meetings on Mondays 9AM PST. See [here](https://discourse.llvm.org/t/community-meeting-developer-hour-refactoring-recurring-meetings/62575) for more information.
* [MLIR topic within LLVM Discourse](https://llvm.discourse.group/c/llvm-project/mlir/31) SHARK and IREE is enabled by and heavily relies on [MLIR](https://mlir.llvm.org).
</details>
## License
nod.ai SHARK is licensed under the terms of the Apache 2.0 License with LLVM Exceptions.

View File

@@ -0,0 +1,210 @@
import torch
import torch_mlir
from transformers import (
AutoTokenizer,
StoppingCriteria,
)
from io import BytesIO
from pathlib import Path
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
get_vmfb_from_path,
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def shouldStop(tokens):
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if tokens[0][-1] == stop_id:
return True
return False
MAX_SEQUENCE_LENGTH = 256
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def compile_stableLM(
model,
model_inputs,
model_name,
model_vmfb_name,
device="cuda",
precision="fp32",
):
from shark.shark_inference import SharkInference
# device = "cuda" # "cpu"
# TODO: vmfb and mlir name should include precision and device
vmfb_path = (
Path(model_name + f"_{device}.vmfb")
if model_vmfb_name is None
else Path(model_vmfb_name)
)
shark_module = get_vmfb_from_path(
vmfb_path, device, mlir_dialect="tm_tensor"
)
if shark_module is not None:
return shark_module
mlir_path = Path(model_name + ".mlir")
print(
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
)
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
)
print("Saved vmfb at ", str(path))
return shark_module
class StableLMModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
combine_input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
output = self.model(**combine_input_dict)
return output.logits
# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
def get_tokenizer():
model_path = "stabilityai/stablelm-tuned-alpha-3b"
tok = AutoTokenizer.from_pretrained(model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
print("Sucessfully loaded the tokenizer to the memory")
return tok
# sharkStableLM = compile_stableLM
# (
# None,
# tuple([input_ids, attention_mask]),
# "stableLM_linalg_f32_seqLen256",
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
# )
def generate(
new_text,
max_new_tokens,
sharkStableLM,
tokenizer=None,
):
if tokenizer is None:
tokenizer = get_tokenizer()
# Construct the input message string for the model by
# concatenating the current system message and conversation history
# Tokenize the messages string
# sharkStableLM = compile_stableLM
# (
# None,
# tuple([input_ids, attention_mask]),
# "stableLM_linalg_f32_seqLen256",
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
# )
words_list = []
for i in range(max_new_tokens):
# numWords = len(new_text.split())
# if(numWords>220):
# break
params = {
"new_text": new_text,
}
generated_token_op = generate_new_token(
sharkStableLM, tokenizer, params
)
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]
if stop_generation:
break
print(detok, end="", flush=True)
words_list.append(detok)
if detok == "":
break
new_text = new_text + detok
return words_list
def generate_new_token(shark_model, tokenizer, params):
new_text = params["new_text"]
model_inputs = tokenizer(
[new_text],
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
output = shark_model(
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
)
output = torch.from_numpy(output)
next_toks = torch.topk(output, 1)
stop_generation = False
if shouldStop(next_toks.indices):
stop_generation = True
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
detok = tokenizer.decode(
new_token,
skip_special_tokens=True,
)
ret_dict = {
"new_token": new_token,
"detok": detok,
"stop_generation": stop_generation,
}
return ret_dict

View File

@@ -0,0 +1,122 @@
import argparse
from pathlib import Path
from apps.language_models.src.pipelines import vicuna_pipeline as vp
from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp
import torch
parser = argparse.ArgumentParser(
prog="vicuna runner",
description="runs a vicuna model",
)
parser.add_argument(
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
"--first_vicuna_vmfb_path", default=None, help="path to first vicuna vmfb"
)
parser.add_argument(
"-s",
"--sharded",
default=False,
action=argparse.BooleanOptionalAction,
help="Run model as sharded",
)
# TODO: sharded config
parser.add_argument(
"--second_vicuna_vmfb_path",
default=None,
help="path to second vicuna vmfb",
)
parser.add_argument(
"--first_vicuna_mlir_path",
default=None,
help="path to first vicuna mlir file",
)
parser.add_argument(
"--second_vicuna_mlir_path",
default=None,
help="path to second vicuna mlir",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
parser.add_argument(
"--cli",
default=False,
action=argparse.BooleanOptionalAction,
help="Run model in cli mode",
)
if __name__ == "__main__":
args, unknown = parser.parse_known_args()
vic = None
if not args.sharded:
first_vic_mlir_path = (
Path(f"first_vicuna_{args.precision}.mlir")
if args.first_vicuna_mlir_path is None
else Path(args.first_vicuna_mlir_path)
)
second_vic_mlir_path = (
Path(f"second_vicuna_{args.precision}.mlir")
if args.second_vicuna_mlir_path is None
else Path(args.second_vicuna_mlir_path)
)
first_vic_vmfb_path = (
Path(
f"first_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
)
if args.first_vicuna_vmfb_path is None
else Path(args.first_vicuna_vmfb_path)
)
second_vic_vmfb_path = (
Path(
f"second_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
)
if args.second_vicuna_vmfb_path is None
else Path(args.second_vicuna_vmfb_path)
)
vic = vp.Vicuna(
"vicuna",
device=args.device,
precision=args.precision,
first_vicuna_mlir_path=first_vic_mlir_path,
second_vicuna_mlir_path=second_vic_mlir_path,
first_vicuna_vmfb_path=first_vic_vmfb_path,
second_vicuna_vmfb_path=second_vic_vmfb_path,
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
)
else:
vic = vsp.Vicuna(
"vicuna",
device=args.device,
precision=args.precision,
)
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"
import gc
while True:
# TODO: Add break condition from user input
user_prompt = input("User: ")
prompt_history = (
prompt_history + "USER:\n" + user_prompt + prologue_prompt
)
prompt = prompt_history.strip()
res_str = vic.generate(prompt, cli=True)
torch.cuda.empty_cache()
gc.collect()
print(
"\n-----\nAssistant: Here's the complete formatted reply:\n",
res_str,
)
prompt_history += f"\n{res_str}\n"

View File

@@ -0,0 +1,22 @@
import torch
class FalconModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": None,
"use_cache": True,
}
output = self.model(
**input_dict,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)[0]
return output[:, -1, :]

View File

@@ -0,0 +1,15 @@
import torch
class StableLMModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
combine_input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
output = self.model(**combine_input_dict)
return output.logits

View File

@@ -0,0 +1,261 @@
import torch
from transformers import AutoModelForCausalLM
class FirstVicuna(torch.nn.Module):
def __init__(self, model_path):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
def forward(self, input_ids):
op = self.model(input_ids=input_ids, use_cache=True)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class SecondVicuna(torch.nn.Module):
def __init__(self, model_path):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
def forward(
self,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
):
# input_ids = input_tuple[0]
# input_tuple = torch.unbind(pkv, dim=0)
token = i0
past_key_values = (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class CombinedModel(torch.nn.Module):
def __init__(
self,
first_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF",
second_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF",
):
super().__init__()
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
def forward(self, input_ids):
first_output = self.first_vicuna(input_ids=input_ids, use_cache=True)
logits = first_output[0]
pkv = first_output[1:]
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
second_output = self.second_vicuna(secondVicunaInput)
return second_output

View File

@@ -0,0 +1,178 @@
import torch
class FirstVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states, attention_mask, position_ids):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class SecondVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
past_key_value0,
past_key_value1,
),
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class CompiledFirstVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class CompiledSecondVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
def forward(
self,
input_ids,
is_first=True,
past_key_values=None,
attention_mask=None,
):
if is_first:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers0
)
return self.model.forward(input_ids, attention_mask=attention_mask)
else:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers1
)
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)

View File

@@ -0,0 +1,41 @@
from abc import ABC, abstractmethod
class SharkLLMBase(ABC):
def __init__(
self, model_name, hf_model_path=None, max_num_tokens=512
) -> None:
self.model_name = model_name
self.hf_model_path = hf_model_path
self.max_num_tokens = max_num_tokens
self.shark_model = None
self.device = "cpu"
self.precision = "fp32"
@classmethod
@abstractmethod
def compile(self):
pass
@classmethod
@abstractmethod
def generate(self, prompt):
pass
@classmethod
@abstractmethod
def generate_new_token(self, params):
pass
@classmethod
@abstractmethod
def get_tokenizer(self):
pass
@classmethod
@abstractmethod
def get_src_model(self):
pass
def load_init_from_config(self):
pass

View File

@@ -0,0 +1,473 @@
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.utils import (
get_vmfb_from_path,
)
from io import BytesIO
from pathlib import Path
from contextlib import redirect_stdout
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
StoppingCriteriaList,
)
import copy
import re
import torch
import torch_mlir
import os
import argparse
parser = argparse.ArgumentParser(
prog="falcon runner",
description="runs a falcon model",
)
parser.add_argument(
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
"--falcon_vmfb_path", default=None, help="path to falcon's vmfb"
)
parser.add_argument(
"--falcon_mlir_path",
default=None,
help="path to falcon's mlir file",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
parser.add_argument(
"--cli",
default=True,
action=argparse.BooleanOptionalAction,
help="Run model in cli mode",
)
class Falcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="tiiuae/falcon-7b-instruct",
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=Path("falcon.mlir"),
falcon_vmfb_path=Path("falcon.vmfb"),
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_padding_length = 100
self.device = device
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.src_model = self.get_src_model()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, trust_remote_code=True
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 11
return tokenizer
def get_src_model(self):
print("Loading src model")
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return falcon_model
def compile_falcon(self):
vmfb = get_vmfb_from_path(self.falcon_vmfb_path, self.device, "linalg")
if vmfb is not None:
return vmfb
print(
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
if args.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/falcon/7b/cuda/falcon.mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
if not mlir_generated:
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 100)
)
compilation_attention_mask = torch.ones(
1, 100, dtype=torch.int64
)
falconCompileInput = (
compilation_input_ids,
compilation_attention_mask,
)
model = FalconModel(self.src_model)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
falconCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
print(f"[DEBUG] generating torch mlir")
module = torch_mlir.compile(
ts_graph,
[*falconCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module
print(f"[DEBUG] writing mlir to file")
with open(f"{self.model_name}.mlir", "wb") as f_:
with redirect_stdout(f_):
print(module.operation.get_asm())
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile(self):
if (
not self.falcon_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/falcon/7b/cuda/falcon.vmfb",
self.falcon_vmfb_path.absolute(),
single_file=True,
)
falcon_shark_model = self.compile_falcon()
return falcon_shark_model
def generate(self, prompt):
model_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.max_padding_length,
add_special_tokens=False,
return_tensors="pt",
)
model_inputs["prompt_text"] = prompt
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
generate_kwargs = {
"max_length": self.max_num_tokens,
"do_sample": True,
"top_k": 10,
"num_return_sequences": 1,
"eos_token_id": 11,
}
generate_kwargs["input_ids"] = input_ids
generate_kwargs["attention_mask"] = attention_mask
generation_config_ = GenerationConfig.from_model_config(
self.src_model.config
)
generation_config = copy.deepcopy(generation_config_)
model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
stopping_criteria = StoppingCriteriaList()
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
model_kwargs,
) = self.src_model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
self.logits_processor = self.src_model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids.shape[-1],
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.src_model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
)
self.logits_warper = self.src_model._get_logits_warper(
generation_config
)
(
self.input_ids,
self.model_kwargs,
) = self.src_model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
**model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id) if eos_token_id is not None else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
output_attentions = generation_config.output_attentions # False
output_hidden_states = generation_config.output_hidden_states # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
all_text = prompt
for i in range(self.max_num_tokens - 1):
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
all_text = all_text + new_word
print(f"{new_word}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(input_ids, self.scores)
):
break
torch.cuda.empty_cache()
gc.collect()
return all_text
def generate_new_token(self):
model_inputs = self.src_model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = torch.from_numpy(
self.shark_model(
"forward",
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
if __name__ == "__main__":
args = parser.parse_args()
falcon_mlir_path = (
Path("falcon.mlir")
if args.falcon_mlir_path is None
else Path(args.falcon_mlir_path)
)
falcon_vmfb_path = (
Path("falcon.vmfb")
if args.falcon_vmfb_path is None
else Path(args.falcon_vmfb_path)
)
falcon = Falcon(
"falcon",
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
falcon_vmfb_path=falcon_vmfb_path,
)
import gc
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
continue_execution = True
while continue_execution:
use_default_prompt = input(
"\nDo you wish to use the default prompt text? True or False?: "
)
if use_default_prompt:
prompt = default_prompt_text
else:
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)
res_str = falcon.generate(prompt)
torch.cuda.empty_cache()
gc.collect()
print(
"\n\n-----\nHere's the complete formatted result: \n\n",
res_str,
)
continue_execution = input(
"\nDo you wish to run script one more time? True or False?: "
)

View File

@@ -0,0 +1,185 @@
import torch
import torch_mlir
from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
from io import BytesIO
from pathlib import Path
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
get_vmfb_from_path,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.stablelm_model import (
StableLMModel,
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
class SharkStableLM(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
max_num_tokens=512,
device="cuda",
precision="fp32",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
self.precision = precision
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
def shouldStop(self, tokens):
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if tokens[0][-1] == stop_id:
return True
return False
def get_src_model(self):
model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, torch_dtype=torch.float32
)
return model
def get_model_inputs(self):
input_ids = torch.randint(3, (1, self.max_sequence_len))
attention_mask = torch.randint(3, (1, self.max_sequence_len))
return input_ids, attention_mask
def compile(self):
tmp_model_name = (
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
)
# device = "cuda" # "cpu"
# TODO: vmfb and mlir name should include precision and device
model_vmfb_name = None
vmfb_path = (
Path(tmp_model_name + f"_{self.device}.vmfb")
if model_vmfb_name is None
else Path(model_vmfb_name)
)
shark_module = get_vmfb_from_path(
vmfb_path, self.device, mlir_dialect="tm_tensor"
)
if shark_module is not None:
return shark_module
mlir_path = Path(tmp_model_name + ".mlir")
print(
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
model = StableLMModel(self.get_src_model())
model_inputs = self.get_model_inputs()
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(tmp_model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
)
print("Saved vmfb at ", str(path))
return shark_module
def get_tokenizer(self):
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
return tok
def generate(self, prompt):
words_list = []
for i in range(self.max_num_tokens):
params = {
"new_text": prompt,
}
generated_token_op = self.generate_new_token(params)
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]
if stop_generation:
break
print(detok, end="", flush=True) # this is for CLI and DEBUG
words_list.append(detok)
if detok == "":
break
prompt = prompt + detok
return words_list
def generate_new_token(self, params):
new_text = params["new_text"]
model_inputs = self.tokenizer(
[new_text],
padding="max_length",
max_length=self.max_sequence_len,
truncation=True,
return_tensors="pt",
)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
output = self.shark_model(
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
)
output = torch.from_numpy(output)
next_toks = torch.topk(output, 1)
stop_generation = False
if self.shouldStop(next_toks.indices):
stop_generation = True
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
detok = self.tokenizer.decode(
new_token,
skip_special_tokens=True,
)
ret_dict = {
"new_token": new_token,
"detok": detok,
"stop_generation": stop_generation,
}
return ret_dict
# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""

View File

@@ -0,0 +1,559 @@
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.utils import (
get_vmfb_from_path,
)
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import torch
import torch_mlir
import os
class Vicuna(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
max_num_tokens=512,
device="cuda",
precision="fp32",
first_vicuna_mlir_path=Path("first_vicuna.mlir"),
second_vicuna_mlir_path=Path("second_vicuna.mlir"),
first_vicuna_vmfb_path=Path("first_vicuna.vmfb"),
second_vicuna_vmfb_path=Path("second_vicuna.vmfb"),
load_mlir_from_shark_tank=True,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
self.first_vicuna_mlir_path = first_vicuna_mlir_path
self.second_vicuna_mlir_path = second_vicuna_mlir_path
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return vicuna_model
def compile_first_vicuna(self):
vmfb = get_vmfb_from_path(
self.first_vicuna_vmfb_path, self.device, "tm_tensor"
)
if vmfb is not None:
return vmfb
# Compilation path needs some more work before it is functional
print(
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.first_vicuna_mlir_path.exists():
with open(self.first_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/vicuna/unsharded/mlir/first_vicuna.mlir",
self.first_vicuna_mlir_path.absolute(),
single_file=True,
)
if self.first_vicuna_mlir_path.exists():
with open(self.first_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
if not mlir_generated:
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = self.tokenizer(
compilation_prompt
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(self.hf_model_path)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
print(f"[DEBUG] generating torch mlir")
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
def remove_constant_dim(line):
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim)", line
)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
return line
module = str(module)
new_lines = []
print(f"[DEBUG] rewriting torch_mlir file")
for line in module.splitlines():
line = remove_constant_dim(line)
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
new_lines.append(
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
)
if (
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
in line
):
continue
new_lines.append(line)
module = "\n".join(new_lines)
print(f"[DEBUG] converting to bytecode")
del new_lines
module = module.encode("UTF-8")
module = BytesIO(module)
bytecode = module.read()
del module
print(f"[DEBUG] writing mlir to file")
f_ = open(self.first_vicuna_mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
path = shark_module.save_module(
self.first_vicuna_vmfb_path.parent.absolute(),
self.first_vicuna_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
print("Saved first vic vmfb at vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile_second_vicuna(self):
vmfb = get_vmfb_from_path(
self.second_vicuna_vmfb_path, self.device, "tm_tensor"
)
if vmfb is not None:
return vmfb
# Compilation path needs some more work before it is functional
print(
f"[DEBUG] mlir path {self.second_vicuna_mlir_path} {'exists' if self.second_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.second_vicuna_mlir_path.exists():
with open(self.second_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/vicuna/unsharded/mlir/second_vicuna.mlir",
self.second_vicuna_mlir_path.absolute(),
single_file=True,
)
if self.second_vicuna_mlir_path.exists():
with open(self.second_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
if not mlir_generated:
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(self.hf_model_path)
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[
i
] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
def remove_constant_dim(line):
if "c19_i64" in line:
line = re.sub("c19_i64", "dim_i64", line)
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim)", line
)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
if "20x" in line:
line = re.sub("20x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dimp1)", line
)
if " 20," in line:
line = re.sub(" 20,", " %dimp1,", line)
return line
module_str = str(module)
new_lines = []
for line in module_str.splitlines():
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128xf32>"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
)
continue
if "%c2 = arith.constant 2 : index" in line:
continue
if "%c20_i64 = arith.constant 20 : i64" in line:
new_lines.append("%c1_i64 = arith.constant 1 : i64")
new_lines.append(
"%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
)
new_lines.append(
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
)
continue
line = remove_constant_dim(line)
new_lines.append(line)
module_str = "\n".join(new_lines)
bytecode = module_str.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(self.second_vicuna_mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
path = shark_module.save_module(
self.second_vicuna_vmfb_path.parent.absolute(),
self.second_vicuna_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
print("Saved vmfb at ", str(path))
shark_module.load_module(self.second_vicuna_vmfb_path)
# self.shark_module = shark_module
return shark_module
def compile(self):
# Cannot load both the models in the memory at once
# due to memory constraints, hence on demand compilation
# is being used until the space is enough for both models
# Testing : DO NOT Download Vmfbs if not found. Modify later
# download vmfbs for A100
if (
not self.first_vicuna_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/vicuna/unsharded/first_vicuna.vmfb",
self.first_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get first vic
# TODO: Remove after testing to avoid memory overload
# fvic_shark_model = self.compile_first_vicuna()
pass
if (
not self.second_vicuna_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/vicuna/unsharded/second_vicuna.vmfb",
self.second_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get second vic
# TODO: Remove after testing to avoid memory overload
# svic_shark_model = self.compile_second_vicuna()
pass
# get first vic
# fvic_shark_model = self.compile_first_vicuna()
# get second vic
# svic_shark_model = self.compile_second_vicuna()
# return tuple of shark_modules
# return fvic_shark_model, svic_shark_model
return None
# return tuple of shark_modules once mem is supported
# return fvic_shark_model, svic_shark_model
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
import gc
res = []
res_tokens = []
params = {
"prompt": prompt,
"is_first": True,
"fv": self.compile_first_vicuna(),
}
generated_token_op = self.generate_new_token(params=params)
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
res.append(detok)
res_tokens.append(token)
if cli:
print(f"Assistant: {detok}", end=" ", flush=True)
# Clear First Vic from Memory (main and cuda)
del params
torch.cuda.empty_cache()
gc.collect()
sec_vic = self.compile_second_vicuna()
for _ in range(self.max_num_tokens - 2):
params = {
"prompt": None,
"is_first": False,
"logits": logits,
"pkv": pkv,
"sv": sec_vic,
}
generated_token_op = self.generate_new_token(params=params)
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
if token == 2:
break
res_tokens.append(token)
if detok == "<0x0A>":
res.append("\n")
if cli:
print("\n", end="", flush=True)
else:
res.append(detok)
if cli:
print(f"{detok}", end=" ", flush=True)
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = self.tokenizer.decode(res_tokens)
# print(f"[DEBUG] final output : \n{res_str}")
return res_str
def generate_new_token(self, params, debug=False):
def forward_first(first_vic, prompt, cache_outputs=False):
input_ids = self.tokenizer(prompt).input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
firstVicunaInput = (input_ids,)
assert first_vic is not None
output_first_vicuna = first_vic("forward", firstVicunaInput)
output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:])
logits_first_vicuna = torch.tensor(output_first_vicuna[0])
if cache_outputs:
torch.save(
logits_first_vicuna, "logits_first_vicuna_tensor.pt"
)
torch.save(
output_first_vicuna_tensor, "output_first_vicuna_tensor.pt"
)
token = torch.argmax(
torch.tensor(logits_first_vicuna)[:, -1, :], dim=1
)
return token, logits_first_vicuna, output_first_vicuna_tensor
def forward_second(sec_vic, inputs=None, load_inputs=False):
if inputs is not None:
logits = inputs[0]
pkv = inputs[1:]
elif load_inputs:
pkv = torch.load("output_first_vicuna_tensor.pt")
pkv = tuple(torch.tensor(x) for x in pkv)
logits = torch.load("logits_first_vicuna_tensor.pt")
else:
print(
"Either inputs must be given, or load_inputs must be true"
)
return None
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
secondVicunaOutput = sec_vic("forward", secondVicunaInput)
new_pkv = secondVicunaOutput[1:]
new_logits = secondVicunaOutput[0]
new_token = torch.argmax(torch.tensor(new_logits)[:, -1, :], dim=1)
return new_token, new_logits, new_pkv
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
fv = params["fv"]
token, logits, pkv = forward_first(
fv, # self.shark_model[0],
prompt=prompt,
cache_outputs=False,
)
else:
_logits = params["logits"]
_pkv = params["pkv"]
inputs = (_logits,) + tuple(_pkv)
sv = params["sv"]
token, logits, pkv = forward_second(
sv, # self.shark_model[1],
inputs=inputs,
load_inputs=False,
)
detok = self.tokenizer.decode(token)
if debug:
print(
f"[DEBUG] is_first: {is_first} |"
f" token : {token} | detok : {detok}"
)
ret_dict = {
"token": token,
"logits": logits,
"pkv": pkv,
"detok": detok,
}
return ret_dict
def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
pass

View File

@@ -0,0 +1,408 @@
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
FirstVicunaLayer,
SecondVicunaLayer,
CompiledFirstVicunaLayer,
CompiledSecondVicunaLayer,
ShardedVicunaModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from shark.shark_importer import import_with_fx
from io import BytesIO
from pathlib import Path
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from torch_mlir import TensorPlaceholder
import re
import torch
import torch_mlir
import os
class Vicuna(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
max_num_tokens=512,
device="cuda",
precision="fp32",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return vicuna_model
def write_in_dynamic_inputs0(self, module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def write_in_dynamic_inputs1(self, module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
if "dim_42 =" in line:
continue
if f"%c{dynamic_input_size}_i64 =" in line:
new_lines.append(
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
)
new_lines.append(
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
)
continue
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim_42)", line
)
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim_42\)",
"tensor.empty(%dim_42, %dim_42)",
line,
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def compile_vicuna_layer(
self,
vicuna_layer,
hidden_states,
attention_mask,
position_ids,
past_key_value0=None,
past_key_value1=None,
):
if past_key_value0 is None and past_key_value1 is None:
model_inputs = (hidden_states, attention_mask, position_ids)
else:
model_inputs = (
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
)
mlir_bytecode = import_with_fx(
vicuna_layer,
model_inputs,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
return mlir_bytecode
def compile_to_vmfb(self, inputs, layers, is_first=True):
mlirs, modules = [], []
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
if is_first:
mlir_path = Path(f"{idx}_0.mlir")
vmfb_path = Path(f"{idx}_0.vmfb")
else:
mlir_path = Path(f"{idx}_1.mlir")
vmfb_path = Path(f"{idx}_1.vmfb")
if vmfb_path.exists():
continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
hidden_states_placeholder = TensorPlaceholder.like(
inputs[0], dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
inputs[1], dynamic_axes=[3]
)
position_ids_placeholder = TensorPlaceholder.like(
inputs[2], dynamic_axes=[1]
)
if not is_first:
pkv0_placeholder = TensorPlaceholder.like(
inputs[3], dynamic_axes=[2]
)
pkv1_placeholder = TensorPlaceholder.like(
inputs[4], dynamic_axes=[2]
)
print(f"Compiling layer {idx} mlir")
if is_first:
ts_g = self.compile_vicuna_layer(
layer, inputs[0], inputs[1], inputs[2]
)
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
inputs[1],
inputs[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
else:
ts_g = self.compile_vicuna_layer(
layer,
inputs[0],
inputs[1],
inputs[2],
inputs[3],
inputs[4],
)
module = torch_mlir.compile(
ts_g,
(
inputs[0],
attention_mask_placeholder,
inputs[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
# bytecode_stream = BytesIO()
# module.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
if is_first:
module = self.write_in_dynamic_inputs0(str(module), 137)
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
else:
module = self.write_in_dynamic_inputs1(str(module), 138)
if idx in [0, 5, 6, 7]:
module_str = module
module_str = module_str.splitlines()
new_lines = []
for line in module_str:
if len(line) < 1000:
new_lines.append(line)
else:
new_lines.append(line[:999])
module_str = "\n".join(new_lines)
f1_ = open(f"{idx}_1_test.mlir", "w+")
f1_.write(module_str)
f1_.close()
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
mlirs.append(bytecode)
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
if is_first:
vmfb_path = Path(f"{idx}_0.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_0",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
else:
vmfb_path = Path(f"{idx}_1.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_1",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
return mlirs, modules
def get_sharded_model(self):
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it
SAMPLE_INPUT_LEN = 137
vicuna_model = self.get_src_model()
placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
)
placeholder_input1 = (
torch.zeros([1, 1, 4096]),
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
torch.zeros([1, 1], dtype=torch.int64),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
layers0 = [
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules0 = self.compile_to_vmfb(
placeholder_input0, layers0, is_first=True
)
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
layers1 = [
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules1 = self.compile_to_vmfb(
placeholder_input1, layers1, is_first=False
)
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
sharded_model = ShardedVicunaModel(
vicuna_model, shark_layers0, shark_layers1
)
return sharded_model
def compile(self):
return self.get_sharded_model()
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
tokens_generated = []
_past_key_values = None
_token = None
detoks_generated = []
for iteration in range(self.max_num_tokens):
params = {
"prompt": prompt,
"is_first": iteration == 0,
"token": _token,
"past_key_values": _past_key_values,
}
generated_token_op = self.generate_new_token(params=params)
_token = generated_token_op["token"]
_past_key_values = generated_token_op["past_key_values"]
_detok = generated_token_op["detok"]
if _token == 2:
break
detoks_generated.append(_detok)
tokens_generated.append(_token)
for i in range(len(tokens_generated)):
if type(tokens_generated[i]) != int:
tokens_generated[i] = int(tokens_generated[i][0])
result_output = self.tokenizer.decode(tokens_generated)
return result_output
def generate_new_token(self, params):
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
input_ids = self.tokenizer(prompt).input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(input_ids, is_first=is_first)
else:
token = params["token"]
past_key_values = params["past_key_values"]
input_ids = [token]
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(
input_ids, past_key_values=past_key_values, is_first=is_first
)
_logits = output["logits"]
_past_key_values = output["past_key_values"]
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
_detok = self.tokenizer.decode(_token)
ret_dict = {
"token": _token,
"detok": _detok,
"past_key_values": _past_key_values,
}
print(f" token : {_token} | detok : {_detok}")
return ret_dict
def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
pass

View File

@@ -0,0 +1,25 @@
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from pathlib import Path
# expects a Path / str as arg
# returns None if path not found or SharkInference module
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
if not isinstance(vmfb_path, Path):
vmfb_path = Path(vmfb_path)
from shark.shark_inference import SharkInference
if not vmfb_path.exists():
return None
print("Loading vmfb from: ", vmfb_path)
shark_module = SharkInference(
None, device=device, mlir_dialect=mlir_dialect
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
return shark_module

View File

@@ -1,107 +0,0 @@
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors
import os
import PIL
import numpy as np
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
)
from datetime import datetime
from PIL import Image
from gradio.components.image_editor import (
EditorValue,
)
class control_adapter:
def __init__(
self,
model: str,
):
self.model = None
def export_control_adapter_model(model_keyword):
return None
def export_xl_control_adapter_model(model_keyword):
return None
class preprocessors:
def __init__(
self,
model: str,
):
self.model = None
def export_controlnet_model(model_keyword):
return None
control_adapter_map = {
"sd15": {
"canny": {"initializer": control_adapter.export_control_adapter_model},
"openpose": {"initializer": control_adapter.export_control_adapter_model},
"scribble": {"initializer": control_adapter.export_control_adapter_model},
"zoedepth": {"initializer": control_adapter.export_control_adapter_model},
},
"sdxl": {
"canny": {"initializer": control_adapter.export_xl_control_adapter_model},
},
}
preprocessor_model_map = {
"canny": {"initializer": preprocessors.export_controlnet_model},
"openpose": {"initializer": preprocessors.export_controlnet_model},
"scribble": {"initializer": preprocessors.export_controlnet_model},
"zoedepth": {"initializer": preprocessors.export_controlnet_model},
}
class PreprocessorModel:
def __init__(
self,
hf_model_id,
device="cpu",
):
self.model = hf_model_id
self.device = device
def compile(self):
print("compile not implemented for preprocessor.")
return
def run(self, inputs):
print("run not implemented for preprocessor.")
return inputs
def cnet_preview(model, input_image):
curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints")
if not os.path.exists(control_imgs_path):
os.mkdir(control_imgs_path)
img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png")
match model:
case "canny":
canny = PreprocessorModel("canny")
result = canny(
np.array(input_image),
100,
200,
)
Image.fromarray(result).save(fp=img_dest)
return result, img_dest
case "openpose":
openpose = PreprocessorModel("openpose")
result = openpose(np.array(input_image))
Image.fromarray(result[0]).save(fp=img_dest)
return result, img_dest
case "zoedepth":
zoedepth = PreprocessorModel("ZoeDepth")
result = zoedepth(np.array(input_image))
Image.fromarray(result).save(fp=img_dest)
return result, img_dest
case "scribble":
input_image.save(fp=img_dest)
return input_image, img_dest
case _:
return None, None

View File

@@ -1,125 +0,0 @@
import importlib
import os
import signal
import sys
import warnings
import json
from threading import Thread
from apps.shark_studio.modules.timer import startup_timer
from apps.shark_studio.web.utils.tmp_configs import (
config_tmp,
clear_tmp_mlir,
clear_tmp_imgs,
shark_tmp,
)
def imports():
import torch # noqa: F401
startup_timer.record("import torch")
warnings.filterwarnings(
action="ignore", category=DeprecationWarning, module="torch"
)
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
import gradio # noqa: F401
startup_timer.record("import gradio")
import apps.shark_studio.web.utils.globals as global_obj
global_obj._init()
startup_timer.record("initialize globals")
from apps.shark_studio.modules import (
img_processing,
) # noqa: F401
startup_timer.record("other imports")
def initialize():
configure_sigint_handler()
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
config_tmp()
# clear_tmp_mlir()
clear_tmp_imgs()
from apps.shark_studio.web.utils.file_utils import (
create_model_folders,
)
# Create custom models folders if they don't exist
create_model_folders()
import gradio as gr
# initialize_rest(reload_script_modules=False)
def initialize_rest(*, reload_script_modules=False):
"""
Called both from initialize() and when reloading the webui.
"""
# Keep this for adding reload options to the webUI.
def dumpstacks():
import threading
import traceback
id2name = {th.ident: th.name for th in threading.enumerate()}
code = []
for threadId, stack in sys._current_frames().items():
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
for filename, lineno, name, line in traceback.extract_stack(stack):
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
if line:
code.append(" " + line.strip())
with open(os.path.join(shark_tmp, "stack_dump.log"), "w") as f:
f.write("\n".join(code))
def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware
app.middleware_stack = (
None # reset current middleware to allow modifying user provided list
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly
def configure_cors_middleware(app):
from starlette.middleware.cors import CORSMiddleware
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
cors_options = {
"allow_methods": ["*"],
"allow_headers": ["*"],
"allow_credentials": True,
}
if cmd_opts.api_accept_origin:
cors_options["allow_origins"] = cmd_opts.api_accept_origin.split(",")
app.add_middleware(CORSMiddleware, **cors_options)
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f"Interrupted with signal {sig} in {frame}")
dumpstacks()
os._exit(0)
signal.signal(signal.SIGINT, sigint_handler)

View File

@@ -1,475 +0,0 @@
from turbine_models.custom_models import stateless_llama
from turbine_models.model_runner import vmfbRunner
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
from apps.shark_studio.web.utils.file_utils import (
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.api.utils import parse_device
from urllib.request import urlopen
import iree.runtime as ireert
from itertools import chain
import gc
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
llm_model_map = {
"meta-llama/Llama-2-7b-chat-hf": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"TinyPixel/small-llama2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "TinyPixel/small-llama2",
"compile_flags": ["--iree-opt-const-expr-hoisting=True"],
"stop_token": 2,
"max_tokens": 1024,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
}
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<s>", "</s>"
DEFAULT_CHAT_SYS_PROMPT = """<s>[INST] <<SYS>>
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <</SYS>>\n\n
"""
def append_user_prompt(history, input_prompt):
user_prompt = f"{B_INST} {input_prompt} {E_INST}"
history += user_prompt
return history
class LanguageModel:
def __init__(
self,
model_name,
hf_auth_token=None,
device=None,
quantization="int4",
precision="",
external_weights=None,
use_system_prompt=True,
streaming_llm=False,
):
_, _, self.triple = parse_device(device)
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.device = device.split("=>")[-1].strip()
self.backend = self.device.split("://")[0]
self.driver = self.backend
if "cpu" in device:
self.device = "cpu"
self.backend = "llvm-cpu"
self.driver = "local-task"
print(f"Selected {self.backend} as IREE target backend.")
self.precision = "f32" if "cpu" in device else "f16"
self.quantization = quantization
self.safe_name = self.hf_model_name.replace("/", "_").replace("-", "_")
self.external_weight_file = None
# TODO: find a programmatic solution for model arch spec instead of hardcoding llama2
self.file_spec = "_".join(
[
self.safe_name,
self.precision,
]
)
if self.quantization != "None":
self.file_spec += "_" + self.quantization
if external_weights in ["safetensors", "gguf"]:
self.external_weight_file = get_resource_path(
os.path.join("..", self.file_spec + "." + external_weights)
)
else:
self.external_weights = None
self.external_weight_file = None
if streaming_llm:
# Add streaming suffix to file spec after setting external weights filename.
self.file_spec += "_streaming"
self.streaming_llm = streaming_llm
self.tempfile_name = get_resource_path(
os.path.join("..", f"{self.file_spec}.tempfile")
)
# TODO: Tag vmfb with target triple of device instead of HAL backend
self.vmfb_name = str(
get_resource_path(
os.path.join("..", f"{self.file_spec}_{self.backend}.vmfb.tempfile")
)
)
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.use_system_prompt = use_system_prompt
self.global_iter = 0
self.prev_token_len = 0
self.first_input = True
self.hf_auth_token = hf_auth_token
if self.external_weight_file is not None:
if not os.path.exists(self.external_weight_file):
print(
f"External weight file {self.external_weight_file} does not exist. Generating..."
)
gen_external_params(
hf_model_name=self.hf_model_name,
quantization=self.quantization,
weight_path=self.external_weight_file,
hf_auth_token=hf_auth_token,
precision=self.precision,
)
else:
print(
f"External weight file {self.external_weight_file} found for {self.vmfb_name}"
)
self.external_weight_file = str(self.external_weight_file)
if os.path.exists(self.vmfb_name) and (
external_weights is None or os.path.exists(str(self.external_weight_file))
):
self.runner = vmfbRunner(
device=self.driver,
vmfb_path=self.vmfb_name,
external_weight_path=self.external_weight_file,
)
if self.streaming_llm:
self.model = self.runner.ctx.modules.streaming_state_update
else:
self.model = self.runner.ctx.modules.state_update
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
"initializer"
](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
external_weights=external_weights,
precision=self.precision,
quantization=self.quantization,
streaming_llm=self.streaming_llm,
decomp_attn=True,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.compile()
else:
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
self.compile()
# Reserved for running HF torch model as reference.
self.hf_mod = None
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
# ONLY architecture/api-specific compile-time flags for each backend, if needed.
# hf_model_id-specific global flags currently in model map.
flags = []
if "cpu" in self.backend:
flags.extend(
[
"--iree-global-opt-enable-quantized-matmul-reassociation",
]
)
elif self.backend == "vulkan":
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
elif self.backend == "rocm":
flags.extend(
[
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-outer-dim-concat=true",
"--iree-flow-enable-aggressive-fusion",
]
)
if "gfx9" in self.triple:
flags.extend(
[
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
]
)
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
flatbuffer_blob = compile_module_to_flatbuffer(
self.tempfile_name,
device=self.device,
frontend="auto",
model_config_path=None,
extra_args=flags,
write_to=self.vmfb_name,
)
self.runner = vmfbRunner(
device=self.driver,
vmfb_path=self.vmfb_name,
external_weight_path=self.external_weight_file,
)
if self.streaming_llm:
self.model = self.runner.ctx.modules.streaming_state_update
else:
self.model = self.runner.ctx.modules.state_update
def sanitize_prompt(self, prompt):
if isinstance(prompt, list):
prompt = list(chain.from_iterable(prompt))
prompt = " ".join([x for x in prompt if isinstance(x, str)])
prompt = prompt.replace("\n", " ")
prompt = prompt.replace("\t", " ")
prompt = prompt.replace("\r", " ")
if self.use_system_prompt and self.global_iter == 0:
prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt)
return prompt
else:
return f"{B_INST} {prompt} {E_INST}"
def chat(self, prompt):
prompt = self.sanitize_prompt(prompt)
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
def format_out(results):
return torch.tensor(results.to_host()[0][0])
history = []
for iter in range(self.max_tokens):
if self.streaming_llm:
token_slice = max(self.prev_token_len - 1, 0)
input_tensor = input_tensor[:, token_slice:]
if self.streaming_llm and self.model["get_seq_step"]() > 600:
print("Evicting cache space!")
self.model["evict_kvcache_space"]()
token_len = input_tensor.shape[-1]
device_inputs = [
ireert.asdevicearray(self.runner.config.device, input_tensor)
]
if self.first_input or not self.streaming_llm:
st_time = time.time()
token = self.model["run_initialize"](*device_inputs)
total_time = time.time() - st_time
token_len += 1
self.first_input = False
else:
st_time = time.time()
token = self.model["run_cached_initialize"](*device_inputs)
total_time = time.time() - st_time
token_len += 1
history.append(format_out(token))
while (
format_out(token) != llm_model_map[self.hf_model_name]["stop_token"]
and len(history) < self.max_tokens
):
dec_time = time.time()
if self.streaming_llm and self.model["get_seq_step"]() > 600:
print("Evicting cache space!")
self.model["evict_kvcache_space"]()
token = self.model["run_forward"](token)
history.append(format_out(token))
total_time = time.time() - dec_time
yield self.tokenizer.decode(history), total_time
self.prev_token_len = token_len + len(history)
if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
self.global_iter += 1
return result_output, total_time
# Reference HF model function for sanity checks.
def chat_hf(self, prompt):
if self.hf_mod is None:
self.hf_mod = AutoModelForCausalLM.from_pretrained(
self.hf_model_name,
torch_dtype=torch.float,
token=self.hf_auth_token,
)
prompt = self.sanitize_prompt(prompt)
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
history = []
for iter in range(self.max_tokens):
token_len = input_tensor.shape[-1]
if self.first_input:
st_time = time.time()
result = self.hf_mod(input_tensor)
token = torch.argmax(result.logits[:, -1, :], dim=1)
total_time = time.time() - st_time
token_len += 1
pkv = result.past_key_values
self.first_input = False
history.append(int(token))
while token != llm_model_map[self.hf_model_name]["stop_token"]:
dec_time = time.time()
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
history.append(int(token))
total_time = time.time() - dec_time
token = torch.argmax(result.logits[:, -1, :], dim=1)
pkv = result.past_key_values
yield self.tokenizer.decode(history), total_time
self.prev_token_len = token_len + len(history)
if token == llm_model_map[self.hf_model_name]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
self.global_iter += 1
return result_output, total_time
def get_mfma_spec_path(target_chip, save_dir):
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
attn_spec = urlopen(url).read().decode("utf-8")
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
if os.path.exists(spec_path):
return spec_path
with open(spec_path, "w") as f:
f.write(attn_spec)
return spec_path
def llm_chat_api(InputData: dict):
from datetime import datetime as dt
import apps.shark_studio.web.utils.globals as global_obj
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
if is_chat_completion_api:
print(f"message -> role : {InputData['messages'][0]['role']}")
print(f"message -> content : {InputData['messages'][0]['content']}")
else:
print(f"prompt : {InputData['prompt']}")
model_name = (
InputData["model"]
if "model" in InputData.keys()
else "meta-llama/Llama-2-7b-chat-hf"
)
model_path = llm_model_map[model_name]
device = InputData["device"] if "device" in InputData.keys() else "cpu"
precision = "fp16"
max_tokens = InputData["max_tokens"] if "max_tokens" in InputData.keys() else 4096
device_id = None
if not global_obj.get_llm_obj():
print("\n[LOG] Initializing new pipeline...")
global_obj.clear_cache()
gc.collect()
if "cuda" in device:
device = "cuda"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "cpu" in device:
device = "cpu"
precision = "fp32"
else:
print("unrecognized device")
llm_model = LanguageModel(
model_name=model_name,
hf_auth_token=cmd_opts.hf_auth_token,
device=device,
quantization=cmd_opts.quantization,
external_weights="safetensors",
use_system_prompt=True,
streaming_llm=False,
)
global_obj.set_llm_obj(llm_model)
else:
llm_model = global_obj.get_llm_obj()
llm_model.max_tokens = max_tokens
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = append_user_prompt(
InputData["messages"][0]["role"], InputData["messages"][0]["content"]
)
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
for res_op, _ in llm_model.chat(prompt):
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion" if is_chat_completion_api else "text_completion",
"created": int(end_time),
"choices": choices,
}
if __name__ == "__main__":
lm = LanguageModel(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
)
print("model loaded")
for i in lm.chat("hi, what are you?"):
print(i)

View File

@@ -1,505 +0,0 @@
import gc
import torch
import gradio as gr
import time
import os
import json
import numpy as np
import copy
import importlib.util
import sys
from tqdm.auto import tqdm
from pathlib import Path
from random import randint
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
SharkSDXLPipeline,
)
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.api.utils import parse_device
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.web.utils.file_utils import (
safe_name,
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.img_processing import (
save_output_img,
)
from apps.shark_studio.modules.ckpt_processing import (
preprocessCKPT,
save_irpa,
)
EMPTY_SD_MAP = {
"clip": None,
"scheduler": None,
"unet": None,
"vae_decode": None,
}
EMPTY_SDXL_MAP = {
"prompt_encoder": None,
"scheduled_unet": None,
"vae_decode": None,
"pipeline": None,
"full_pipeline": None,
}
EMPTY_FLAGS = {
"clip": None,
"unet": None,
"vae": None,
"pipeline": None,
}
def load_script(source, module_name):
"""
reads file source and loads it as a module
:param source: file to load
:param module_name: name of module to register in sys.modules
:return: loaded module
"""
spec = importlib.util.spec_from_file_location(module_name, source)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
class StableDiffusion:
# This class is responsible for executing image generation and creating
# /managing a set of compiled modules to run Stable Diffusion. The init
# aims to be as general as possible, and the class will infer and compile
# a list of necessary modules or a combined "pipeline module" for a
# specified job based on the inference task.
def __init__(
self,
base_model_id,
height: int,
width: int,
batch_size: int,
steps: int,
scheduler: str,
precision: str,
device: str,
target_triple: str = None,
custom_vae: str = None,
num_loras: int = 0,
import_ir: bool = True,
is_controlled: bool = False,
external_weights: str = "safetensors",
):
self.precision = precision
self.compiled_pipeline = False
self.base_model_id = base_model_id
self.custom_vae = custom_vae
self.is_sdxl = "xl" in self.base_model_id.lower()
self.is_custom = ".py" in self.base_model_id.lower()
if self.is_custom:
custom_module = load_script(
os.path.join(get_checkpoints_path("scripts"), self.base_model_id),
"custom_pipeline",
)
self.turbine_pipe = custom_module.StudioPipeline
self.model_map = custom_module.MODEL_MAP
elif self.is_sdxl:
self.turbine_pipe = SharkSDXLPipeline
self.model_map = EMPTY_SDXL_MAP
else:
self.turbine_pipe = SharkSDPipeline
self.model_map = EMPTY_SD_MAP
max_length = 64
target_backend, self.rt_device, triple = parse_device(device, target_triple)
pipe_id_list = [
safe_name(base_model_id),
str(batch_size),
str(max_length),
f"{str(height)}x{str(width)}",
precision,
triple,
]
if num_loras > 0:
pipe_id_list.append(str(num_loras) + "lora")
if is_controlled:
pipe_id_list.append("controlled")
if custom_vae:
pipe_id_list.append(custom_vae)
self.pipe_id = "_".join(pipe_id_list)
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.weights_path = Path(
os.path.join(
get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)
)
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
decomp_attn = True
attn_spec = None
if triple in ["gfx940", "gfx942", "gfx90a"]:
decomp_attn = False
attn_spec = "mfma"
elif triple in ["gfx1100", "gfx1103", "gfx1150"]:
decomp_attn = False
attn_spec = "wmma"
if triple in ["gfx1103", "gfx1150"]:
# external weights have issues on igpu
external_weights = None
elif target_backend == "llvm-cpu":
decomp_attn = False
self.sd_pipe = self.turbine_pipe(
hf_model_name=base_model_id,
scheduler_id=scheduler,
height=height,
width=width,
precision=precision,
max_length=max_length,
batch_size=batch_size,
num_inference_steps=steps,
device=target_backend,
iree_target_triple=triple,
ireec_flags=EMPTY_FLAGS,
attn_spec=attn_spec,
decomp_attn=decomp_attn,
pipeline_dir=self.pipeline_dir,
external_weights_dir=self.weights_path,
external_weights=external_weights,
custom_vae=custom_vae,
)
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
gc.collect()
def prepare_pipe(
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
):
print(f"\n[LOG] Preparing pipeline...")
self.is_img2img = False
mlirs = copy.deepcopy(self.model_map)
vmfbs = copy.deepcopy(self.model_map)
weights = copy.deepcopy(self.model_map)
if not self.is_sdxl:
compiled_pipeline = False
self.compiled_pipeline = compiled_pipeline
if custom_weights:
custom_weights = os.path.join(
get_checkpoints_path("checkpoints"),
safe_name(self.base_model_id.split("/")[-1]),
custom_weights,
)
diffusers_weights_path = preprocessCKPT(custom_weights, self.precision)
for key in weights:
if key in ["scheduled_unet", "unet"]:
unet_weights_path = os.path.join(
diffusers_weights_path,
"unet",
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(unet_weights_path, "unet.")
elif key in ["clip", "prompt_encoder"]:
if not self.is_sdxl:
sd1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
else:
clip_1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
clip_2_path = os.path.join(
diffusers_weights_path,
"text_encoder_2",
"model.safetensors",
)
weights[key] = [
save_irpa(clip_1_path, "text_encoder_model_1."),
save_irpa(clip_2_path, "text_encoder_model_2."),
]
elif key in ["vae_decode"] and weights[key] is None:
vae_weights_path = os.path.join(
diffusers_weights_path,
"vae",
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(vae_weights_path, "vae.")
vmfbs, weights = self.sd_pipe.check_prepared(
mlirs, vmfbs, weights, interactive=False
)
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
self.sd_pipe.load_pipeline(
vmfbs, weights, self.rt_device, self.compiled_pipeline
)
print(
"\n[LOG] Pipeline successfully prepared for runtime. Generating images..."
)
return
def generate_images(
self,
prompt,
negative_prompt,
image,
strength,
guidance_scale,
seed,
ondemand,
resample_type,
control_mode,
hints,
):
img = self.sd_pipe.generate_images(
prompt,
negative_prompt,
1,
guidance_scale,
seed,
return_imgs=True,
)
return img
def shark_sd_fn_dict_input(
sd_kwargs: dict,
):
print("\n[LOG] Submitting Request...")
for key in sd_kwargs:
if sd_kwargs[key] in [None, []]:
sd_kwargs[key] = None
if sd_kwargs[key] in ["None"]:
sd_kwargs[key] = ""
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
if not sd_kwargs["device"]:
gr.Warning("No device specified. Please specify a device.")
return None, ""
if sd_kwargs["height"] not in [512, 1024]:
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
return None, ""
if sd_kwargs["height"] != sd_kwargs["width"]:
gr.Warning("Height and width must be the same. This is a temporary limitation.")
return None, ""
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
if sd_kwargs["steps"] > 10:
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
return None, ""
if sd_kwargs["guidance_scale"] > 3:
gr.Warning(
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
)
return None, ""
if sd_kwargs["target_triple"] == "":
if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "":
gr.Warning(
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
)
return None, ""
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
return generated_imgs
def shark_sd_fn(
prompt,
negative_prompt,
sd_init_image: list,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: list,
batch_count: int,
batch_size: int,
scheduler: str,
base_model_id: str,
custom_weights: str,
custom_vae: str,
precision: str,
device: str,
target_triple: str,
ondemand: bool,
compiled_pipeline: bool,
resample_type: str,
controlnets: dict,
embeddings: dict,
):
sd_kwargs = locals()
if not isinstance(sd_init_image, list):
sd_init_image = [sd_init_image]
is_img2img = True if sd_init_image[0] is not None else False
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
adapters = {}
is_controlled = False
control_mode = None
hints = []
num_loras = 0
import_ir = True
for i in embeddings:
num_loras += 1 if embeddings[i] else 0
if "model" in controlnets:
for i, model in enumerate(controlnets["model"]):
if "xl" not in base_model_id.lower():
adapters[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map["runwayml/stable-diffusion-v1-5"][
model
],
"strength": controlnets["strength"][i],
}
else:
adapters[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][
model
],
"strength": controlnets["strength"][i],
}
if model is not None:
is_controlled = True
control_mode = controlnets["control_mode"]
for i in controlnets["hint"]:
hints.append[i]
submit_pipe_kwargs = {
"base_model_id": base_model_id,
"height": height,
"width": width,
"batch_size": batch_size,
"precision": precision,
"device": device,
"target_triple": target_triple,
"custom_vae": custom_vae,
"num_loras": num_loras,
"import_ir": import_ir,
"is_controlled": is_controlled,
"steps": steps,
"scheduler": scheduler,
}
submit_prep_kwargs = {
"custom_weights": custom_weights,
"adapters": adapters,
"embeddings": embeddings,
"is_img2img": is_img2img,
"compiled_pipeline": compiled_pipeline,
}
submit_run_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": sd_init_image,
"strength": strength,
"guidance_scale": guidance_scale,
"seed": seed,
"ondemand": ondemand,
"resample_type": resample_type,
"control_mode": control_mode,
"hints": hints,
}
if (
not global_obj.get_sd_obj()
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
):
print("\n[LOG] Initializing new pipeline...")
global_obj.clear_cache()
gc.collect()
# Initializes the pipeline and retrieves IR based on all
# parameters that are static in the turbine output format,
# which is currently MLIR in the torch dialect.
sd_pipe = StableDiffusion(
**submit_pipe_kwargs,
)
global_obj.set_sd_obj(sd_pipe)
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
if (
not global_obj.get_prep_kwargs()
or global_obj.get_prep_kwargs() != submit_prep_kwargs
):
global_obj.set_prep_kwargs(submit_prep_kwargs)
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
generated_imgs = []
for current_batch in range(batch_count):
start_time = time.time()
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
if not isinstance(out_imgs, list):
out_imgs = [out_imgs]
# total_time = time.time() - start_time
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
# print(f"\n[LOG] {text_output}")
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
# break
# else:
for batch in range(batch_size):
save_output_img(
out_imgs[batch],
seed,
sd_kwargs,
)
generated_imgs.extend(out_imgs)
# TODO: make seed changes over batch counts more configurable.
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)
return (generated_imgs, "")
def unload_sd():
print("Unloading models.")
import apps.shark_studio.web.utils.globals as global_obj
global_obj.clear_cache()
gc.collect()
def cancel_sd():
print("Inject call to cancel longer API calls.")
return
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
content = fopen.read()
return content
def safe_name(name):
return name.replace("/", "_").replace("\\", "_").replace(".", "_")
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
global_obj._init()
sd_json = view_json_file(
get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json"))
)
sd_kwargs = json.loads(sd_json)
for arg in vars(cmd_opts):
if arg in sd_kwargs:
sd_kwargs[arg] = getattr(cmd_opts, arg)
for i in shark_sd_fn_dict_input(sd_kwargs):
print(i)

View File

@@ -1,389 +0,0 @@
import numpy as np
import json
from random import (
randint,
seed as seed_random,
getstate as random_getstate,
setstate as random_setstate,
)
from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from cpuinfo import get_cpu_info
# TODO: migrate these utils to studio
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
if "local" in driver_name:
device_list.append(
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
# for drivers with single devices
# let the default device be selected without any indexing
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
hip_devices = get_devices_by_name("hip")
available_devices.extend(hip_devices)
for idx, device_str in enumerate(available_devices):
if "AMD Radeon(TM) Graphics =>" in device_str:
igpu_id_candidates = [
x.split("w/")[-1].split("=>")[0]
for x in available_devices
if "M Graphics" in x
]
for igpu_name in igpu_id_candidates:
if igpu_name:
available_devices[idx] = device_str.replace(
"AMD Radeon(TM) Graphics", igpu_name
)
break
return available_devices
def set_init_device_flags():
if "vulkan" in cmd_opts.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
cmd_opts.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_vulkan_target_triple}."
)
elif "cuda" in cmd_opts.device:
cmd_opts.device = "cuda"
elif "metal" in cmd_opts.device:
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_metal_target_platform:
from shark.iree_utils.metal_utils import get_metal_target_triple
triple = get_metal_target_triple(device_name)
if triple is not None:
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_metal_target_platform}."
)
elif "cpu" in cmd_opts.device:
cmd_opts.device = "cpu"
def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if cmd_opts.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if cmd_opts.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def parse_device(device_str, target_override=""):
from shark.iree_utils.compile_utils import (
clean_device_info,
get_iree_target_triple,
iree_target_map,
)
rt_driver, device_id = clean_device_info(device_str)
target_backend = iree_target_map(rt_driver)
if device_id:
rt_device = f"{rt_driver}://{device_id}"
else:
rt_device = rt_driver
if target_override:
return target_backend, rt_device, target_override
match target_backend:
case "vulkan-spirv":
triple = get_iree_target_triple(device_str)
return target_backend, rt_device, triple
case "rocm":
triple = get_rocm_target_chip(device_str)
return target_backend, rt_device, triple
case "llvm-cpu":
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
def get_rocm_target_chip(device_str):
# TODO: Use a data file to map device_str to target chip.
rocm_chip_map = {
"6700": "gfx1031",
"6800": "gfx1030",
"6900": "gfx1030",
"7900": "gfx1100",
"MI300X": "gfx942",
"MI300A": "gfx940",
"MI210": "gfx90a",
"MI250": "gfx90a",
"MI100": "gfx908",
"MI50": "gfx906",
"MI60": "gfx906",
"780M": "gfx1103",
}
for key in rocm_chip_map:
if key in device_str:
return rocm_chip_map[key]
raise AssertionError(
f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues."
)
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired
combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def get_opt_flags(model, precision="fp16"):
iree_flags = []
if len(cmd_opts.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
)
if "rocm" in cmd_opts.device:
from shark.iree_utils.gpu_utils import get_iree_rocm_args
rocm_args = get_iree_rocm_args()
iree_flags.extend(rocm_args)
if cmd_opts.iree_constant_folding == False:
iree_flags.append("--iree-opt-const-expr-hoisting=False")
iree_flags.append(
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
if cmd_opts.data_tiling == False:
iree_flags.append("--iree-opt-data-tiling=False")
if "vae" not in model:
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user
selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for
the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
if "local" in driver_name:
device_list.append(
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
# for drivers with single devices
# let the default device be selected without any indexing
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
return available_devices
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None
if isinstance(seed_input, int):
return [seed_input]
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
return seed_input
raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)

View File

@@ -1,145 +0,0 @@
import os
import json
import re
import requests
import torch
import safetensors
from shark_turbine.aot.params import (
ParameterArchiveBuilder,
)
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from omegaconf import OmegaConf
from diffusers import StableDiffusionPipeline
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
create_vae_diffusers_config,
convert_ldm_vae_checkpoint,
)
def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = os.path.join("diffusers", path.stem + f"_{precision}")
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
return path_to_diffusers
def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
return path_to_diffusers
else:
print(
"Diffusers' checkpoint will be identified here : ",
path_to_diffusers,
)
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but
# non-EMA weights have been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
# they want to go for EMA weight extraction or not.
extract_ema = False
print("Loading diffusers' pipeline from original stable diffusion checkpoint")
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
)
if precision == "fp16":
pipe.to(dtype=torch.float16)
pipe.save_pretrained(path_to_diffusers)
del pipe
print("Loading complete")
return path_to_diffusers
def save_irpa(weights_path, prepend_str):
weights = safetensors.torch.load_file(weights_path)
archive = ParameterArchiveBuilder()
for key in weights.keys():
new_key = prepend_str + key
archive.add_tensor(new_key, weights[key])
irpa_file = weights_path.replace(".safetensors", ".irpa")
archive.save(irpa_file)
return irpa_file
def convert_original_vae(vae_checkpoint):
vae_state_dict = {}
for key in list(vae_checkpoint.keys()):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
"main/configs/stable-diffusion/v1-inference.yaml"
)
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config)
return converted_vae_checkpoint
def process_custom_pipe_weights(custom_weights):
if custom_weights != "":
if custom_weights.startswith("https://civitai.com/api/"):
# download the checkpoint from civitai if we don't already have it
weights_path = get_civitai_checkpoint(custom_weights)
# act as if we were given the local file as custom_weights originally
custom_weights_tgt = get_path_to_diffusers_checkpoint(weights_path)
custom_weights_params = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights_params = custom_weights
return custom_weights_params, custom_weights_tgt
def get_civitai_checkpoint(url: str):
with requests.get(url, allow_redirects=True, stream=True) as response:
response.raise_for_status()
# civitai api returns the filename in the content disposition
base_filename = re.findall(
'"([^"]*)"', response.headers["Content-Disposition"]
)[0]
destination_path = Path.cwd() / (cmd_opts.model_dir or "models") / base_filename
# we don't have this model downloaded yet
if not destination_path.is_file():
print(f"downloading civitai model from {url} to {destination_path}")
size = int(response.headers["content-length"], 0)
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
with open(destination_path, "wb") as f:
for chunk in response.iter_content(chunk_size=65536):
f.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
# we already have this model downloaded
else:
print(f"civitai model already downloaded to {destination_path}")
response.close()
return destination_path.as_posix()

View File

@@ -1,185 +0,0 @@
import os
import sys
import torch
import json
import safetensors
from dataclasses import dataclass
from safetensors.torch import load_file
from apps.shark_studio.web.utils.file_utils import (
get_checkpoint_pathfile,
get_path_stem,
)
@dataclass
class LoRAweight:
up: torch.tensor
down: torch.tensor
mid: torch.tensor
alpha: torch.float32 = 1.0
def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
# gather the weights from the LoRA in a more convenient form, assumes
# everything will have an up.weight.
weight_dict: dict[str, LoRAweight] = {}
for key in state_dict:
if key.startswith(splitting_prefix) and key.endswith("up.weight"):
stem = key.split("up.weight")[0]
weight_key = stem.removesuffix(".lora_")
weight_key = weight_key.removesuffix("_lora_")
weight_key = weight_key.removesuffix(".lora_linear_layer.")
if weight_key not in weight_dict:
weight_dict[weight_key] = LoRAweight(
state_dict[f"{stem}up.weight"],
state_dict[f"{stem}down.weight"],
state_dict.get(f"{stem}mid.weight", None),
(
state_dict[f"{weight_key}.alpha"]
/ state_dict[f"{stem}up.weight"].shape[1]
if f"{weight_key}.alpha" in state_dict
else 1.0
),
)
# Directly update weight in model
# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
# and similar code in https://github.com/huggingface/diffusers/issues/3064
# TODO: handle mid weights (how do they even work?)
for key, lora_weight in weight_dict.items():
curr_layer = model
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
weight = curr_layer.weight.data
scale = lora_weight.alpha * lora_strength
if len(weight.size()) == 2:
if len(lora_weight.up.shape) == 4:
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
change = torch.mm(lora_weight.up, lora_weight.down)
elif lora_weight.down.size()[2:4] == (1, 1):
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
change = torch.nn.functional.conv2d(
lora_weight.down.permute(1, 0, 2, 3),
lora_weight.up,
).permute(1, 0, 2, 3)
curr_layer.weight.data += change * scale
return model
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
return unet
main_file_name = get_path_stem(use_lora)
if ".bin" in use_lora:
main_file_name += ".bin"
elif ".safetensors" in use_lora:
main_file_name += ".safetensors"
elif ".pt" in use_lora:
main_file_name += ".pt"
else:
sys.exit("Only .bin and .safetensors format for LoRA is supported")
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)
def update_lora_weight(model, use_lora, model_name, lora_strength=1.0):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora, lora_strength)
try:
return processLoRA(model, use_lora, "lora_te_", lora_strength)
except:
return None
def get_lora_metadata(lora_filename):
# get the metadata from the file
filename = get_checkpoint_pathfile(lora_filename, "lora")
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
metadata = f.metadata()
# guard clause for if there isn't any metadata
if not metadata:
return None
# metadata is a dictionary of strings, the values of the keys we're
# interested in are actually json, and need to be loaded as such
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
tag_dirs = [dir for dir in tag_frequencies.keys()]
# gather the tag frequency information for all the datasets trained
all_frequencies = {}
for dataset in tag_dirs:
frequencies = sorted(
[entry for entry in tag_frequencies[dataset].items()],
reverse=True,
key=lambda x: x[1],
)
# get a figure for the total number of images processed for this dataset
# either then number actually listed or in its dataset_dir entry or
# the highest frequency's number if that doesn't exist
img_count = dataset_dirs.get(dir, {}).get("img_count", frequencies[0][1])
# add the dataset frequencies to the overall frequencies replacing the
# frequency counts on the tags with a percentage/ratio
all_frequencies.update(
[(entry[0], entry[1] / img_count) for entry in frequencies]
)
trained_model_id = " ".join(
[
metadata.get("ss_sd_model_hash", ""),
metadata.get("ss_sd_model_name", ""),
metadata.get("ss_base_model_version", ""),
]
).strip()
# return the topmost <count> of all frequencies in all datasets
return {
"model": trained_model_id,
"frequencies": sorted(
all_frequencies.items(), reverse=True, key=lambda x: x[1]
),
}

View File

@@ -1,202 +0,0 @@
import os
import re
import json
import torch
import numpy as np
from csv import DictWriter
from PIL import Image, PngImagePlugin
from pathlib import Path
from datetime import datetime as dt
from base64 import decode
resamplers = {
"Lanczos": Image.Resampling.LANCZOS,
"Nearest Neighbor": Image.Resampling.NEAREST,
"Bilinear": Image.Resampling.BILINEAR,
"Bicubic": Image.Resampling.BICUBIC,
"Hamming": Image.Resampling.HAMMING,
"Box": Image.Resampling.BOX,
}
resampler_list = resamplers.keys()
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info=None):
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
if extra_info is None:
extra_info = {}
generated_imgs_path = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompt"][0][:15])
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
img_model = extra_info["base_model_id"]
if extra_info["custom_weights"] not in [None, "None"]:
img_model = Path(os.path.basename(extra_info["custom_weights"])).stem
img_vae = None
if extra_info["custom_vae"]:
img_vae = Path(os.path.basename(extra_info["custom_vae"])).stem
img_loras = None
if extra_info["embeddings"]:
img_lora = []
for i in extra_info["embeddings"]:
img_lora += Path(os.path.basename(cmd_opts.use_lora)).stem
img_loras = ", ".join(img_lora)
if cmd_opts.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if cmd_opts.write_metadata_to_png:
# Using a conditional expression caused problems, so setting a new
# variable for now.
# if cmd_opts.use_hiresfix:
# png_size_text = (
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
# )
# else:
png_size_text = f"{extra_info['width']}x{extra_info['height']}"
pngInfo.add_text(
"parameters",
f"{extra_info['prompt'][0]}"
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
f"\nSteps: {extra_info['steps']},"
f"Sampler: {extra_info['scheduler']}, "
f"CFG scale: {extra_info['guidance_scale']}, "
f"Seed: {img_seed},"
f"Size: {png_size_text}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_loras}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if cmd_opts.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {cmd_opts.output_img_format} is not "
f"supported yet. Image saved as png instead."
f"Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
# importance for each data point. Something to consider.
new_entry = {}
new_entry.update(extra_info)
csv_mode = "a" if os.path.isfile(csv_path) else "w"
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
if csv_mode == "w":
dictwriter_obj.writeheader()
dictwriter_obj.writerow(new_entry)
csv_obj.close()
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
# For stencil, the input image can be of any size, but we need to ensure that
# it conforms with our model constraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image, width, height, resampler_type=None):
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 128:
n_size = 128
if width == min_size:
width = n_size
height = n_size / aspect_ratio
else:
height = n_size
width = n_size * aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
min_size = min(width, height)
if min_size > 768:
n_size = 768
if width == min_size:
height = n_size
width = n_size * aspect_ratio
else:
width = n_size
height = n_size / aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
if resampler_type in resamplers:
resampler = resamplers[resampler_type]
else:
resampler = resamplers["Nearest Neighbor"]
new_image = image.resize((n_width, n_height), resampler=resampler)
return new_image, n_width, n_height
def process_sd_init_image(self, sd_init_image, resample_type):
if isinstance(sd_init_image, list):
images = []
for img in sd_init_image:
img, _ = self.process_sd_init_image(img, resample_type)
images.append(img)
is_img2img = True
return images, is_img2img
if isinstance(sd_init_image, str):
if os.path.isfile(sd_init_image):
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type)
else:
image = None
is_img2img = False
elif isinstance(sd_init_image, Image.Image):
image = sd_init_image.convert("RGB")
elif sd_init_image:
image = sd_init_image["image"].convert("RGB")
else:
image = None
is_img2img = False
if image:
resample_type = (
resamplers[resample_type]
if resample_type in resampler_list
# Fallback to Lanczos
else Image.Resampling.LANCZOS
)
image = image.resize((self.width, self.height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
image_arr = 2 * (image_arr - 0.5)
is_img2img = True
image = image_arr
return image, is_img2img

View File

@@ -1,37 +0,0 @@
import sys
class Logger:
def __init__(self, filename, filter=None):
self.terminal = sys.stdout
self.log = open(filename, "w")
self.filter = filter
def write(self, message):
for x in message.split("\n"):
if self.filter in x:
self.log.write(message)
else:
self.terminal.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
def logger_test(x):
print("[LOG] This is a test")
print(f"This is another test, without the filter")
return x
def read_sd_logs():
sys.stdout.flush()
with open("shark_tmp/sd.log", "r") as f:
return f.read()
sys.stdout = Logger("shark_tmp/sd.log", filter="[LOG]")

View File

@@ -1,205 +0,0 @@
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
clean_device_info,
get_iree_target_triple,
)
from apps.shark_studio.web.utils.file_utils import (
get_checkpoints_path,
get_resource_path,
)
from apps.shark_studio.modules.shared_cmd_opts import (
cmd_opts,
)
from iree import runtime as ireert
from pathlib import Path
import gc
import os
class SharkPipelineBase:
# This class is a lightweight base for managing an
# inference API class. It should provide methods for:
# - compiling a set (model map) of torch IR modules
# - preparing weights for an inference job
# - loading weights for an inference job
# - utilites like benchmarks, tests
def __init__(
self,
model_map: dict,
base_model_id: str,
static_kwargs: dict,
device: str,
import_mlir: bool = True,
):
self.model_map = model_map
self.pipe_map = {}
self.static_kwargs = static_kwargs
self.base_model_id = base_model_id
self.triple = get_iree_target_triple(device)
self.device, self.device_id = clean_device_info(device)
self.import_mlir = import_mlir
self.iree_module_dict = {}
self.tmp_dir = get_resource_path(cmd_opts.tmp_dir)
if not os.path.exists(self.tmp_dir):
os.mkdir(self.tmp_dir)
self.tempfiles = {}
self.pipe_vmfb_path = ""
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
# First checks whether we have .vmfbs precompiled, then populates the map
# with the precompiled executables and fetches executables for the rest of the map.
# The weights aren't static here anymore so this function should be a part of pipeline
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
# and your model map is populated with any IR - unique model IDs and their static params,
# call this method to get the artifacts associated with your map.
self.pipe_id = self.safe_name(pipe_id)
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
if submodel == "None":
print("\n[LOG] Gathering any pre-compiled artifacts....")
for key in self.model_map:
self.get_compiled_map(pipe_id, submodel=key)
else:
self.pipe_map[submodel] = {}
self.get_precompiled(self.pipe_id, submodel)
ireec_flags = []
if submodel in self.iree_module_dict:
return
elif "vmfb_path" in self.pipe_map[submodel]:
return
elif submodel not in self.tempfiles:
print(
f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..."
)
if submodel in self.static_kwargs:
init_kwargs = self.static_kwargs[submodel]
for key in self.static_kwargs["pipe"]:
if key not in init_kwargs:
init_kwargs[key] = self.static_kwargs["pipe"][key]
self.import_torch_ir(submodel, init_kwargs)
self.get_compiled_map(pipe_id, submodel)
else:
ireec_flags = (
self.model_map[submodel]["ireec_flags"]
if "ireec_flags" in self.model_map[submodel]
else []
)
weights_path = self.get_io_params(submodel)
if weights_path:
ireec_flags.append("--iree-opt-const-eval=False")
self.iree_module_dict[submodel] = get_iree_compiled_module(
self.tempfiles[submodel],
device=self.device,
frontend="torch",
mmap=True,
external_weight_file=weights_path,
extra_args=ireec_flags,
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
)
return
def get_io_params(self, submodel):
if "external_weight_file" in self.static_kwargs[submodel]:
# we are using custom weights
weights_path = self.static_kwargs[submodel]["external_weight_file"]
elif "external_weight_path" in self.static_kwargs[submodel]:
# we are using the default weights for the HF model
weights_path = self.static_kwargs[submodel]["external_weight_path"]
else:
# assume the torch IR contains the weights.
weights_path = None
return weights_path
def get_precompiled(self, pipe_id, submodel="None"):
if submodel == "None":
for model in self.model_map:
self.get_precompiled(pipe_id, model)
vmfbs = []
for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path):
vmfbs.extend(filenames)
break
for file in vmfbs:
if submodel in file:
self.pipe_map[submodel]["vmfb_path"] = os.path.join(
self.pipe_vmfb_path, file
)
return
def import_torch_ir(self, submodel, kwargs):
torch_ir = self.model_map[submodel]["initializer"](
**self.safe_dict(kwargs), compile_to="torch"
)
if submodel == "clip":
# clip.export_clip_model returns (torch_ir, tokenizer)
torch_ir = torch_ir[0]
self.tempfiles[submodel] = os.path.join(
self.tmp_dir, f"{submodel}.torch.tempfile"
)
with open(self.tempfiles[submodel], "w+") as f:
f.write(torch_ir)
del torch_ir
gc.collect()
return
def load_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
print(f"\n[LOG] {submodel} is ready for inference.")
continue
if "vmfb_path" in self.pipe_map[submodel]:
weights_path = self.get_io_params(submodel)
# print(
# f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
# )
self.iree_module_dict[submodel] = {}
(
self.iree_module_dict[submodel]["vmfb"],
self.iree_module_dict[submodel]["config"],
self.iree_module_dict[submodel]["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
self.pipe_map[submodel]["vmfb_path"],
self.device,
device_idx=0,
rt_flags=[],
external_weight_file=weights_path,
)
else:
self.get_compiled_map(self.pipe_id, submodel)
return
def unload_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
del self.iree_module_dict[submodel]
gc.collect()
return
def run(self, submodel, inputs):
if not isinstance(inputs, list):
inputs = [inputs]
inp = [
ireert.asdevicearray(
self.iree_module_dict[submodel]["config"].device, input
)
for input in inputs
]
return self.iree_module_dict[submodel]["vmfb"]["main"](*inp)
def safe_name(self, name):
return name.replace("/", "_").replace("-", "_").replace("\\", "_")
def safe_dict(self, kwargs: dict):
flat_args = {}
for i in kwargs:
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
else:
flat_args[i] = kwargs[i]
return flat_args

View File

@@ -1,376 +0,0 @@
from typing import List, Optional, Union
from iree import runtime as ireert
import re
import torch
import numpy as np
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs:
text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens,
weights,
max_length,
bos,
eos,
no_boseos_middle=True,
chunk_length=77,
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j
* (chunk_length - 2) : min(
len(weights[i]), (j + 1) * (chunk_length - 2)
)
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe,
text_input,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.run("clip", text_input_chunk)[0].to_host()
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)
else:
text_embeddings = pipe.run("clip", text_input)[0]
text_embeddings = torch.from_numpy(text_embeddings.to_host())
return text_embeddings
# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
return [[49407 if token is None else token for token in tokens[0]]]
def get_weighted_text_embeddings(
pipe,
prompt: List[str],
uncond_prompt: List[str] = None,
max_embeddings_multiples: Optional[int] = 8,
no_boseos_middle: Optional[bool] = True,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(
pipe, prompt, max_length - 2
)
if uncond_prompt is not None:
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(
prompt, max_length=max_length, truncation=True
).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
prompt_tokens = filter_nonetype_tokens(prompt_tokens)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
uncond_tokens = filter_nonetype_tokens(uncond_tokens)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device="cpu")
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
prompt_weights = torch.tensor(prompt_weights, dtype=torch.float, device="cpu")
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
uncond_weights = torch.tensor(uncond_weights, dtype=torch.float, device="cpu")
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@@ -1,118 +0,0 @@
# from shark_turbine.turbine_models.schedulers import export_scheduler_model
from diffusers import (
LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
def get_schedulers(model_id):
# TODO: switch over to turbine and run all on GPU
print(f"\n[LOG] Initializing schedulers from model id: {model_id}")
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
# schedulers["DDPM"] = DDPMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DDIM"] = DDIMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
# model_id, subfolder="scheduler", algorithm_type="dpmsolver"
# )
# schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
# model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
# )
# schedulers["DPMSolverMultistepKarras"] = (
# DPMSolverMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# use_karras_sigmas=True,
# )
# )
# schedulers["DPMSolverMultistepKarras++"] = (
# DPMSolverMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# algorithm_type="dpmsolver++",
# use_karras_sigmas=True,
# )
# )
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["EulerAncestralDiscrete"] = (
EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
)
# schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["KDPM2AncestralDiscrete"] = (
# KDPM2AncestralDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# )
# schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
return schedulers
def export_scheduler_model(model):
return "None", "None"
scheduler_model_map = {
"PNDM": export_scheduler_model("PNDMScheduler"),
# "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
# "LCM": export_scheduler_model("LCMScheduler"),
# "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
# "DDPM": export_scheduler_model("DDPMScheduler"),
# "DDIM": export_scheduler_model("DDIMScheduler"),
# "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
# "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
# "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
# "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
# "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
# "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
}

View File

@@ -1,66 +0,0 @@
import numpy as np
import json
from random import (
randint,
seed as seed_random,
getstate as random_getstate,
setstate as random_setstate,
)
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None
if isinstance(seed_input, int):
return [seed_input]
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
return seed_input
raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)
# Generate a set of seeds from an input expression for batch_count batches,
# optionally using that input as the rng seed for any randomly generated seeds.
def batch_seeds(seed_input: str | list | int, batch_count: int, repeatable=False):
# turn the input into a list if possible
seeds = parse_seed_input(seed_input)
# slice or pad the list to be of batch_count length
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
if repeatable:
if all(seed < 0 for seed in seeds):
seeds[0] = sanitize_seed(seeds[0])
# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
seed_random(str([n for n in seeds if n > -1]))
# generate any seeds that are unspecified
seeds = [sanitize_seed(seed) for seed in seeds]
if repeatable:
# reset the rng back to normal
random_setstate(saved_random_state)
return seeds

View File

@@ -1,791 +0,0 @@
import argparse
import os
from pathlib import Path
from apps.shark_studio.modules.img_processing import resampler_list
def path_expand(s):
return Path(s).expanduser().resolve()
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
# Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
)
p.add_argument(
"-p",
"--prompt",
nargs="+",
default=[
"a photo taken of the front of a super-car drifting on a road near "
"mountains at high speeds with smoke coming off the tires, front "
"angle, front point of view, trees in the mountains of the "
"background, ((sharp focus))"
],
help="Text of which images to be generated.",
)
p.add_argument(
"--negative_prompt",
nargs="+",
default=[
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
"blurry, ugly, blur, oversaturated, cropped"
],
help="Text you don't want to see in the generated image.",
)
p.add_argument(
"--sd_init_image",
type=str,
help="Path to the image input for img2img/inpainting.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="The number of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=str,
default=-1,
help="The seed or list of seeds to use. -1 for a random one.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="The number of inferences to be made in a single `batch_count`.",
)
p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 1025, 8),
help="The height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 1025, 8),
help="The width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="The value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="The value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="Max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--max_embeddings_multiples",
type=int,
default=5,
help="The max multiple length of prompt embeddings compared to the max "
"output length of text encoder.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="The strength of change applied on the given input image for " "img2img.",
)
p.add_argument(
"--use_hiresfix",
type=bool,
default=False,
help="Use Hires Fix to do higher resolution images, while trying to "
"avoid the issues that come with it. This is accomplished by first "
"generating an image using txt2img, then running it through img2img.",
)
p.add_argument(
"--hiresfix_height",
type=int,
default=768,
choices=range(128, 769, 8),
help="The height of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_width",
type=int,
default=768,
choices=range(128, 769, 8),
help="The width of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_strength",
type=float,
default=0.6,
help="The denoising strength to apply for the Hires Fix.",
)
p.add_argument(
"--resample_type",
type=str,
default="Nearest Neighbor",
choices=resampler_list,
help="The resample type to use when resizing an image before being run "
"through stable diffusion.",
)
##############################################################################
# Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model.",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt.",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The number of steps to train.",
)
##############################################################################
# Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting.",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture.",
)
p.add_argument(
"--inpaint_full_res_padding",
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding.",
)
p.add_argument(
"--pixels",
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting.",
)
p.add_argument(
"--mask_blur",
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting.",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend left for outpainting.",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend right for outpainting.",
)
p.add_argument(
"--up",
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend top for outpainting.",
)
p.add_argument(
"--down",
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend bottom for outpainting.",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) "
"(min=0.0, max=4.0).",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0).",
)
##############################################################################
# Model Config and Usage Params
##############################################################################
p.add_argument("--device", type=str, default="vulkan", help="Device to run the model.")
p.add_argument(
"--precision", type=str, default="fp16", help="Precision to run the model."
)
p.add_argument(
"--import_mlir",
default=True,
action=argparse.BooleanOptionalAction,
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
)
p.add_argument(
"--use_tuned",
default=False,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available.",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--scheduler",
type=str,
default="DDIM",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
"HeunDiscrete].",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="Specify the format in which output image is save. "
"Supported options: jpg / png.",
)
p.add_argument(
"--output_dir",
type=str,
default=os.path.join(os.getcwd(), "generated_imgs"),
help="Directory path to save the output images and json.",
)
p.add_argument(
"--batch_count",
type=int,
default=1,
help="Number of batches to be generated with random seeds in " "single execution.",
)
p.add_argument(
"--repeatable_seeds",
default=False,
action=argparse.BooleanOptionalAction,
help="The seed of the first batch will be used as the rng seed to "
"generate the subsequent seeds for subsequent batches in that run.",
)
p.add_argument(
"--custom_weights",
type=str,
default="",
help="Path to a .safetensors or .ckpt file for SD pipeline weights.",
)
p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
"needs to be plugged in.",
)
p.add_argument(
"--base_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption.",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
"or an integer).",
)
p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble", "zoedepth"],
help="Enable the stencil feature.",
)
p.add_argument(
"--control_mode",
choices=["Prompt", "Balanced", "Controlnet"],
default="Balanced",
help="How Controlnet injection should be prioritized.",
)
p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint " "file (~3 MB).",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="Runs the quantized version of stable diffusion model. "
"This is currently in experimental phase. "
"Currently, only runs the stable-diffusion-2-1-base model in "
"int8 quantization.",
)
p.add_argument(
"--lowvram",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM.",
)
p.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
p.add_argument(
"--external_weights",
type=str,
default=None,
help="What type of externalized weights to use. Currently options are 'safetensors' and defaults to inlined weights.",
)
p.add_argument(
"--device_allocator_heap_key",
type=str,
default="",
help="Specify heap key for device caching allocator."
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan.",
)
p.add_argument(
"--iree_metal_target_platform",
type=str,
default="",
help="Specify target triple for metal.",
)
##############################################################################
# Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="Use the default scheduler precompiled into the model if available.",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. "
"If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. " "Use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help="Dispatches to return benchmark data on. "
'Use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help="Directory where you want to store dispatch data "
'generated with "--dispatch_benchmarks".',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for inserting debug frames between iterations " "for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="Flag setting warmup count for CLIP and VAE [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to clear all mlir and vmfb from common locations. "
"Recompiling will take several minutes.",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether or not to save a generation information "
"json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for whether or not to save generation information in "
"PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="If import_mlir is True, saves mlir via the debug option "
"in shark importer. Does nothing if import_mlir is false (the default).",
)
p.add_argument(
"--compile_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to toggle debug assert/verify flags for imported IR in the"
"iree-compiler. Default to false.",
)
p.add_argument(
"--iree_constant_folding",
default=True,
action=argparse.BooleanOptionalAction,
help="Controls constant folding in iree-compile for all SD models.",
)
p.add_argument(
"--data_tiling",
default=False,
action=argparse.BooleanOptionalAction,
help="Controls data tiling in iree-compile for all SD models.",
)
p.add_argument(
"--quantization",
type=str,
default="None",
help="Quantization to be used for api-exposed model.",
)
##############################################################################
# Web UI flags
##############################################################################
p.add_argument(
"--webui",
default=True,
action=argparse.BooleanOptionalAction,
help="controls whether the webui is launched.",
)
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the progress bar animation during " "image generation.",
)
p.add_argument(
"--tmp_dir",
type=str,
default=os.path.join(os.getcwd(), "shark_tmp"),
help="Path to tmp directory",
)
p.add_argument(
"--config_dir",
type=str,
default=os.path.join(os.getcwd(), "configs"),
help="Path to config directory",
)
p.add_argument(
"--model_dir",
type=str,
default=os.path.join(os.getcwd(), "models"),
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="One of: [api, app, web].",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for generating a public URL.",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="Flag for setting server port.",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling rest API.",
)
p.add_argument(
"--api_accept_origin",
action="append",
type=str,
help="An origin to be accepted by the REST api for Cross Origin"
"Resource Sharing (CORS). Use multiple times for multiple origins, "
'or use --api_accept_origin="*" to accept all origins. If no origins '
"are set no CORS headers will be returned by the api. Use, for "
"instance, if you need to access the REST api from Javascript running "
"in a web browser.",
)
p.add_argument(
"--debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling debugging log in WebUI.",
)
p.add_argument(
"--output_gallery",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the output gallery tab, and avoid exposing "
"images under --output_dir in the UI.",
)
p.add_argument(
"--configs_path",
default=None,
type=str,
help="Path to .json config directory.",
)
p.add_argument(
"--output_gallery_followlinks",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether the output gallery tab in the UI should "
"follow symlinks when listing subdirectories under --output_dir.",
)
p.add_argument(
"--api_log",
default=False,
action=argparse.BooleanOptionalAction,
help="Enables Compatibility API logging.",
)
##############################################################################
# SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file.",
)
p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)
p.add_argument(
"--save_annotation",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file.",
)
##############################################################################
# SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file.",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning.",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all.",
)
##############################################################################
# DocuChat Flags
##############################################################################
p.add_argument(
"--run_docuchat_web",
default=False,
action=argparse.BooleanOptionalAction,
help="Specifies whether the docuchat's web version is running or not.",
)
##############################################################################
# rocm Flags
##############################################################################
p.add_argument(
"--iree_rocm_target_chip",
type=str,
default="",
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
)
cmd_opts, unknown = p.parse_known_args()
if cmd_opts.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
os.getcwd(), cmd_opts.hf_model_id.replace("/", "_")
)

View File

@@ -1,106 +0,0 @@
import time
import argparse
class TimerSubcategory:
def __init__(self, timer, category):
self.timer = timer
self.category = category
self.start = None
self.original_base_category = timer.base_category
def __enter__(self):
self.start = time.time()
self.timer.base_category = self.original_base_category + self.category + "/"
self.timer.subcategory_level += 1
if self.timer.print_log:
print(f"{' ' * self.timer.subcategory_level}{self.category}:")
def __exit__(self, exc_type, exc_val, exc_tb):
elapsed_for_subcategroy = time.time() - self.start
self.timer.base_category = self.original_base_category
self.timer.add_time_to_record(
self.original_base_category + self.category,
elapsed_for_subcategroy,
)
self.timer.subcategory_level -= 1
self.timer.record(self.category, disable_log=True)
class Timer:
def __init__(self, print_log=False):
self.start = time.time()
self.records = {}
self.total = 0
self.base_category = ""
self.print_log = print_log
self.subcategory_level = 0
def elapsed(self):
end = time.time()
res = end - self.start
self.start = end
return res
def add_time_to_record(self, category, amount):
if category not in self.records:
self.records[category] = 0
self.records[category] += amount
def record(self, category, extra_time=0, disable_log=False):
e = self.elapsed()
self.add_time_to_record(self.base_category + category, e + extra_time)
self.total += e + extra_time
if self.print_log and not disable_log:
print(
f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s"
)
def subcategory(self, name):
self.elapsed()
subcat = TimerSubcategory(self, name)
return subcat
def summary(self):
res = f"{self.total:.1f}s"
additions = [
(category, time_taken)
for category, time_taken in self.records.items()
if time_taken >= 0.1 and "/" not in category
]
if not additions:
return res
res += " ("
res += ", ".join(
[f"{category}: {time_taken:.1f}s" for category, time_taken in additions]
)
res += ")"
return res
def dump(self):
return {"total": self.total, "records": self.records}
def reset(self):
self.__init__()
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument(
"--log-startup",
action="store_true",
help="print a detailed log of what's happening at startup",
)
args = parser.parse_known_args()[0]
startup_timer = Timer(print_log=args.log_startup)
startup_record = None

View File

@@ -1,48 +0,0 @@
# -*- mode: python ; coding: utf-8 -*-
from apps.shark_studio.studio_imports import pathex, datas, hiddenimports
binaries = []
block_cipher = None
a = Analysis(
['web/index.py'],
pathex=pathex,
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
module_collection_mode={
'gradio': 'py', # Collect gradio package as source .py files
},
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='nodai_shark_studio',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -1,68 +0,0 @@
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
# python path for pyinstaller
pathex = [
".",
]
# datafiles for pyinstaller
datas = []
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
datas += copy_metadata("packaging")
datas += copy_metadata("filelock")
datas += copy_metadata("numpy")
datas += copy_metadata("importlib_metadata")
datas += copy_metadata("omegaconf")
datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += copy_metadata("gradio")
datas += copy_metadata("scipy")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree", include_py_files=True)
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("scipy", include_py_files=True)
datas += [
("web/ui/css/*", "ui/css"),
("web/ui/js/*", "ui/js"),
("web/ui/logos/*", "logos"),
]
# hidden imports for pyinstaller
hiddenimports = ["shark", "apps"]
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
blacklist = ["tests", "convert"]
hiddenimports += [
x
for x in collect_submodules("transformers")
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
hiddenimports += ["iree._runtime"]
hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x]

View File

@@ -1,58 +0,0 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import logging
import unittest
import json
import gc
from apps.shark_studio.api.llm import LanguageModel, llm_chat_api
from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file
from apps.shark_studio.web.utils.file_utils import get_resource_path
# class SDAPITest(unittest.TestCase):
# def testSDSimple(self):
# from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
# import apps.shark_studio.web.utils.globals as global_obj
# global_obj._init()
# sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
# sd_kwargs = json.loads(sd_json)
# for arg in vars(cmd_opts):
# if arg in sd_kwargs:
# sd_kwargs[arg] = getattr(cmd_opts, arg)
# for i in shark_sd_fn_dict_input(sd_kwargs):
# print(i)
class LLMAPITest(unittest.TestCase):
def test01_LLMSmall(self):
lm = LanguageModel(
"TinyPixel/small-llama2",
hf_auth_token=None,
device="cpu",
precision="fp32",
quantization="None",
streaming_llm=True,
)
count = 0
label = "Turkishoure Turkish"
for msg, _ in lm.chat("hi, what are you?"):
# skip first token output
if count == 0:
count += 1
continue
assert (
msg.strip(" ") == label
), f"LLM API failed to return correct response, expected '{label}', received {msg}"
break
del lm
gc.collect()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

View File

@@ -1,41 +0,0 @@
import torch
from diffusers import (
UNet2DConditionModel,
)
from torch.fx.experimental.proxy_tensor import make_fx
class UnetModel(torch.nn.Module):
def __init__(self, hf_model_name):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
hf_model_name,
subfolder="unet",
)
def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
samples = torch.cat([sample] * 2)
unet_out = self.unet.forward(
samples, timestep, encoder_hidden_states, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
if __name__ == "__main__":
hf_model_name = "CompVis/stable-diffusion-v1-4"
unet = UnetModel(hf_model_name)
inputs = (torch.randn(1, 4, 64, 64), 1, torch.randn(2, 77, 768), 7.5)
fx_g = make_fx(
unet,
decomposition_table={},
tracing_mode="symbolic",
_allow_non_fake_inputs=True,
_allow_fake_constant=False,
)(*inputs)
print(fx_g)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 347 KiB

View File

@@ -1,45 +0,0 @@
import requests
from PIL import Image
import base64
from io import BytesIO
import json
def llm_chat_test(verbose=False):
# Define values here
prompt = "What is the significance of the number 42?"
url = "http://127.0.0.1:8080/v1/chat/completions"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"model": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"messages": [
{
"role": "",
"content": prompt,
}
],
"device": "vulkan://0",
"max_tokens": 4096,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
res_dict = json.loads(res.content.decode("utf-8"))
print(f"[chat] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(f"\n{res_dict['choices'][0]['message']['content']}\n")
if __name__ == "__main__":
# "Exercises the chatbot REST API of Shark. Make sure "
# "Shark is running in API mode on 127.0.0.1:8080 before running"
# "this script."
llm_chat_test(verbose=True)

View File

@@ -1,286 +0,0 @@
import base64
import io
import os
import time
import datetime
import uvicorn
import ipaddress
import requests
import threading
import collections
import gradio as gr
from PIL import Image, PngImagePlugin
from threading import Lock
from io import BytesIO
from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
# from sdapi_v1 import shark_sd_api
from apps.shark_studio.api.llm import llm_chat_api
def decode_base64_to_image(encoding):
if encoding.startswith("http://") or encoding.startswith("https://"):
headers = {}
response = requests.get(encoding, timeout=30, headers=headers)
try:
image = Image.open(BytesIO(response.content))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid image url") from e
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(
output_bytes,
format="PNG",
pnginfo=(metadata if use_metadata else None),
)
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
class FIFOLock(object):
def __init__(self):
self._lock = threading.Lock()
self._inner_lock = threading.Lock()
self._pending_threads = collections.deque()
def acquire(self, blocking=True):
with self._inner_lock:
lock_acquired = self._lock.acquire(False)
if lock_acquired:
return True
elif not blocking:
return False
release_event = threading.Event()
self._pending_threads.append(release_event)
release_event.wait()
return self._lock.acquire()
def release(self):
with self._inner_lock:
if self._pending_threads:
release_event = self._pending_threads.popleft()
release_event.set()
self._lock.release()
__enter__ = acquire
def __exit__(self, t, v, tb):
self.release()
def api_middleware(app: FastAPI):
rich_available = False
try:
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
import anyio # importing just so it can be placed on silent list
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
rich_available = True
except Exception:
pass
@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
res: Response = await call_next(req)
duration = str(round(time.time() - ts, 4))
res.headers["X-Process-Time"] = duration
endpoint = req.scope.get("path", "err")
if cmd_opts.api_log and endpoint.startswith("/sdapi"):
print(
"API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}".format(
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code=res.status_code,
ver=req.scope.get("http_version", "0.0"),
cli=req.scope.get("client", ("0:0.0.0", 0))[0],
prot=req.scope.get("scheme", "err"),
method=req.scope.get("method", "err"),
endpoint=endpoint,
duration=duration,
)
)
return res
def handle_exception(request: Request, e: Exception):
err = {
"error": type(e).__name__,
"detail": vars(e).get("detail", ""),
"body": vars(e).get("body", ""),
"errors": str(e),
}
if not isinstance(
e, HTTPException
): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
console.print_exception(
show_locals=True,
max_frames=2,
extra_lines=1,
suppress=[anyio, starlette],
word_wrap=False,
width=min([console.width, 200]),
)
else:
print(message)
raise (e)
return JSONResponse(
status_code=vars(e).get("status_code", 500),
content=jsonable_encoder(err),
)
@app.middleware("http")
async def exception_handling(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
return handle_exception(request, e)
@app.exception_handler(Exception)
async def fastapi_exception_handler(request: Request, e: Exception):
return handle_exception(request, e)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, e: HTTPException):
return handle_exception(request, e)
class ApiCompat:
def __init__(self, app: FastAPI, queue_lock: Lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
api_middleware(self.app)
# self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
# self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
# self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
# self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
# self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
# self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
# self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
# self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
# self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
# self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
# self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
# self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
# self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
# self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
# self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
# self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
# self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
# self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
# self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
# self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
# self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
# self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
# self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
# self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
# self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
# self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
# self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
# self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
# chat APIs needed for compatibility with multiple extensions using OpenAI API
self.add_api_route("/v1/chat/completions", llm_chat_api, methods=["POST"])
self.add_api_route("/v1/completions", llm_chat_api, methods=["POST"])
self.add_api_route("/chat/completions", llm_chat_api, methods=["POST"])
self.add_api_route("/completions", llm_chat_api, methods=["POST"])
self.add_api_route(
"/v1/engines/codegen/completions", llm_chat_api, methods=["POST"]
)
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
# def refresh_checkpoints(self):
# with self.queue_lock:
# studio_data.refresh_checkpoints()
# def refresh_vae(self):
# with self.queue_lock:
# studio_data.refresh_vae_list()
# def unloadapi(self):
# unload_model_weights()
# return {}
# def reloadapi(self):
# reload_model_weights()
# return {}
# def skip(self):
# studio.state.skip()
def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
uvicorn.run(
self.app,
host=server_name,
port=port,
root_path=root_path,
)
# def kill_studio(self):
# restart.stop_program()
# def restart_studio(self):
# if restart.is_restartable():
# restart.restart_program()
# return Response(status_code=501)
# def preprocess(self, args: dict):
# try:
# studio.state.begin(job="preprocess")
# preprocess(**args)
# studio.state.end()
# return models.PreprocessResponse(info="preprocess complete")
# except:
# studio.state.end()
# def stop_studio(request):
# studio.state.server_command = "stop"
# return Response("Stopping.")

View File

@@ -1 +0,0 @@

View File

@@ -1,222 +0,0 @@
from multiprocessing import Process, freeze_support
freeze_support()
from PIL import Image
import os
import time
import sys
import logging
import apps.shark_studio.api.initializers as initialize
from apps.shark_studio.modules import timer
startup_timer = timer.startup_timer
startup_timer.record("launcher")
initialize.imports()
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
# import before IREE to avoid MLIR library issues
import torch_mlir
def create_api(app):
from apps.shark_studio.web.api.compat import ApiCompat, FIFOLock
queue_lock = FIFOLock()
api = ApiCompat(app, queue_lock)
return api
def api_only():
from fastapi import FastAPI
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
initialize.initialize()
app = FastAPI()
initialize.setup_middleware(app)
api = create_api(app)
# from modules import script_callbacks
# script_callbacks.before_ui_callback()
# script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")
api.launch(
server_name="0.0.0.0",
port=cmd_opts.server_port,
root_path="",
)
def launch_webui(address):
from tkinter import Tk
import webview
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())
def webui():
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.ui.utils import (
amdicon_loc,
amdlogo_loc,
)
launch_api = cmd_opts.api
initialize.initialize()
from ui.chat import chat_element
from ui.sd import sd_element
from ui.outputgallery import outputgallery_element
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
# if args.api or "api" in args.ui.split(","):
# from apps.shark_studio.api.llm import (
# chat,
# )
# from apps.shark_studio.web.api import sdapi
#
# from fastapi import FastAPI, APIRouter
# from fastapi.middleware.cors import CORSMiddleware
# import uvicorn
#
# # init global sd pipeline and config
# global_obj._init()
#
# api = FastAPI()
# api.mount("/sdapi/", sdapi)
#
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
# api.add_api_route(
# "/v1/chat/completions", llm_chat_api, methods=["post"]
# )
# api.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
# api.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
# api.add_api_route("/completions", llm_chat_api, methods=["post"])
# api.add_api_route(
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
# )
# api.include_router(APIRouter())
#
# # deal with CORS requests if CORS accept origins are set
# if args.api_accept_origin:
# print(
# f"API Configured for CORS. Accepting origins: { args.api_accept_origin }"
# )
# api.add_middleware(
# CORSMiddleware,
# allow_origins=args.api_accept_origin,
# allow_methods=["GET", "POST"],
# allow_headers=["*"],
# )
# else:
# print("API not configured for CORS")
#
# uvicorn.run(api, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
import gradio as gr
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js")
# from apps.shark_studio.web.ui import load_ui_from_script
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme,
js=gradio_workarounds,
analytics_enabled=False,
title="Shark Studio 2.0 Beta",
) as studio_web:
amd_logo = Image.open(amdlogo_loc)
gr.Image(
value=amd_logo,
show_label=False,
interactive=False,
elem_id="tab_bar_logo",
show_download_button=False,
)
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
# and that the order in the code here is the order they should
# appear in the ui, as the id value doesn't determine the order.
# Where possible, avoid changing the id of any tab that is the
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
with gr.TabItem(label="Stable Diffusion", id=0):
sd_element.render()
with gr.TabItem(label="Output Gallery", id=1):
outputgallery_element.render()
with gr.TabItem(label="Chat Bot", id=2):
chat_element.render()
studio_web.queue()
# if args.ui == "app":
# t = Process(
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
# )
# t.start()
studio_web.launch(
share=cmd_opts.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=cmd_opts.server_port,
favicon_path=amdicon_loc,
)
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
if cmd_opts.webui == False:
api_only()
else:
webui()

View File

@@ -1,239 +0,0 @@
import gradio as gr
import time
import os
from pathlib import Path
from datetime import datetime as dt
import json
import sys
from apps.shark_studio.api.llm import (
llm_model_map,
LanguageModel,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
B_SYS, E_SYS = "<s>", "</s>"
B_SYS, E_SYS = "<s>", "</s>"
B_SYS, E_SYS = "<s>", "</s>"
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def append_bot_prompt(history, input_prompt):
user_prompt = f"{input_prompt} {E_SYS} {E_SYS}"
history += user_prompt
return history
language_model = None
def get_default_config():
return False
# model_vmfb_key = ""
def chat_fn(
prompt_prefix,
history,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
cli=False,
):
global language_model
if streaming_llm and prompt_prefix == "Clear":
language_model = None
return "Clearing history...", ""
if language_model is None:
history[-1][-1] = "Getting the model ready..."
yield history, ""
language_model = LanguageModel(
model,
device=device,
precision=precision,
external_weights="safetensors",
use_system_prompt=prompt_prefix,
streaming_llm=streaming_llm,
hf_auth_token=cmd_opts.hf_auth_token,
)
history[-1][-1] = "Getting the model ready... Done"
yield history, ""
history[-1][-1] = ""
token_count = 0
total_time = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, exec_time in language_model.chat(history):
history[-1][-1] = f"{text}{E_SYS}"
if is_first:
prefill_time = exec_time
is_first = False
yield history, f"Prefill: {prefill_time:.2f}"
else:
total_time += exec_time
token_count += 1
tokens_per_sec = token_count / total_time
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
def view_json_file(file_obj):
content = ""
with open(file_obj.name, "r") as fopen:
content = fopen.read()
return content
with gr.Blocks(title="Chat") as chat_element:
with gr.Row():
model_choices = list(llm_model_map.keys())
model = gr.Dropdown(
label="Select Model",
value=model_choices[0],
choices=model_choices,
allow_custom_value=True,
)
supported_devices = global_obj.get_device_list()
enabled = True
if len(supported_devices) == 0:
supported_devices = ["cpu-task"]
supported_devices = [x for x in supported_devices if "sync" not in x]
device = gr.Dropdown(
label="Device",
value=supported_devices[0],
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
)
precision = gr.Radio(
label="Precision",
value="fp32",
choices=[
# "int4",
# "int8",
# "fp16",
"fp32",
],
visible=False,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=False,
interactive=True,
visible=False,
)
streaming_llm = gr.Checkbox(
label="Run in streaming mode (requires recompilation)",
value=True,
interactive=False,
visible=False,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=True,
interactive=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(label="Upload sharding configuration", visible=False)
json_view_button = gr.Button("View as JSON", visible=False)
json_view = gr.JSON(visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
submit_click_event = submit.click(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(
fn=chat_fn,
inputs=[
clear,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
).then(lambda: None, None, [chatbot], queue=False)

View File

@@ -1,67 +0,0 @@
from apps.shark_studio.web.ui.utils import (
HSLHue,
hsl_color,
)
from apps.shark_studio.modules.embeddings import get_lora_metadata
# Answers HTML to show the most frequent tags used when a LoRA was trained,
# taken from the metadata of its .safetensors file.
def lora_changed(lora_files):
# tag frequency percentage, that gets maximum amount of the staring hue
TAG_COLOR_THRESHOLD = 0.55
# tag frequency percentage, above which a tag is displayed
TAG_DISPLAY_THRESHOLD = 0.65
# template for the html used to display a tag
TAG_HTML_TEMPLATE = (
'<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
)
output = []
for lora_file in lora_files:
if lora_file == "":
output.extend(["<div><i>No LoRA selected</i></div>"])
elif not lora_file.lower().endswith(".safetensors"):
output.extend(
[
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
]
)
else:
metadata = get_lora_metadata(lora_file)
if metadata:
frequencies = metadata["frequencies"]
output.extend(
[
"".join(
[
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
]
+ [
TAG_HTML_TEMPLATE.format(
color=hsl_color(
(tag[1] - TAG_COLOR_THRESHOLD)
/ (1 - TAG_COLOR_THRESHOLD),
start=HSLHue.RED,
end=HSLHue.GREEN,
),
tag=tag[0],
)
for tag in frequencies
if tag[1] > TAG_DISPLAY_THRESHOLD
],
)
]
)
elif metadata is None:
output.extend(
[
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
]
)
else:
output.extend(
[
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]
)
return output

View File

@@ -1,49 +0,0 @@
// workaround gradio after 4.7, not applying any @media rules form the custom .css file
() => {
console.log(`innerWidth: ${window.innerWidth}` )
// 1536px rules
const mediaQuery1536 = window.matchMedia('(min-width: 1536px)')
function handleWidth1536(event) {
// display in full width for desktop devices
document.querySelectorAll(".gradio-container")
.forEach( (node) => {
if (event.matches) {
node.classList.add("gradio-container-size-full");
} else {
node.classList.remove("gradio-container-size-full")
}
});
}
mediaQuery1536.addEventListener("change", handleWidth1536);
mediaQuery1536.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1536}));
// 1921px rules
const mediaQuery1921 = window.matchMedia('(min-width: 1921px)')
function handleWidth1921(event) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
document.querySelectorAll("#gallery")
.forEach( (node) => {
if (event.matches) {
node.classList.add("gallery-force-height768");
node.classList.add("gallery-limit-height768");
} else {
node.classList.remove("gallery-force-height768");
node.classList.remove("gallery-limit-height768");
}
});
}
mediaQuery1921.addEventListener("change", handleWidth1921);
mediaQuery1921.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1921}));
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.4 KiB

View File

@@ -1,777 +0,0 @@
import os
import json
import gradio as gr
import numpy as np
from inspect import signature
from PIL import Image
from pathlib import Path
from datetime import datetime as dt
from gradio.components.image_editor import (
EditorValue,
)
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_checkpoints_path,
get_checkpoints,
get_configs_path,
write_default_sd_configs,
)
from apps.shark_studio.api.sd import (
shark_sd_fn_dict_input,
cancel_sd,
unload_sd,
)
from apps.shark_studio.api.controlnet import (
cnet_preview,
)
from apps.shark_studio.modules.schedulers import (
scheduler_model_map,
)
from apps.shark_studio.modules.img_processing import (
resampler_list,
resize_stencil,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.ui.utils import (
amdlogo_loc,
none_to_str_none,
str_none_to_none,
)
from apps.shark_studio.web.utils.state import (
status_label,
)
from apps.shark_studio.web.ui.common_events import lora_changed
from apps.shark_studio.modules import logger
import apps.shark_studio.web.utils.globals as global_obj
sd_default_models = [
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1-base",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/sdxl-turbo",
]
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
content = fopen.read()
return content
def submit_to_cnet_config(
stencil: str,
preprocessed_hint: str,
cnet_strength: int,
control_mode: str,
curr_config: dict,
):
if any(i in [None, ""] for i in [stencil, preprocessed_hint]):
return gr.update()
if curr_config is not None:
if "controlnets" in curr_config:
curr_config["controlnets"]["control_mode"] = control_mode
curr_config["controlnets"]["model"].append(stencil)
curr_config["controlnets"]["hint"].append(preprocessed_hint)
curr_config["controlnets"]["strength"].append(cnet_strength)
return curr_config
cnet_map = {}
cnet_map["controlnets"] = {
"control_mode": control_mode,
"model": [stencil],
"hint": [preprocessed_hint],
"strength": [cnet_strength],
}
return cnet_map
def update_embeddings_json(embedding):
return {"embeddings": [embedding]}
def submit_to_main_config(input_cfg: dict, main_cfg: dict):
if main_cfg in [None, "", {}]:
return input_cfg
for base_key in input_cfg:
main_cfg[base_key] = input_cfg[base_key]
return main_cfg
def pull_sd_configs(
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
controlnets,
embeddings,
):
sd_args = str_none_to_none(locals())
sd_cfg = {}
for arg in sd_args:
if arg in [
"prompt",
"negative_prompt",
"sd_init_image",
]:
sd_cfg[arg] = [sd_args[arg]]
elif arg in ["controlnets", "embeddings"]:
if isinstance(arg, dict):
sd_cfg[arg] = json.loads(sd_args[arg])
else:
sd_cfg[arg] = {}
else:
sd_cfg[arg] = sd_args[arg]
return json.dumps(sd_cfg)
def load_sd_cfg(sd_json: dict, load_sd_config: str):
new_sd_config = none_to_str_none(json.loads(view_json_file(load_sd_config)))
if sd_json:
for key in new_sd_config:
sd_json[key] = new_sd_config[key]
else:
sd_json = new_sd_config
for i in sd_json["sd_init_image"]:
if i is not None:
if os.path.isfile(i):
sd_image = [Image.open(i, mode="r")]
else:
sd_image = None
return [
sd_json["prompt"][0],
sd_json["negative_prompt"][0],
sd_image,
sd_json["height"],
sd_json["width"],
sd_json["steps"],
sd_json["strength"],
sd_json["guidance_scale"],
sd_json["seed"],
sd_json["batch_count"],
sd_json["batch_size"],
sd_json["scheduler"],
sd_json["base_model_id"],
sd_json["custom_weights"],
sd_json["custom_vae"],
sd_json["precision"],
sd_json["device"],
sd_json["target_triple"],
sd_json["ondemand"],
sd_json["compiled_pipeline"],
sd_json["resample_type"],
sd_json["controlnets"],
sd_json["embeddings"],
sd_json,
]
def save_sd_cfg(config: dict, save_name: str):
if os.path.exists(save_name):
filepath = save_name
elif cmd_opts.configs_path:
filepath = os.path.join(cmd_opts.configs_path, save_name)
else:
filepath = os.path.join(get_configs_path(), save_name)
if ".json" not in filepath:
filepath += ".json"
with open(filepath, mode="w") as f:
f.write(json.dumps(config))
return "..."
def create_canvas(width, height):
data = Image.fromarray(
np.zeros(
shape=(height, width, 3),
dtype=np.uint8,
)
+ 255
)
img_dict = {
"background": data,
"layers": [],
"composite": None,
}
return EditorValue(img_dict)
def import_original(original_img, width, height):
if original_img is None:
resized_img = create_canvas(width, height)
return resized_img
else:
resized_img, _, _ = resize_stencil(original_img, width, height)
img_dict = {
"background": resized_img,
"layers": [],
"composite": None,
}
return EditorValue(img_dict)
def base_model_changed(base_model_id):
new_choices = get_checkpoints(
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
) + get_checkpoints(model_type="checkpoints")
return gr.Dropdown(
value=new_choices[0] if len(new_choices) > 0 else "None",
choices=["None"] + new_choices,
)
with gr.Blocks(title="Stable Diffusion") as sd_element:
with gr.Column(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=2, min_width=600):
with gr.Accordion(
label="\U0001F4D0\U0000FE0F Device Settings", open=False
):
device = gr.Dropdown(
elem_id="device",
label="Device",
value=global_obj.get_device_list()[0],
choices=global_obj.get_device_list(),
allow_custom_value=False,
)
target_triple = gr.Textbox(
elem_id="target_triple",
label="Architecture",
value="",
)
with gr.Row():
ondemand = gr.Checkbox(
value=cmd_opts.lowvram,
label="Low VRAM",
interactive=True,
)
precision = gr.Radio(
label="Precision",
value=cmd_opts.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}"
base_model_id = gr.Dropdown(
label="\U000026F0\U0000FE0F Base Model",
info="Select or enter HF model ID",
elem_id="custom_model",
value="stabilityai/stable-diffusion-2-1-base",
choices=sd_default_models,
allow_custom_value=True,
) # base_model_id
with gr.Row():
height = gr.Slider(
384,
1024,
value=cmd_opts.height,
step=8,
label="\U00002195\U0000FE0F Height",
)
width = gr.Slider(
384,
1024,
value=cmd_opts.width,
step=8,
label="\U00002194\U0000FE0F Width",
)
with gr.Accordion(
label="\U00002696\U0000FE0F Model Weights", open=False
):
with gr.Column():
custom_weights = gr.Dropdown(
label="Checkpoint Weights",
info="Select or enter HF model ID",
elem_id="custom_model",
value="None",
allow_custom_value=True,
choices=["None"]
+ get_checkpoints(os.path.basename(str(base_model_id))),
) # custom_weights
base_model_id.change(
fn=base_model_changed,
inputs=[base_model_id],
outputs=[custom_weights],
)
sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
"\\", "\n\\"
)
sd_vae_info = f"VAE Path: {sd_vae_info}"
custom_vae = gr.Dropdown(
label=f"VAE Model",
info=sd_vae_info,
elem_id="custom_model",
value=(
os.path.basename(cmd_opts.custom_vae)
if cmd_opts.custom_vae
else "None"
),
choices=["None"] + get_checkpoints("vae"),
allow_custom_value=True,
scale=1,
)
sd_lora_info = (str(get_checkpoints_path("loras"))).replace(
"\\", "\n\\"
)
lora_opt = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=sd_lora_info,
elem_id="lora_weights",
value=None,
multiselect=True,
choices=[] + get_checkpoints("lora"),
scale=2,
)
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
embeddings_config = gr.JSON(
label="Embeddings Options", min_width=50, scale=1
)
gr.on(
triggers=[lora_opt.change],
fn=lora_changed,
inputs=[lora_opt],
outputs=[lora_tags],
queue=True,
show_progress=False,
).then(
fn=update_embeddings_json,
inputs=[lora_opt],
outputs=[embeddings_config],
show_progress=False,
)
with gr.Accordion(
label="\U0001F9EA\U0000FE0F Input Image Processing", open=False
):
strength = gr.Slider(
0,
1,
value=cmd_opts.strength,
step=0.01,
label="Denoising Strength",
)
resample_type = gr.Dropdown(
value=cmd_opts.resample_type,
choices=resampler_list,
label="Resample Type",
allow_custom_value=True,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="\U00002795\U0000FE0F Prompt",
value=cmd_opts.prompt[0],
lines=2,
elem_id="prompt_box",
show_copy_button=True,
)
negative_prompt = gr.Textbox(
label="\U00002796\U0000FE0F Negative Prompt",
value=cmd_opts.negative_prompt[0],
lines=2,
elem_id="negative_prompt_box",
show_copy_button=True,
)
with gr.Row(equal_height=True):
seed = gr.Textbox(
value=cmd_opts.seed,
label="\U0001F331\U0000FE0F Seed",
info="An integer or a JSON list of integers, -1 for random",
show_copy_button=True,
)
scheduler = gr.Dropdown(
elem_id="scheduler",
label="\U0001F4C5\U0000FE0F Scheduler",
info="\U000E0020", # forces same height as seed
value="EulerDiscrete",
choices=scheduler_model_map.keys(),
allow_custom_value=False,
)
with gr.Row():
steps = gr.Slider(
1,
100,
value=cmd_opts.steps,
step=1,
label="\U0001F3C3\U0000FE0F Steps",
)
guidance_scale = gr.Slider(
0,
50,
value=cmd_opts.guidance_scale,
step=0.1,
label="\U0001F5C3\U0000FE0F CFG Scale",
)
with gr.Accordion(
label="Controlnet Options",
open=False,
visible=False,
):
preprocessed_hints = gr.State([])
with gr.Column():
sd_cnet_info = (
str(get_checkpoints_path("controlnet"))
).replace("\\", "\n\\")
with gr.Row():
cnet_config = gr.JSON()
with gr.Column():
clear_config = gr.ClearButton(
value="Clear Controlnet Config",
size="sm",
components=cnet_config,
)
control_mode = gr.Radio(
choices=["Prompt", "Balanced", "Controlnet"],
value="Balanced",
label="Control Mode",
)
with gr.Row():
with gr.Column(scale=1):
cnet_model = gr.Dropdown(
allow_custom_value=True,
label=f"Controlnet Model",
info=sd_cnet_info,
value="None",
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
]
+ get_checkpoints("controlnet"),
)
cnet_strength = gr.Slider(
label="Controlnet Strength",
minimum=0,
maximum=100,
value=50,
step=1,
)
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=8,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=8,
)
make_canvas = gr.Button(
value="Make Canvas!",
)
use_input_img = gr.Button(
value="Use Original Image",
size="sm",
)
cnet_input = gr.Image(
value=None,
type="pil",
image_mode="RGB",
interactive=True,
)
with gr.Column(scale=1):
cnet_output = gr.Image(
value=None,
visible=True,
label="Preprocessed Hint",
interactive=False,
show_label=True,
)
cnet_gen = gr.Button(
value="Preprocess controlnet input",
)
use_result = gr.Button(
"Submit",
size="sm",
)
make_canvas.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[cnet_input],
queue=False,
)
cnet_gen.click(
fn=cnet_preview,
inputs=[
cnet_model,
cnet_input,
],
outputs=[
cnet_output,
preprocessed_hints,
],
)
use_result.click(
fn=submit_to_cnet_config,
inputs=[
cnet_model,
cnet_output,
cnet_strength,
control_mode,
cnet_config,
],
outputs=[
cnet_config,
],
queue=False,
)
with gr.Column(scale=3, min_width=600):
with gr.Tabs() as sd_tabs:
sd_element.load(
# Workaround for Gradio issue #7085
# TODO: revert to setting selected= in gr.Tabs declaration
# once this is resolved in Gradio
lambda: gr.Tabs(selected=101),
outputs=[sd_tabs],
)
with gr.Tab(label="Input Image", id=100) as sd_tab_init_image:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
# TODO: make this import image prompt info if it exists
sd_init_image = gr.Image(
type="pil",
interactive=True,
show_label=False,
)
use_input_img.click(
fn=import_original,
inputs=[
sd_init_image,
canvas_width,
canvas_height,
],
outputs=[cnet_input],
queue=False,
)
with gr.Tab(label="Generate Images", id=101) as sd_tab_gallery:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
sd_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
columns=2,
object_fit="fit",
preview=True,
)
with gr.Row():
batch_count = gr.Slider(
1,
100,
value=cmd_opts.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=cmd_opts.batch_size,
step=1,
label="Batch Size",
interactive=True,
visible=True,
)
compiled_pipeline = gr.Checkbox(
False,
label="Faster txt2img (SDXL only)",
)
with gr.Row():
stable_diffusion = gr.Button("Start")
unload = gr.Button("Unload Models")
unload.click(
fn=unload_sd,
queue=False,
show_progress=False,
)
stop_batch = gr.Button("Stop")
with gr.Tab(label="Config", id=102) as sd_tab_config:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
Path(get_configs_path()).mkdir(
parents=True, exist_ok=True
)
default_config_file = os.path.join(
get_configs_path(),
"default_sd_config.json",
)
write_default_sd_configs(get_configs_path())
sd_json = gr.JSON(
elem_classes=["fill"],
value=view_json_file(default_config_file),
)
with gr.Row():
with gr.Column(scale=3):
load_sd_config = gr.FileExplorer(
label="Load Config",
file_count="single",
root_dir=(
cmd_opts.configs_path
if cmd_opts.configs_path
else get_configs_path()
),
height=75,
)
with gr.Column(scale=1):
save_sd_config = gr.Button(
value="Save Config", size="sm"
)
clear_sd_config = gr.ClearButton(
value="Clear Config",
size="sm",
components=sd_json,
)
with gr.Row():
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
show_label=False,
)
load_sd_config.change(
fn=load_sd_cfg,
inputs=[sd_json, load_sd_config],
outputs=[
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
sd_json,
],
)
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Tab(label="Log", id=103) as sd_tab_log:
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
pull_kwargs = dict(
fn=pull_sd_configs,
inputs=[
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
],
outputs=[
sd_json,
],
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Stable Diffusion", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=sd_status,
)
gen_kwargs = dict(
fn=shark_sd_fn_dict_input,
inputs=[sd_json],
outputs=[
sd_gallery,
sd_status,
],
)
prompt_submit = prompt.submit(**status_kwargs).then(**pull_kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(**pull_kwargs)
generate_click = (
stable_diffusion.click(**status_kwargs).then(**pull_kwargs).then(**gen_kwargs)
)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -1,43 +0,0 @@
from enum import IntEnum
import math
import sys
import os
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
amdlogo_loc = resource_path("logos/amd-logo.jpg")
amdicon_loc = resource_path("logos/amd-icon.jpg")
class HSLHue(IntEnum):
RED = 0
YELLOW = 60
GREEN = 120
CYAN = 180
BLUE = 240
MAGENTA = 300
def hsl_color(alpha: float, start, end):
b = (end - start) * (alpha if alpha > 0 else 0)
result = b + start
# Return a CSS HSL string
return f"hsl({math.floor(result)}, 80%, 35%)"
def none_to_str_none(props: dict):
for key in props:
props[key] = "None" if props[key] == None else props[key]
return props
def str_none_to_none(props: dict):
for key in props:
props[key] = None if props[key] == "None" else props[key]
return props

View File

@@ -1,12 +0,0 @@
import os
import sys
def get_available_devices():
return ["cpu-task"]
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)

View File

@@ -1,95 +0,0 @@
default_sd_config = r"""{
"prompt": [
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 50,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": false,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
sdxl_30steps = r"""{
"prompt": [
"a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 1024,
"width": 1024,
"steps": 30,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-xl-base-1.0",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": true,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
sdxl_turbo = r"""{
"prompt": [
"A cat wearing a hat that says 'TURBO' on it. The cat is sitting on a skateboard."
],
"negative_prompt": [
""
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 2,
"strength": 0.8,
"guidance_scale": 0,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerAncestralDiscrete",
"base_model_id": "stabilityai/sdxl-turbo",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": true,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
default_sd_configs = {
"default_sd_config.json": default_sd_config,
"sdxl-30steps.json": sdxl_30steps,
"sdxl-turbo.json": sdxl_turbo,
}

View File

@@ -1,102 +0,0 @@
import os
import sys
import glob
from datetime import datetime as dt
from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
checkpoints_filetypes = (
"*.ckpt",
"*.safetensors",
)
from apps.shark_studio.web.utils.default_configs import default_sd_configs
def write_default_sd_configs(path):
for key in default_sd_configs.keys():
config_fpath = os.path.join(path, key)
with open(config_fpath, "w") as f:
f.write(default_sd_configs[key])
def safe_name(name):
return name.split("/")[-1].replace("-", "_")
def get_path_stem(path):
path = Path(path)
return path.stem
def get_resource_path(path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
if os.path.isabs(path):
return path
else:
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
result = Path(os.path.join(base_path, path)).resolve(strict=False)
return result
def get_configs_path() -> Path:
configs = get_resource_path(cmd_opts.config_dir)
if not os.path.exists(configs):
os.mkdir(configs)
return Path(configs)
def get_generated_imgs_path() -> Path:
outputs = get_resource_path(cmd_opts.output_dir)
if not os.path.exists(outputs):
os.mkdir(outputs)
return Path(outputs)
def get_tmp_path() -> Path:
tmpdir = get_resource_path(cmd_opts.model_dir)
if not os.path.exists(tmpdir):
os.mkdir(tmpdir)
return Path(tmpdir)
def get_generated_imgs_todays_subdir() -> str:
return dt.now().strftime("%Y%m%d")
def create_model_folders():
dir = ["checkpoints", "vae", "lora", "vmfb"]
if not os.path.isdir(cmd_opts.model_dir):
try:
os.makedirs(cmd_opts.model_dir)
except OSError:
sys.exit(
f"Invalid --model_dir argument, "
f"{cmd_opts.model_dir} folder does not exist, and cannot be created."
)
for root in dir:
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)
def get_checkpoints_path(model_type=""):
return get_resource_path(os.path.join(cmd_opts.model_dir, model_type))
def get_checkpoints(model_type="checkpoints"):
ckpt_files = []
file_types = checkpoints_filetypes
if model_type == "lora":
file_types = file_types + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_checkpoints_path(model_type), extn))
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_checkpoint_pathfile(checkpoint_name, model_type="checkpoints"):
return os.path.join(get_checkpoints_path(model_type), checkpoint_name)

View File

@@ -1,134 +0,0 @@
import gc
from ...api.utils import get_available_devices
"""
The global objects include SD pipeline and config.
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
Also we could avoid memory leak when switching models by clearing the cache.
"""
def _init():
global _sd_obj
global _llm_obj
global _devices
global _pipe_kwargs
global _prep_kwargs
global _gen_kwargs
global _schedulers
_sd_obj = None
_llm_obj = None
_devices = None
_pipe_kwargs = None
_prep_kwargs = None
_gen_kwargs = None
_schedulers = None
set_devices()
def set_sd_obj(value):
global _sd_obj
global _llm_obj
_llm_obj = None
_sd_obj = value
def set_llm_obj(value):
global _sd_obj
global _llm_obj
_llm_obj = value
_sd_obj = None
def set_devices():
global _devices
_devices = get_available_devices()
def set_sd_scheduler(key):
global _sd_obj
_sd_obj.scheduler = _schedulers[key]
def set_sd_status(value):
global _sd_obj
_sd_obj.status = value
def set_pipe_kwargs(value):
global _pipe_kwargs
_pipe_kwargs = value
def set_prep_kwargs(value):
global _prep_kwargs
_prep_kwargs = value
def set_gen_kwargs(value):
global _gen_kwargs
_gen_kwargs = value
def set_schedulers(value):
global _schedulers
_schedulers = value
def get_sd_obj():
global _sd_obj
return _sd_obj
def get_llm_obj():
global _llm_obj
return _llm_obj
def get_device_list():
global _devices
return _devices
def get_sd_status():
global _sd_obj
return _sd_obj.status
def get_pipe_kwargs():
global _pipe_kwargs
return _pipe_kwargs
def get_prep_kwargs():
global _prep_kwargs
return _prep_kwargs
def get_gen_kwargs():
global _gen_kwargs
return _gen_kwargs
def get_scheduler(key):
global _schedulers
return _schedulers[key]
def clear_cache():
global _sd_obj
global _llm_obj
global _pipe_kwargs
global _prep_kwargs
global _gen_kwargs
global _schedulers
del _sd_obj
del _llm_obj
del _schedulers
gc.collect()
_sd_obj = None
_llm_obj = None
_pipe_kwargs = None
_prep_kwargs = None
_gen_kwargs = None
_schedulers = None

View File

@@ -1,43 +0,0 @@
import csv
import os
from .format import humanize, humanizable
def csv_path(image_filename: str):
return os.path.join(os.path.dirname(image_filename), "imgs_details.csv")
def has_csv(image_filename: str) -> bool:
return os.path.exists(csv_path(image_filename))
def matching_filename(image_filename: str, row):
# we assume the final column of the csv has the original filename with full path and match that
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
# the value of the OUTPUT key
return os.path.basename(image_filename) in (
row[-1] if isinstance(row, list) else row["OUTPUT"]
)
def parse_csv(image_filename: str):
csv_filename = csv_path(image_filename)
with open(csv_filename, "r", newline="") as csv_file:
# We use a reader or DictReader here for images_details.csv depending on whether we think it
# has headers or not. Having headers means less guessing of the format.
has_header = csv.Sniffer().has_header(csv_file.read(2048))
csv_file.seek(0)
reader = csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
matches = [
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
humanize(row)
for row in reader
if row
and (has_header or humanizable(row))
and matching_filename(image_filename, row)
]
return matches[0] if matches else {}

View File

@@ -1,216 +0,0 @@
import re
from pathlib import Path
from apps.shark_studio.web.utils.file_utils import (
get_checkpoint_pathfile,
)
from apps.shark_studio.api.sd import EMPTY_SD_MAP as sd_model_map
from apps.shark_studio.modules.schedulers import (
scheduler_model_map,
)
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
def parse_generation_parameters(x: str):
res = {}
prompt = ""
negative_prompt = ""
done_with_prompt = False
*lines, lastline = x.strip().split("\n")
if len(re_param.findall(lastline)) < 3:
lines.append(lastline)
lastline = ""
for i, line in enumerate(lines):
line = line.strip()
if line.startswith("Negative prompt:"):
done_with_prompt = True
line = line[16:].strip()
if done_with_prompt:
negative_prompt += ("" if negative_prompt == "" else "\n") + line
else:
prompt += ("" if prompt == "" else "\n") + line
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v)
if m is not None:
res[k + "-1"] = m.group(1)
res[k + "-2"] = m.group(2)
else:
res[k] = v
# Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res:
res["Clip skip"] = "1"
hypernet = res.get("Hypernet", None)
if hypernet is not None:
res[
"Prompt"
] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
res["Hires resize-2"] = 0
return res
def try_find_model_base_from_png_metadata(file: str, folder: str = "models") -> str:
custom = ""
# Remove extension from file info
if file.endswith(".safetensors") or file.endswith(".ckpt"):
file = Path(file).stem
# Check for the file name match with one of the local ckpt or safetensors files
if Path(get_checkpoint_pathfile(file + ".ckpt", folder)).is_file():
custom = file + ".ckpt"
if Path(get_checkpoint_pathfile(file + ".safetensors", folder)).is_file():
custom = file + ".safetensors"
return custom
def find_model_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
png_hf_id = ""
png_custom = ""
if key in metadata:
model_file = metadata[key]
png_custom = try_find_model_base_from_png_metadata(model_file)
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if model_file in sd_model_map:
png_custom = model_file
# If nothing had matched, check vendor/hf_model_id
if not png_custom and model_file.count("/"):
png_hf_id = model_file
# No matching model was found
if not png_custom and not png_hf_id:
print(
"Import PNG info: Unable to find a matching model for %s" % model_file
)
return png_custom, png_hf_id
def find_vae_from_png_metadata(key: str, metadata: dict[str, str | int]) -> str:
vae_custom = ""
if key in metadata:
vae_file = metadata[key]
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
# VAE input is optional, should not print or throw an error if missing
return vae_custom
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
lora_custom = ""
if key in metadata:
lora_file = metadata[key]
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
def import_png_metadata(
pil_data,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
cfg_scale = float(metadata["CFG scale"])
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
elif "Model" in metadata and png_hf_model_id:
custom_model = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
if "Prompt" in metadata:
prompt = metadata["Prompt"]
if "Sampler" in metadata:
if metadata["Sampler"] in scheduler_model_map:
sampler = metadata["Sampler"]
else:
print(
"Import PNG info: Unable to find a scheduler for %s"
% metadata["Sampler"]
)
except Exception as ex:
if pil_data and pil_data.info.get("parameters"):
print("import_png_metadata failed with %s" % ex)
pass
return (
None,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
)

View File

@@ -1,39 +0,0 @@
import apps.shark_studio.web.utils.globals as global_obj
import gc
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
if batch_index < batch_count:
bs = f"x{batch_size}" if batch_size > 1 else ""
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
else:
return f"{tab_name} complete"
def get_generation_text_info(seeds, device):
cfg_dump = {}
for cfg in global_obj.get_config_dict():
cfg_dump[cfg] = cfg
text_output = f"prompt={cfg_dump['prompts']}"
text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}"
text_output += (
f"\nmodel_id={cfg_dump['hf_model_id']}, " f"ckpt_loc={cfg_dump['ckpt_loc']}"
)
text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}"
text_output += (
f"\nsteps={cfg_dump['steps']}, "
f"guidance_scale={cfg_dump['guidance_scale']}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={cfg_dump['height']}x{cfg_dump['width']}, "
if not cfg_dump.use_hiresfix
else f"\nsize={cfg_dump['hiresfix_height']}x{cfg_dump['hiresfix_width']}, "
)
text_output += (
f"batch_count={cfg_dump['batch_count']}, "
f"batch_size={cfg_dump['batch_size']}, "
f"max_length={cfg_dump['max_length']}"
)
return text_output

View File

@@ -1,75 +0,0 @@
import os
import shutil
from time import time
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
shark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "shark_tmp/")
def clear_tmp_mlir():
cleanup_start = time()
print("Clearing .mlir temporary files from a prior run. This may take some time...")
mlir_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.endswith(".mlir")
]
for filename in mlir_files:
os.remove(os.path.join(shark_tmp, filename))
print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.")
def clear_tmp_imgs():
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
print(
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
)
# Clear all gradio tmp images from the last session
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
cleanup_start = time()
print(
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
)
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
print(
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
)
# older SHARK versions had to workaround gradio bugs and stored things differently
else:
image_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
print(
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
)
cleanup_start = time()
for filename in image_files:
os.remove(shark_tmp + filename)
print(
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
)
else:
print("No temporary images files to clear.")
def config_tmp():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
clear_tmp_mlir()
clear_tmp_imgs()

View File

@@ -0,0 +1,87 @@
Compile / Run Instructions:
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
Compile Commands FP32/FP16:
```shell
Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
CPU:
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
```
Run / Benchmark Command (FP32 - NCHW):
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
```shell
## Vulkan AMD:
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CUDA:
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=cuda --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CPU:
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=local-task --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
Run via vulkan_gui for RGP Profiling:
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
```shell
./build/vulkan_gui/iree-vulkan-gui --module=/path/to/unet.vmfb --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
</details>
<details>
<summary>Debug Commands</summary>
## Debug commands and other advanced usage follows.
```shell
python txt2img.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
```
## dump all dispatch .spv and isa using amdllpc
```shell
python txt2img.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
```
## Compile and save the .vmfb (using vulkan fp16 as an example):
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
```
## Capture an RGP trace
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
```
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
```shell
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf16
```
## Run the unet module with iree-benchmark-module (same config as above):
```shell
##if you want to use .npz inputs:
unzip ~/.local/shark_tank/<your unet>/inputs.npz
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --input=@arr_0.npy --input=1xf16 --input=@arr_2.npy --input=@arr_3.npy --input=@arr_4.npy
```
</details>

View File

@@ -0,0 +1 @@
from apps.stable_diffusion.scripts.train_lora_word import lora_train

View File

@@ -0,0 +1,126 @@
import sys
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
StencilPipeline,
resize_stencil,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
def main():
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
image = Image.open(args.img_path).convert("RGB")
# When the models get uploaded, it should be default to False.
args.import_mlir = True
use_stencil = args.use_stencil
if use_stencil:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, args.width, args.height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model
if use_stencil:
img2img_obj = StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
start_time = time.time()
generated_imgs = img2img_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.batch_size,
args.height,
args.width,
args.steps,
args.strength,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, strength={args.strength}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += img2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
extra_info = {"STRENGTH": args.strength}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,104 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
def main():
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
if args.mask_path is None:
print("Flag --mask_path is required.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
model_id = (
args.hf_model_id
if "inpaint" in args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
image = Image.open(args.img_path)
mask_image = Image.open(args.mask_path)
inpaint_obj = InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = inpaint_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
mask_image,
args.batch_size,
args.height,
args.width,
args.inpaint_full_res,
args.inpaint_full_res_padding,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += inpaint_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,19 @@
from apps.stable_diffusion.src import args
from apps.stable_diffusion.scripts import (
img2img,
txt2img,
# inpaint,
# outpaint,
)
if __name__ == "__main__":
if args.app == "txt2img":
txt2img.main()
elif args.app == "img2img":
img2img.main()
# elif args.app == "inpaint":
# inpaint.main()
# elif args.app == "outpaint":
# outpaint.main()
else:
print(f"args.app value is {args.app} but this isn't supported")

View File

@@ -0,0 +1,119 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
def main():
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
model_id = (
args.hf_model_id
if "inpaint" in args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
image = Image.open(args.img_path)
outpaint_obj = OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = outpaint_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.pixels,
args.mask_blur,
args.left,
args.right,
args.top,
args.bottom,
args.noise_q,
args.color_variation,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += outpaint_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
# save this information as metadata of output generated image.
directions = []
if args.left:
directions.append("left")
if args.right:
directions.append("right")
if args.top:
directions.append("up")
if args.bottom:
directions.append("down")
extra_info = {
"PIXELS": args.pixels,
"MASK_BLUR": args.mask_blur,
"DIRECTIONS": directions,
"NOISE_Q": args.noise_q,
"COLOR_VARIATION": args.color_variation,
}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,240 @@
import logging
import os
from models.stable_diffusion.main import stable_diff_inf
from models.stable_diffusion.utils import get_available_devices
from dotenv import load_dotenv
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram import BotCommand
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
from io import BytesIO
import random
log = logging.getLogger("TG.Bot")
logging.basicConfig()
log.warning("Start")
load_dotenv()
os.environ["AMD_ENABLE_LLPC"] = "0"
TG_TOKEN = os.getenv("TG_TOKEN")
SELECTED_MODEL = "stablediffusion"
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
STEPS = 30
NEGATIVE_PROMPT = (
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
" legs, fused fingers, too many fingers"
)
GUIDANCE_SCALE = 6
available_devices = get_available_devices()
models_list = [
"stablediffusion",
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]
sheds_list = [
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
]
def image_to_bytes(image):
bio = BytesIO()
bio.name = "image.jpeg"
image.save(bio, "JPEG")
bio.seek(0)
return bio
def get_try_again_markup():
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
reply_markup = InlineKeyboardMarkup(keyboard)
return reply_markup
def generate_image(prompt):
seed = random.randint(1, 10000)
log.warning(SELECTED_MODEL)
log.warning(STEPS)
image, text = stable_diff_inf(
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
steps=STEPS,
guidance_scale=GUIDANCE_SCALE,
seed=seed,
scheduler_key=SELECTED_SCHEDULER,
variant=SELECTED_MODEL,
device_key=available_devices[0],
)
return image, seed
async def generate_and_send_photo(
update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
progress_msg = await update.message.reply_text(
"Generating image...", reply_to_message_id=update.message.message_id
)
im, seed = generate_image(prompt=update.message.text)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{update.message.text}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=update.message.message_id,
)
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
query = update.callback_query
if query.data in models_list:
global SELECTED_MODEL
SELECTED_MODEL = query.data
await query.answer()
await query.edit_message_text(text=f"Selected model: {query.data}")
return
if query.data in sheds_list:
global SELECTED_SCHEDULER
SELECTED_SCHEDULER = query.data
await query.answer()
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
return
replied_message = query.message.reply_to_message
await query.answer()
progress_msg = await query.message.reply_text(
"Generating image...", reply_to_message_id=replied_message.message_id
)
if query.data == "TRYAGAIN":
prompt = replied_message.text
im, seed = generate_image(prompt)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{prompt}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=replied_message.message_id,
)
async def select_model_handler(update, context):
text = "Select model"
keyboard = []
for model in models_list:
keyboard.append(
[
InlineKeyboardButton(text=model, callback_data=model),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def select_scheduler_handler(update, context):
text = "Select schedule"
keyboard = []
for shed in sheds_list:
keyboard.append(
[
InlineKeyboardButton(text=shed, callback_data=shed),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def set_steps_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_steps ")[1]
global STEPS
STEPS = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_steps 30"
)
await update.message.reply_text(input_args)
async def set_negative_prompt_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_negative_prompt ")[1]
global NEGATIVE_PROMPT
NEGATIVE_PROMPT = input_args
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_negative_prompt ugly, bad art, mutated"
)
await update.message.reply_text(input_args)
async def set_guidance_scale_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_guidance_scale ")[1]
global GUIDANCE_SCALE
GUIDANCE_SCALE = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_guidance_scale 7"
)
await update.message.reply_text(input_args)
async def setup_bot_commands(application: Application) -> None:
await application.bot.set_my_commands(
[
BotCommand("select_model", "to select model"),
BotCommand("select_scheduler", "to select scheduler"),
BotCommand("set_steps", "to set steps"),
BotCommand("set_guidance_scale", "to set guidance scale"),
BotCommand("set_negative_prompt", "to set negative prompt"),
]
)
app = (
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
)
app.add_handler(CommandHandler("select_model", select_model_handler))
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
app.add_handler(CommandHandler("set_steps", set_steps_handler))
app.add_handler(
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
)
app.add_handler(
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
)
app.add_handler(
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
)
app.add_handler(CallbackQueryHandler(button))
log.warning("Start bot")
app.run_polling()

View File

@@ -0,0 +1,692 @@
# Install the required libs
# pip install -U git+https://github.com/huggingface/diffusers.git
# pip install accelerate transformers ftfy
# HuggingFace Token
# YOUR_TOKEN = "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"
# Import required libraries
import itertools
import math
import os
from typing import List
import random
import torch_mlir
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import PIL
import logging
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
import torch._dynamo as dynamo
from torch.fx.experimental.proxy_tensor import make_fx
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
from shark.shark_inference import SharkInference
torch._dynamo.config.verbose = True
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker,
)
from PIL import Image
from tqdm.auto import tqdm
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer,
)
from io import BytesIO
from dataclasses import dataclass
from apps.stable_diffusion.src import (
args,
get_schedulers,
set_init_device_flags,
clear_all,
)
from apps.stable_diffusion.src.utils import update_lora_weight
# Setup the dataset
class LoraDataset(Dataset):
def __init__(
self,
data_root,
tokenizer,
size=512,
repeats=100,
interpolation="bicubic",
set="train",
prompt="myloraprompt",
center_crop=False,
):
self.data_root = data_root
self.tokenizer = tokenizer
self.size = size
self.center_crop = center_crop
self.prompt = prompt
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
self.num_images = len(self.image_paths)
self._length = self.num_images
if set == "train":
self._length = self.num_images * repeats
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
example["input_ids"] = self.tokenizer(
self.prompt,
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids[0]
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = np.array(image).astype(np.uint8)
image = (image / 127.5 - 1.0).astype(np.float32)
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
return example
def torch_device(device):
device_tokens = device.split("=>")
if len(device_tokens) == 1:
device_str = device_tokens[0].strip()
else:
device_str = device_tokens[1].strip()
device_type_tokens = device_str.split("://")
if device_type_tokens[0] == "metal":
device_type_tokens[0] = "vulkan"
if len(device_type_tokens) > 1:
return device_type_tokens[0] + ":" + device_type_tokens[1]
else:
return device_type_tokens[0]
########## Setting up the model ##########
def lora_train(
prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
training_images_dir: str,
lora_save_dir: str,
use_lora: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
print(
"Note LoRA training is not compatible with the latest torch-mlir branch"
)
print(
"To run LoRA training you'll need this to follow this guide for the torch-mlir branch: https://github.com/nod-ai/SHARK/tree/main/shark/examples/shark_training/stable_diffusion"
)
torch.manual_seed(seed)
args.prompts = [prompt]
args.steps = steps
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = custom_model
else:
args.hf_model_id = custom_model
args.training_images_dir = training_images_dir
args.lora_save_dir = lora_save_dir
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = torch_device(device)
args.use_lora = use_lora
# Load the Stable Diffusion model
text_encoder = CLIPTextModel.from_pretrained(
args.hf_model_id, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(args.hf_model_id, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
args.hf_model_id, subfolder="unet"
)
def freeze_params(params):
for param in params:
param.requires_grad = False
# Freeze everything but LoRA
freeze_params(vae.parameters())
freeze_params(unet.parameters())
freeze_params(text_encoder.parameters())
# Move vae and unet to device
vae.to(args.device)
unet.to(args.device)
text_encoder.to(args.device)
if use_lora != "":
update_lora_weight(unet, args.use_lora, "unet")
else:
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = vae
def forward(self, input):
x = self.vae.encode(input, return_dict=False)[0]
return x
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = unet
def forward(self, x, y, z):
return self.unet.forward(x, y, z, return_dict=False)[0]
shark_vae = VaeModel()
shark_unet = UnetModel()
####### Creating our training data ########
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id,
subfolder="tokenizer",
)
# Let's create the Dataset and Dataloader
train_dataset = LoraDataset(
data_root=args.training_images_dir,
tokenizer=tokenizer,
size=vae.sample_size,
prompt=args.prompts[0],
repeats=100,
center_crop=False,
set="train",
)
def create_dataloader(train_batch_size=1):
return torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True
)
# Create noise_scheduler for training
noise_scheduler = DDPMScheduler.from_config(
args.hf_model_id, subfolder="scheduler"
)
######## Training ###########
# Define hyperparameters for our training. If you are not happy with your results,
# you can tune the `learning_rate` and the `max_train_steps`
# Setting up all training args
hyperparameters = {
"learning_rate": 5e-04,
"scale_lr": True,
"max_train_steps": steps,
"train_batch_size": batch_size,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True,
"mixed_precision": "fp16",
"seed": 42,
"output_dir": "sd-concept-output",
}
# creating output directory
cwd = os.getcwd()
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
while not os.path.exists(str(out_dir)):
try:
os.mkdir(out_dir)
except OSError as error:
print("Output directory not created")
###### Torch-MLIR Compilation ######
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
@make_simple_dynamo_backend
def refbackend_torchdynamo_backend(
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
):
# handling usage of empty tensor without initializing
transform_fx(fx_graph)
fx_graph.recompile()
if _returns_nothing(fx_graph):
return fx_graph
removed_none_indexes = _remove_nones(fx_graph)
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
mlir_module = torch_mlir.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors"
)
bytecode_stream = BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
def compiled_callable(*inputs):
inputs = [x.numpy() for x in inputs]
result = shark_module("forward", inputs)
if was_unwrapped:
result = [
result,
]
if not isinstance(result, list):
result = torch.from_numpy(result)
else:
result = tuple(torch.from_numpy(x) for x in result)
result = list(result)
for removed_index in removed_none_indexes:
result.insert(removed_index, None)
result = tuple(result)
return result
return compiled_callable
def predictions(torch_func, jit_func, batchA, batchB):
res = jit_func(batchA.numpy(), batchB.numpy())
if res is not None:
# prediction = torch.from_numpy(res)
prediction = res
else:
prediction = None
return prediction
logger = logging.getLogger(__name__)
train_batch_size = hyperparameters["train_batch_size"]
gradient_accumulation_steps = hyperparameters[
"gradient_accumulation_steps"
]
learning_rate = hyperparameters["learning_rate"]
if hyperparameters["scale_lr"]:
learning_rate = (
learning_rate
* gradient_accumulation_steps
* train_batch_size
# * accelerator.num_processes
)
# Initialize the optimizer
optimizer = torch.optim.AdamW(
lora_layers.parameters(), # only optimize the embeddings
lr=learning_rate,
)
# Training function
def train_func(batch_pixel_values, batch_input_ids):
# Convert images to latent space
latents = shark_vae(batch_pixel_values).sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.num_train_timesteps,
(bsz,),
device=latents.device,
).long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch_input_ids)[0]
# Predict the noise residual
noise_pred = shark_unet(
noisy_latents,
timesteps,
encoder_hidden_states,
)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
)
loss = (
F.mse_loss(noise_pred, target, reduction="none")
.mean([1, 2, 3])
.mean()
)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
def training_function():
max_train_steps = hyperparameters["max_train_steps"]
output_dir = hyperparameters["output_dir"]
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
train_dataloader = create_dataloader(train_batch_size)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
num_train_epochs = math.ceil(
max_train_steps / num_update_steps_per_epoch
)
# Train!
total_batch_size = (
train_batch_size
* gradient_accumulation_steps
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(
f" Instantaneous batch size per device = {train_batch_size}"
)
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(
# range(max_train_steps), disable=not accelerator.is_local_main_process
range(max_train_steps)
)
progress_bar.set_description("Steps")
global_step = 0
params__ = [
i for i in text_encoder.get_input_embeddings().parameters()
]
for epoch in range(num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
dynamo_callable = dynamo.optimize(
refbackend_torchdynamo_backend
)(train_func)
lam_func = lambda x, y: dynamo_callable(
torch.from_numpy(x), torch.from_numpy(y)
)
loss = predictions(
train_func,
lam_func,
batch["pixel_values"],
batch["input_ids"],
)
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item()}
progress_bar.set_postfix(**logs)
if global_step >= max_train_steps:
break
training_function()
# Save the lora weights
unet.save_attn_procs(args.lora_save_dir)
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
if param.grad is not None:
del param.grad # free some memory
torch.cuda.empty_cache()
if __name__ == "__main__":
if args.clear_all:
clear_all()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
if len(args.prompts) != 1:
print("Need exactly one prompt for the LoRA word")
lora_train(
args.prompts[0],
args.height,
args.width,
args.training_steps,
args.guidance_scale,
args.seed,
args.batch_count,
args.batch_size,
args.scheduler,
"None",
args.hf_model_id,
args.precision,
args.device,
args.max_length,
args.training_images_dir,
args.lora_save_dir,
args.use_lora,
)

View File

@@ -0,0 +1,126 @@
import os
from pathlib import Path
from shark_tuner.codegen_tuner import SharkCodegenTuner
from shark_tuner.iree_utils import (
dump_dispatches,
create_context,
export_module_to_mlir_file,
)
from shark_tuner.model_annotation import model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.utils import set_init_device_flags
from apps.stable_diffusion.src.utils.sd_annotation import (
get_device_args,
load_winograd_configs,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def load_mlir_module():
sd_model = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
max_len=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=False,
low_cpu_mem_usage=args.low_cpu_mem_usage,
return_mlir=True,
)
if args.annotation_model == "unet":
mlir_module = sd_model.unet()
model_name = sd_model.model_name["unet"]
elif args.annotation_model == "vae":
mlir_module = sd_model.vae()
model_name = sd_model.model_name["vae"]
else:
raise ValueError(
f"{args.annotation_model} is not supported for tuning."
)
return mlir_module, model_name
def main():
args.use_tuned = False
set_init_device_flags()
mlir_module, model_name = load_mlir_module()
# Get device and device specific arguments
device, device_spec_args = get_device_args()
device_spec = ""
vulkan_target_triple = ""
if device_spec_args:
device_spec = device_spec_args[-1].split("=")[-1].strip()
if device == "vulkan":
vulkan_target_triple = device_spec
device_spec = device_spec.split("-")[0]
# Add winograd annotation for vulkan device
use_winograd = (
True
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
else False
)
winograd_config = (
load_winograd_configs()
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
else ""
)
with create_context() as ctx:
input_module = model_annotation(
ctx,
input_contents=mlir_module,
config_path=winograd_config,
search_op="conv",
winograd=use_winograd,
)
# Dump model dispatches
generates_dir = Path.home() / "tmp"
if not os.path.exists(generates_dir):
os.makedirs(generates_dir)
dump_mlir = generates_dir / "temp.mlir"
dispatch_dir = generates_dir / f"{model_name}_{device_spec}_dispatches"
export_module_to_mlir_file(input_module, dump_mlir)
dump_dispatches(
dump_mlir,
device,
dispatch_dir,
vulkan_target_triple,
use_winograd=use_winograd,
)
# Tune each dispatch
dtype = "f16" if args.precision == "fp16" else "f32"
config_filename = f"{model_name}_{device_spec}_configs.json"
for f_path in os.listdir(dispatch_dir):
if not f_path.endswith(".mlir"):
continue
model_dir = os.path.join(dispatch_dir, f_path)
tuner = SharkCodegenTuner(
model_dir,
device,
"random",
args.num_iters,
args.tuned_config_dir,
dtype,
args.search_op,
batch_size=1,
config_filename=config_filename,
use_dispatch=True,
vulkan_target_triple=vulkan_target_triple,
)
tuner.tune()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,85 @@
import torch
import transformers
import time
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
def main():
if args.clear_all:
clear_all()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,91 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
if __name__ == "__main__":
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
# When the models get uploaded, it should be default to False.
args.import_mlir = True
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
image = (
Image.open(args.img_path)
.convert("RGB")
.resize((args.height, args.width))
)
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model
upscaler_obj = UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ddpm_scheduler=schedulers["DDPM"],
ondemand=args.ondemand,
)
start_time = time.time()
generated_imgs = upscaler_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.batch_size,
args.height,
args.width,
args.steps,
args.noise_level,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, noise_level={args.noise_level}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += upscaler_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
extra_info = {"NOISE LEVEL": args.noise_level}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)

View File

@@ -0,0 +1,90 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv-python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('tkinter')
datas += collect_data_files('webview')
datas += collect_data_files('sentencepiece')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
( 'web/ui/css/*', 'ui/css' ),
( 'web/ui/logos/*', 'logos' )
]
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("PIL") if "tests" not in x]
a = Analysis(
['web/index.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,84 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_submodules
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
]
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['scripts/main.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,18 @@
from apps.stable_diffusion.src.utils import (
args,
set_init_device_flags,
prompt_examples,
get_available_devices,
clear_all,
save_output_img,
resize_stencil,
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,
StencilPipeline,
UpscalerPipeline,
)
from apps.stable_diffusion.src.schedulers import get_schedulers

View File

@@ -0,0 +1,12 @@
from apps.stable_diffusion.src.models.model_wrappers import (
SharkifyStableDiffusionModel,
)
from apps.stable_diffusion.src.models.opt_params import (
get_vae_encode,
get_vae,
get_unet,
get_clip,
get_tokenizer,
get_params,
get_variant_version,
)

View File

@@ -0,0 +1,686 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from collections import defaultdict
from pathlib import Path
import torch
import safetensors.torch
import traceback
import subprocess
import sys
import os
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
base_models,
args,
preprocessCKPT,
convert_original_vae,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
get_stencil_model_id,
update_lora_weight,
)
# These shapes are parameter dependent.
def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape = []
for i in range(len(shape)):
if shape[i] == "max_len":
new_shape.append(max_len)
elif shape[i] == "height":
new_shape.append(height)
elif shape[i] == "width":
new_shape.append(width)
elif isinstance(shape[i], str):
if "*" in shape[i]:
mul_val = int(shape[i].split("*")[0])
if "batch_size" in shape[i]:
new_shape.append(batch_size * mul_val)
elif "height" in shape[i]:
new_shape.append(height * mul_val)
elif "width" in shape[i]:
new_shape.append(width * mul_val)
elif "/" in shape[i]:
import math
div_val = int(shape[i].split("/")[1])
if "batch_size" in shape[i]:
new_shape.append(math.ceil(batch_size / div_val))
elif "height" in shape[i]:
new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val))
else:
new_shape.append(shape[i])
return new_shape
def check_compilation(model, model_name):
if not model:
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
class SharkifyStableDiffusionModel:
def __init__(
self,
model_id: str,
custom_weights: str,
custom_vae: str,
precision: str,
max_len: int = 64,
width: int = 512,
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
low_cpu_mem_usage: bool = False,
debug: bool = False,
sharktank_dir: str = "",
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
use_lora: str = "",
use_quantize: str = None,
return_mlir: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.use_quantize = use_quantize
if custom_weights != "":
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
if not os.path.isfile(weights_path):
subprocess.run(["wget", custom_weights, "-O", weights_path])
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
self.model_id = "stabilityai/stable-diffusion-2-1-base"
self.custom_vae = custom_vae
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
"_"
+ str(batch_size)
+ "_"
+ str(max_len)
+ "_"
+ str(height)
+ "_"
+ str(width)
+ "_"
+ precision
)
print(f'use_tuned? sharkify: {use_tuned}')
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.is_upscaler = is_upscaler
self.use_stencil = get_stencil_model_id(use_stencil)
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora
print(self.model_name)
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
self.sharktank_dir = sharktank_dir
self.generate_vmfb = generate_vmfb
self.inputs = dict()
self.model_to_run = ""
if self.custom_weights != "":
self.model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights, self.is_inpaint)
else:
self.model_to_run = args.hf_model_id
self.custom_vae = self.process_custom_vae()
self.base_model_id = fetch_and_update_base_model_id(self.model_to_run)
if self.base_model_id != "" and args.ckpt_loc != "":
args.hf_model_id = self.base_model_id
self.return_mlir = return_mlir
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
index = 0
for model in sub_model_list:
sub_model = model
model_config = self.model_name
if "vae" == model:
if self.custom_vae != "":
model_config = model_config + get_path_stem(self.custom_vae)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
model_config = model_config + get_path_stem(self.use_stencil)
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
sys.exit("please specify max_len in the range [32, 77].")
if not (width % 8 == 0 and width >= 128):
sys.exit("width should be greater than 128 and multiple of 8")
if not (height % 8 == 0 and height >= 128):
sys.exit("height should be greater than 128 and multiple of 8")
# Get the input info for a model i.e. "unet", "clip", "vae", etc.
def get_input_info_for(self, model_info):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = []
for inp in model_info:
shape = model_info[inp]["shape"]
dtype = dtype_config[model_info[inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, self.max_len, self.width, self.height, self.batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map.append(tensor)
return input_map
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
def forward(self, input):
latents = self.vae.encode(input).latent_dist.sample()
return 0.18215 * latents
vae_encode = VaeEncodeModel()
inputs = tuple(self.inputs["vae_encode"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
shark_vae_encode, vae_encode_mlir = compile_through_fx(
vae_encode,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae_encode"],
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae_encode",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae_encode, vae_encode_mlir
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
super().__init__()
self.vae = None
if custom_vae == "":
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae
def forward(self, input):
if not self.base_vae:
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
if self.base_vae:
return x
x = x * 255.0
return x.round()
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae, vae_mlir = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae, vae_mlir
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
def forward( self, latent, timestep, text_embedding, guidance_scale, control1,
control2, control3, control4, control5, control6, control7,
control8, control9, control10, control11, control12, control13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple([ control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12,])
mb_res_samples = control13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
down_block_additional_residuals=db_res_samples,
mid_block_additional_residual=mb_res_samples,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = ControlledUnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["stencil_unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self):
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
):
super().__init__()
self.cnet = ControlNetModel.from_pretrained(
model_id,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.cnet.in_channels
self.train(False)
def forward(
self,
latent,
timestep,
text_embedding,
stencil_image_input,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
# TODO: guidance NOT NEEDED change in `get_input_info` later
latents = torch.cat(
[latent] * 2
) # needs to be same as controlledUNET latents
stencil_image = torch.cat(
[stencil_image_input] * 2
) # needs to be same as controlledUNET latents
down_block_res_samples, mid_block_res_sample = self.cnet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
controlnet_cond=stencil_image,
return_dict=False,
)
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
scnet = StencilControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
input_mask = [True, True, True, True]
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
extended_model_name=self.model_name["stencil_adaptor"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_adaptor",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_cnet, cnet_mlir
def get_unet(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
if(args.attention_slicing.isdigit()):
self.unet.set_attention_slice(int(args.attention_slicing))
else:
self.unet.set_attention_slice(args.attention_slicing)
# TODO: Instead of flattening the `control` try to use the list.
def forward(
self, latent, timestep, text_embedding, guidance_scale,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_unet_upscaler(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, latent, timestep, text_embedding, noise_level):
unet_out = self.unet.forward(
latent,
timestep,
text_embedding,
noise_level,
return_dict=False,
)[0]
return unet_out
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip, clip_mlir = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
extended_model_name=self.model_name["clip"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip, clip_mlir
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
if not custom_vae.endswith((".ckpt", ".safetensors")):
return self.custom_vae
try:
preprocessCKPT(self.custom_vae)
return get_path_to_diffusers_checkpoint(self.custom_vae)
except:
print("Processing standalone Vae checkpoint")
vae_checkpoint = None
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
if custom_vae.endswith(".ckpt"):
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
else:
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
if "state_dict" in vae_checkpoint:
vae_checkpoint = vae_checkpoint["state_dict"]
try:
vae_checkpoint = convert_original_vae(vae_checkpoint)
finally:
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
def compile_unet_variants(self, model):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler()
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
return get_unet()
else:
return self.get_unet()
else:
return self.get_controlled_unet()
def vae_encode(self):
try:
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
check_compilation(compiled_vae_encode, "Vae Encode")
if self.return_mlir:
return vae_encode_mlir
return compiled_vae_encode
except Exception as e:
sys.exit(e)
def clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(base_models["clip"])
compiled_clip, clip_mlir = self.get_clip()
check_compilation(compiled_clip, "Clip")
if self.return_mlir:
return clip_mlir
return compiled_clip
except Exception as e:
sys.exit(e)
def unet(self):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
compiled_unet = None
unet_inputs = base_models[model]
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet, unet_mlir = self.compile_unet_variants(model)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
try:
compiled_unet, unet_mlir = self.compile_unet_variants(model)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(self.model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
break
check_compilation(compiled_unet, "Unet")
if self.return_mlir:
return unet_mlir
return compiled_unet
except Exception as e:
sys.exit(e)
def vae(self):
try:
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
self.inputs["vae"] = self.get_input_info_for(vae_input)
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")
if self.return_mlir:
return vae_mlir
return compiled_vae
except Exception as e:
sys.exit(e)
def controlnet(self):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
check_compilation(compiled_stencil_adaptor, "Stencil")
if self.return_mlir:
return controlnet_mlir
return compiled_stencil_adaptor
except Exception as e:
sys.exit(e)

View File

@@ -0,0 +1,123 @@
import sys
from transformers import CLIPTokenizer
from apps.stable_diffusion.src.utils import (
models_db,
args,
get_shark_model,
get_opt_flags,
)
hf_model_variant_map = {
"Linaqruf/anything-v3.0": ["anythingv3", "v1_4"],
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v1_4"],
"prompthero/openjourney": ["openjourney", "v1_4"],
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v1_4"],
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
bucket_key = "gs://shark_tank/prashant_nod"
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
def get_params(bucket_key, model_key, model, is_tuned, precision):
try:
bucket = models_db[0][bucket_key]
model_name = models_db[1][model_key]
except KeyError:
raise Exception(
f"{bucket_key}/{model_key} is not present in the models database"
)
iree_flags = get_opt_flags(model, precision="fp16")
return bucket, model_name, iree_flags
def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
# TODO: Get the quantize model from model_db.json
if args.use_quantize == "int8":
bk, mk, flags = get_quantize_model()
return get_shark_model(bk, mk, flags)
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "unet", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae_encode():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
is_base = "/base" if args.use_base_vae else ""
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
variant, version = get_variant_version(args.hf_model_id)
bucket_key = f"{variant}/untuned"
model_key = (
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
)
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "clip", "untuned", "fp32"
)
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
)
return tokenizer

View File

@@ -0,0 +1,18 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
InpaintPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpaint import (
OutpaintPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_stencil import (
StencilPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_upscaler import (
UpscalerPipeline,
)

View File

@@ -0,0 +1,200 @@
import torch
import time
import numpy as np
from tqdm.auto import tqdm
from random import randint
from PIL import Image
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class Image2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_image_latents(
self,
image,
batch_size,
height,
width,
generator,
num_inference_steps,
strength,
dtype,
):
# Pre process image -> get image encoded -> process latents
# TODO: process with variable HxW combos
# Pre process image
image = image.resize((width, height))
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
image_arr = 2 * (image_arr - 0.5)
# set scheduler steps
self.scheduler.set_timesteps(num_inference_steps)
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps
)
t_start = max(num_inference_steps - init_timestep, 0)
# timesteps reduced as per strength
timesteps = self.scheduler.timesteps[t_start:]
# new number of steps to be used as per strength will be
# num_inference_steps = num_inference_steps - t_start
# image encode
latents = self.encode_image((image_arr,))
latents = torch.from_numpy(latents).to(dtype)
# add noise to data
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
latents = self.scheduler.add_noise(
latents, noise, timesteps[0].repeat(1)
)
return latents, timesteps
def encode_image(self, input_image):
self.load_vae_encode()
vae_encode_start = time.time()
latents = self.vae_encode("forward", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
if self.ondemand:
self.unload_vae_encode()
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
return latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
strength,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
use_stencil,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare input image latent
image_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=image_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,473 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from PIL import Image, ImageOps
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class InpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
latents = latents * self.scheduler.init_noise_sigma
return latents
def get_crop_region(self, mask, pad=0):
h, w = mask.shape
crop_left = 0
for i in range(w):
if not (mask[:, i] == 0).all():
break
crop_left += 1
crop_right = 0
for i in reversed(range(w)):
if not (mask[:, i] == 0).all():
break
crop_right += 1
crop_top = 0
for i in range(h):
if not (mask[i] == 0).all():
break
crop_top += 1
crop_bottom = 0
for i in reversed(range(h)):
if not (mask[i] == 0).all():
break
crop_bottom += 1
return (
int(max(crop_left - pad, 0)),
int(max(crop_top - pad, 0)),
int(min(w - crop_right + pad, w)),
int(min(h - crop_bottom + pad, h)),
)
def expand_crop_region(
self,
crop_region,
processing_width,
processing_height,
image_width,
image_height,
):
x1, y1, x2, y2 = crop_region
ratio_crop_region = (x2 - x1) / (y2 - y1)
ratio_processing = processing_width / processing_height
if ratio_crop_region > ratio_processing:
desired_height = (x2 - x1) / ratio_processing
desired_height_diff = int(desired_height - (y2 - y1))
y1 -= desired_height_diff // 2
y2 += desired_height_diff - desired_height_diff // 2
if y2 >= image_height:
diff = y2 - image_height
y2 -= diff
y1 -= diff
if y1 < 0:
y2 -= y1
y1 -= y1
if y2 >= image_height:
y2 = image_height
else:
desired_width = (y2 - y1) * ratio_processing
desired_width_diff = int(desired_width - (x2 - x1))
x1 -= desired_width_diff // 2
x2 += desired_width_diff - desired_width_diff // 2
if x2 >= image_width:
diff = x2 - image_width
x2 -= diff
x1 -= diff
if x1 < 0:
x2 -= x1
x1 -= x1
if x2 >= image_width:
x2 = image_width
return x1, y1, x2, y2
def resize_image(self, resize_mode, im, width, height):
"""
resize_mode:
0: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
1: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
"""
if resize_mode == 0:
ratio = width / height
src_ratio = im.width / im.height
src_w = (
width if ratio > src_ratio else im.width * height // im.height
)
src_h = (
height if ratio <= src_ratio else im.height * width // im.width
)
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(
resized,
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
)
else:
ratio = width / height
src_ratio = im.width / im.height
src_w = (
width if ratio < src_ratio else im.width * height // im.height
)
src_h = (
height if ratio >= src_ratio else im.height * width // im.width
)
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(
resized,
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
)
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
res.paste(
resized.resize((width, fill_height), box=(0, 0, width, 0)),
box=(0, 0),
)
res.paste(
resized.resize(
(width, fill_height),
box=(0, resized.height, width, resized.height),
),
box=(0, fill_height + src_h),
)
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
res.paste(
resized.resize(
(fill_width, height), box=(0, 0, 0, height)
),
box=(0, 0),
)
res.paste(
resized.resize(
(fill_width, height),
box=(resized.width, 0, resized.width, height),
),
box=(fill_width + src_w, 0),
)
return res
def prepare_mask_and_masked_image(
self,
image,
mask,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
):
# preprocess image
image = image.resize((width, height))
mask = mask.resize((width, height))
paste_to = ()
overlay_image = None
if inpaint_full_res:
# prepare overlay image
overlay_image = Image.new("RGB", (image.width, image.height))
overlay_image.paste(
image.convert("RGB"),
mask=ImageOps.invert(mask.convert("L")),
)
# prepare mask
mask = mask.convert("L")
crop_region = self.get_crop_region(
np.array(mask), inpaint_full_res_padding
)
crop_region = self.expand_crop_region(
crop_region, width, height, mask.width, mask.height
)
x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region)
mask = self.resize_image(1, mask, width, height)
paste_to = (x1, y1, x2 - x1, y2 - y1)
# prepare image
image = image.crop(crop_region)
image = self.resize_image(1, image, width, height)
if isinstance(image, (Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image, paste_to, overlay_image
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
):
mask = torch.nn.functional.interpolate(
mask, size=(height // 8, width // 8)
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
return mask, masked_image_latents
def apply_overlay(self, image, paste_loc, overlay):
x, y, w, h = paste_loc
image = self.resize_image(0, image, w, h)
overlay.paste(image, (x, y))
return overlay
def generate_images(
self,
prompts,
neg_prompts,
image,
mask_image,
batch_size,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Preprocess mask and image
(
mask,
masked_image,
paste_to,
overlay_image,
) = self.prepare_mask_and_masked_image(
image,
mask_image,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
)
# Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask=mask,
masked_image=masked_image,
batch_size=batch_size,
height=height,
width=width,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
mask=mask,
masked_image_latents=masked_image_latents,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
if inpaint_full_res:
output_image = self.apply_overlay(
all_imgs[0], paste_to, overlay_image
)
return [output_image]
return all_imgs

View File

@@ -0,0 +1,567 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from PIL import Image, ImageDraw, ImageFilter
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
import math
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class OutpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_mask_and_masked_image(
self, image, mask, mask_blur, width, height
):
if mask_blur > 0:
mask = mask.filter(ImageFilter.GaussianBlur(mask_blur))
image = image.resize((width, height))
mask = mask.resize((width, height))
# preprocess image
if isinstance(image, (Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
):
mask = torch.nn.functional.interpolate(
mask, size=(height // 8, width // 8)
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
return mask, masked_image_latents
def get_matched_noise(
self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05
):
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(data):
if data.ndim > 2: # has channels
out_fft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]),
dtype=np.complex128,
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_fft[:, :, c] = np.fft.fft2(
np.fft.fftshift(c_data), norm="ortho"
)
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
else: # one channel
out_fft = np.zeros(
(data.shape[0], data.shape[1]), dtype=np.complex128
)
out_fft[:, :] = np.fft.fft2(
np.fft.fftshift(data), norm="ortho"
)
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
return out_fft
def _ifft2(data):
if data.ndim > 2: # has channels
out_ifft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]),
dtype=np.complex128,
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_ifft[:, :, c] = np.fft.ifft2(
np.fft.fftshift(c_data), norm="ortho"
)
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
else: # one channel
out_ifft = np.zeros(
(data.shape[0], data.shape[1]), dtype=np.complex128
)
out_ifft[:, :] = np.fft.ifft2(
np.fft.fftshift(data), norm="ortho"
)
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
return out_ifft
def _get_gaussian_window(width, height, std=3.14, mode=0):
window_scale_x = float(width / min(width, height))
window_scale_y = float(height / min(width, height))
window = np.zeros((width, height))
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
for y in range(height):
fy = (y / height * 2.0 - 1.0) * window_scale_y
if mode == 0:
window[:, y] = np.exp(-(x**2 + fy**2) * std)
else:
window[:, y] = (
1 / ((x**2 + 1.0) * (fy**2 + 1.0))
) ** (std / 3.14)
return window
def _get_masked_window_rgb(np_mask_grey, hardness=1.0):
np_mask_rgb = np.zeros(
(np_mask_grey.shape[0], np_mask_grey.shape[1], 3)
)
if hardness != 1.0:
hardened = np_mask_grey[:] ** hardness
else:
hardened = np_mask_grey[:]
for c in range(3):
np_mask_rgb[:, :, c] = hardened[:]
return np_mask_rgb
def _match_cumulative_cdf(source, template):
src_values, src_unique_indices, src_counts = np.unique(
source.ravel(), return_inverse=True, return_counts=True
)
tmpl_values, tmpl_counts = np.unique(
template.ravel(), return_counts=True
)
# calculate normalized quantiles for each array
src_quantiles = np.cumsum(src_counts) / source.size
tmpl_quantiles = np.cumsum(tmpl_counts) / template.size
interp_a_values = np.interp(
src_quantiles, tmpl_quantiles, tmpl_values
)
return interp_a_values[src_unique_indices].reshape(source.shape)
def _match_histograms(image, reference):
if image.ndim != reference.ndim:
raise ValueError(
"Image and reference must have the same number of channels."
)
if image.shape[-1] != reference.shape[-1]:
raise ValueError(
"Number of channels in the input image and reference image must match!"
)
matched = np.empty(image.shape, dtype=image.dtype)
for channel in range(image.shape[-1]):
matched_channel = _match_cumulative_cdf(
image[..., channel], reference[..., channel]
)
matched[..., channel] = matched_channel
matched = matched.astype(np.float64, copy=False)
return matched
width = _np_src_image.shape[0]
height = _np_src_image.shape[1]
num_channels = _np_src_image.shape[2]
np_src_image = _np_src_image[:] * (1.0 - np_mask_rgb)
np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
# rather than leave the masked area black, we get better results from fft by filling the average unmasked color
windowed_image = _np_src_image * (
1.0 - _get_masked_window_rgb(np_mask_grey)
)
windowed_image /= np.max(windowed_image)
windowed_image += np.average(_np_src_image) * np_mask_rgb
src_fft = _fft2(
windowed_image
) # get feature statistics from masked src img
src_dist = np.absolute(src_fft)
src_phase = src_fft / src_dist
# create a generator with a static seed to make outpainting deterministic / only follow global seed
rng = np.random.default_rng(0)
noise_window = _get_gaussian_window(
width, height, mode=1
) # start with simple gaussian noise
noise_rgb = rng.random((width, height, num_channels))
noise_grey = np.sum(noise_rgb, axis=2) / 3.0
# the colorfulness of the starting noise is blended to greyscale with a parameter
noise_rgb *= color_variation
for c in range(num_channels):
noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
noise_fft = _fft2(noise_rgb)
for c in range(num_channels):
noise_fft[:, :, c] *= noise_window
noise_rgb = np.real(_ifft2(noise_fft))
shaped_noise_fft = _fft2(noise_rgb)
shaped_noise_fft[:, :, :] = (
np.absolute(shaped_noise_fft[:, :, :]) ** 2
* (src_dist**noise_q)
* src_phase
) # perform the actual shaping
# color_variation
brightness_variation = 0.0
contrast_adjusted_np_src = (
_np_src_image[:] * (brightness_variation + 1.0)
- brightness_variation * 2.0
)
shaped_noise = np.real(_ifft2(shaped_noise_fft))
shaped_noise -= np.min(shaped_noise)
shaped_noise /= np.max(shaped_noise)
shaped_noise[img_mask, :] = _match_histograms(
shaped_noise[img_mask, :] ** 1.0,
contrast_adjusted_np_src[ref_mask, :],
)
shaped_noise = (
_np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
)
matched_noise = shaped_noise[:]
return np.clip(matched_noise, 0.0, 1.0)
def generate_images(
self,
prompts,
neg_prompts,
image,
pixels,
mask_blur,
is_left,
is_right,
is_top,
is_bottom,
noise_q,
color_variation,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
process_width = width
process_height = height
left = pixels if is_left else 0
right = pixels if is_right else 0
up = pixels if is_top else 0
down = pixels if is_bottom else 0
target_w = math.ceil((image.width + left + right) / 64) * 64
target_h = math.ceil((image.height + up + down) / 64) * 64
if left > 0:
left = left * (target_w - image.width) // (left + right)
if right > 0:
right = target_w - image.width - left
if up > 0:
up = up * (target_h - image.height) // (up + down)
if down > 0:
down = target_h - image.height - up
def expand(
init_img,
expand_pixels,
is_left=False,
is_right=False,
is_top=False,
is_bottom=False,
):
is_horiz = is_left or is_right
is_vert = is_top or is_bottom
pixels_horiz = expand_pixels if is_horiz else 0
pixels_vert = expand_pixels if is_vert else 0
res_w = init_img.width + pixels_horiz
res_h = init_img.height + pixels_vert
process_res_w = math.ceil(res_w / 64) * 64
process_res_h = math.ceil(res_h / 64) * 64
img = Image.new("RGB", (process_res_w, process_res_h))
img.paste(
init_img,
(pixels_horiz if is_left else 0, pixels_vert if is_top else 0),
)
msk = Image.new("RGB", (process_res_w, process_res_h), "white")
draw = ImageDraw.Draw(msk)
draw.rectangle(
(
expand_pixels + mask_blur if is_left else 0,
expand_pixels + mask_blur if is_top else 0,
msk.width - expand_pixels - mask_blur
if is_right
else res_w,
msk.height - expand_pixels - mask_blur
if is_bottom
else res_h,
),
fill="black",
)
np_image = (np.asarray(img) / 255.0).astype(np.float64)
np_mask = (np.asarray(msk) / 255.0).astype(np.float64)
noised = self.get_matched_noise(
np_image, np_mask, noise_q, color_variation
)
output_image = Image.fromarray(
np.clip(noised * 255.0, 0.0, 255.0).astype(np.uint8),
mode="RGB",
)
target_width = (
min(width, init_img.width + pixels_horiz)
if is_horiz
else img.width
)
target_height = (
min(height, init_img.height + pixels_vert)
if is_vert
else img.height
)
crop_region = (
0 if is_left else output_image.width - target_width,
0 if is_top else output_image.height - target_height,
target_width if is_left else output_image.width,
target_height if is_top else output_image.height,
)
mask_to_process = msk.crop(crop_region)
image_to_process = output_image.crop(crop_region)
# Preprocess mask and image
mask, masked_image = self.prepare_mask_and_masked_image(
image_to_process, mask_to_process, mask_blur, width, height
)
# Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask=mask,
masked_image=masked_image,
batch_size=batch_size,
height=height,
width=width,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
mask=mask,
masked_image_latents=masked_image_latents,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
res_img = all_imgs[0].resize(
(image_to_process.width, image_to_process.height)
)
output_image.paste(
res_img,
(
0 if is_left else output_image.width - res_img.width,
0 if is_top else output_image.height - res_img.height,
),
)
output_image = output_image.crop((0, 0, res_w, res_h))
return output_image
img = image.resize((width, height))
if left > 0:
img = expand(img, left, is_left=True)
if right > 0:
img = expand(img, right, is_right=True)
if up > 0:
img = expand(img, up, is_top=True)
if down > 0:
img = expand(img, down, is_bottom=True)
return [img]

View File

@@ -0,0 +1,274 @@
import torch
import time
import numpy as np
from tqdm.auto import tqdm
from random import randint
from PIL import Image
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import controlnet_hint_conversion
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class StencilPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
def load_controlnet(self):
if self.controlnet is not None:
return
self.controlnet = self.sd_model.controlnet()
def unload_controlnet(self):
del self.controlnet
self.controlnet = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_stencil_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
controlnet_conditioning_scale: float = 1.0,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
self.load_controlnet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.ondemand:
self.unload_unet()
self.unload_controlnet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
strength,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
use_stencil,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
controlnet_hint = controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
)
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare initial latent.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
final_timesteps = self.scheduler.timesteps
# Get Image latents
latents = self.produce_stencil_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,144 @@
import torch
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,326 @@
import inspect
import torch
import time
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_IDLE,
SD_STATE_CANCEL,
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from PIL import Image
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, Image.Image):
image = [image]
if isinstance(image[0], Image.Image):
w, h = image[0].size
w, h = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class UpscalerPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
low_res_scheduler: Union[
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.low_res_scheduler = low_res_scheduler
self.status = SD_STATE_IDLE
def prepare_extra_step_kwargs(self, generator, eta):
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
latents = 1 / 0.08333 * (latents.float())
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height,
width,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_img_latents(
self,
latents,
image,
text_embeddings,
guidance_scale,
noise_level,
total_timesteps,
dtype,
cpu_scheduling,
extra_step_kwargs,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.status = SD_STATE_IDLE
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = torch.cat([latent_model_input, image], dim=1)
timestep = torch.tensor([t]).to(dtype).detach().numpy()
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if cpu_scheduling:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
else:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
noise_level,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# 4. Preprocess image
image = preprocess(image).to(dtype)
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long)
noise = torch.randn(
image.shape,
generator=generator,
).to(dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
image = torch.cat([image] * 2)
noise_level = torch.cat([noise_level] * image.shape[0])
height, width = image.shape[2:]
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
eta = 0.0
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# guidance scale as a float32 tensor.
# guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
image=image,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
noise_level=noise_level,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
extra_step_kwargs=extra_step_kwargs,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,842 @@
import torch
import numpy as np
from transformers import CLIPTokenizer
from PIL import Image
from tqdm.auto import tqdm
import time
from typing import Union
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae,
get_clip,
get_unet,
get_tokenizer,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
import sys
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
class StableDiffusionPipeline:
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
self.vae = None
self.text_encoder = None
self.unet = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
self.sd_model = sd_model
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
try:
self.tokenizer = get_tokenizer()
except:
self.load_unet()
self.unload_unet()
self.tokenizer = get_tokenizer()
def load_clip(self):
if self.text_encoder is not None:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
)
self.text_encoder = self.sd_model.clip()
else:
try:
self.text_encoder = get_clip()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.text_encoder = self.sd_model.clip()
def unload_clip(self):
del self.text_encoder
self.text_encoder = None
def load_unet(self):
if self.unet is not None:
return
if self.import_mlir or self.use_lora:
self.unet = self.sd_model.unet()
else:
try:
self.unet = get_unet()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.unet = self.sd_model.unet()
def unload_unet(self):
del self.unet
self.unet = None
def load_vae(self):
if self.vae is not None:
return
if self.import_mlir or self.use_lora:
self.vae = self.sd_model.vae()
else:
try:
self.vae = get_vae()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.vae = self.sd_model.vae()
def unload_vae(self):
del self.vae
self.vae = None
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
text_input = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
# Get unconditional embeddings as well
uncond_input = self.tokenizer(
neg_prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
self.load_clip()
clip_inf_start = time.time()
text_embeddings = self.text_encoder("forward", (text_input,))
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
if use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
if use_base_vae:
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def produce_img_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
self.status = SD_STATE_IDLE
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
@classmethod
def from_pretrained(
cls,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
import_mlir: bool,
model_id: str,
ckpt_loc: str,
custom_vae: str,
precision: str,
max_length: int,
batch_size: int,
height: int,
width: int,
use_base_vae: bool,
use_tuned: bool,
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
if (
not import_mlir
and not use_lora
and cls.__name__ == "StencilPipeline"
):
sys.exit("StencilPipeline not supported with SharkTank currently.")
is_inpaint = cls.__name__ in [
"InpaintPipeline",
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
sd_model = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
use_quantize=use_quantize,
)
if cls.__name__ in ["UpscalerPipeline"]:
return cls(
scheduler,
ddpm_scheduler,
sd_model,
import_mlir,
use_lora,
ondemand,
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
# #####################################################
# Implements text embeddings with weights from prompts
# https://huggingface.co/AlanB/lpw_stable_diffusion_mod
# #####################################################
def encode_prompts_weight(
self,
prompt,
negative_prompt,
model_max_length,
do_classifier_free_guidance=True,
max_embeddings_multiples=1,
num_images_per_prompt=1,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
model_max_length (int):
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not,
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
num_images_per_prompt (`int`):
number of images that should be generated per prompt
SHARK: num_images_per_prompt is not used (defaulted to 1)
"""
# SHARK: Save model_max_length, load the clip and init inference time
self.model_max_length = model_max_length
self.load_clip()
clip_inf_start = time.time()
batch_size = len(prompt) if isinstance(prompt, list) else 1
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt
if do_classifier_free_guidance
else None,
max_embeddings_multiples=max_embeddings_multiples,
)
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = text_embeddings.shape
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = uncond_embeddings.shape
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()
from typing import List, Optional, Union
import re
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(
pipe: StableDiffusionPipeline, prompt: List[str], max_length: int
):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens,
weights,
max_length,
bos,
eos,
no_boseos_middle=True,
chunk_length=77,
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length
if no_boseos_middle
else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
tokens[i] = (
[bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
)
if no_boseos_middle:
weights[i] = (
[1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
)
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j
* (chunk_length - 2) : min(
len(weights[i]), (j + 1) * (chunk_length - 2)
)
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
# text_embedding = pipe.text_encoder(text_input_chunk)[0]
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
formatted_text_input_chunk = torch.cat(
[text_input_chunk, text_input_chunk]
)
text_embedding = pipe.text_encoder(
"forward", (formatted_text_input_chunk,)
)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
else:
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
# Convert the result to tensor
# text_embeddings = pipe.text_encoder(text_input)[0]
formatted_text_input = torch.cat([text_input, text_input])
text_embeddings = pipe.text_encoder(
"forward", (formatted_text_input,)
)[0]
text_embeddings = torch.from_numpy(text_embeddings)[None, :]
return text_embeddings
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
pipe (`StableDiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(
pipe, prompt, max_length - 2
)
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(
prompt, max_length=max_length, truncation=True
).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(
max_length, max([len(token) for token in uncond_tokens])
)
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(
uncond_tokens, dtype=torch.long, device="cpu"
)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
prompt_weights = torch.tensor(
prompt_weights, dtype=torch.float, device="cpu"
)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
uncond_weights = torch.tensor(
uncond_weights, dtype=torch.float, device="cpu"
)
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@@ -0,0 +1,4 @@
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)

View File

@@ -0,0 +1,66 @@
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
def get_schedulers(model_id):
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"SharkEulerDiscrete"
] = SharkEulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers

View File

@@ -0,0 +1,157 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_shark_model,
args,
)
import torch
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
)
def compile(self):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
device = args.device.split(":", 1)[0].strip()
model_input = {
"euler": {
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
}
example_latent = model_input["euler"]["latent"]
example_output = model_input["euler"]["output"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_sigma = model_input["euler"]["sigma"]
example_dt = model_input["euler"]["dt"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
def _import(self):
scaling_model = ScalingModel()
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model, _ = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
if args.import_mlir:
_import(self)
else:
try:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_step_" + args.precision,
iree_flags,
)
except:
print(
"failed to download model, falling back and using import_mlir"
)
args.import_mlir = True
_import(self)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
return self.scaling_model(
"forward",
(
sample,
sigma,
),
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
return self.step_model(
"forward",
(
noise_pred,
sigma,
latent,
dt,
),
send_to_host=False,
)

View File

@@ -0,0 +1,40 @@
from apps.stable_diffusion.src.utils.profiler import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.utils.resources import (
prompt_examples,
models_db,
base_models,
opt_flags,
resource_path,
)
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
controlnet_hint_conversion,
get_stencil_model_id,
)
from apps.stable_diffusion.src.utils.utils import (
get_shark_model,
compile_through_fx,
set_iree_runtime_flags,
map_device_to_name_path,
set_init_device_flags,
get_available_devices,
get_opt_flags,
preprocessCKPT,
convert_original_vae,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
get_path_stem,
get_extended_name,
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
clear_all,
save_output_img,
get_generation_text_info,
update_lora_weight,
resize_stencil,
)

View File

@@ -0,0 +1,18 @@
from apps.stable_diffusion.src.utils.stable_args import args
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device:
import iree
print(f"Profiling and saving to {file_path}.")
vulkan_device = iree.runtime.get_device(args.device)
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
return vulkan_device
return None
def end_profiling(device):
if device:
return device.end_profiling()

View File

@@ -0,0 +1,37 @@
import os
import json
import sys
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
def get_json_file(path):
json_var = []
loc_json = resource_path(path)
if os.path.exists(loc_json):
with open(loc_json, encoding="utf-8") as fopen:
json_var = json.load(fopen)
if not json_var:
print(f"Unable to fetch {path}")
return json_var
# TODO: This shouldn't be called from here, every time the file imports
# it will run all the global vars.
prompt_examples = get_json_file("resources/prompts.json")
models_db = get_json_file("resources/model_db.json")
# The base_model contains the input configuration for the different
# models and also helps in providing information for the variants.
base_models = get_json_file("resources/base_model.json")
# Contains optimization flags for different models.
opt_flags = get_json_file("resources/opt_flags.json")

View File

@@ -0,0 +1,296 @@
{
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"vae_upscaler": {
"latents" : {
"shape" : [
"1*batch_size",4,"8*height","8*width"
],
"dtype":"f32"
}
}
},
"unet": {
"stabilityai/stable-diffusion-2-1": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-2-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"runwayml/stable-diffusion-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-x4-upscaler": {
"latents": {
"shape": [
"2*batch_size",
7,
"8*height",
"8*width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"noise_level": {
"shape": [2],
"dtype": "i64"
}
}
},
"stencil_adaptor": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
}
},
"stencil_unet": {
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
},
"control1": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control2": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control3": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control4": {
"shape": [2, 320, "height/2", "width/2"],
"dtype": "f32"
},
"control5": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"control6": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"control7": {
"shape": [2, 640, "height/4", "width/4"],
"dtype": "f32"
},
"control8": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"control9": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"control10": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control11": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control12": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
}
}
}

View File

@@ -0,0 +1,23 @@
[
{
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
"stablediffusion/inpaint_v1":"runwayml/stable-diffusion-inpainting",
"stablediffusion/inpaint_v2":"stabilityai/stable-diffusion-2-inpainting",
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
"openjourney/v1_4":"prompthero/openjourney",
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
},
{
"stablediffusion/fp16":"fp16",
"stablediffusion/fp32":"main",
"anythingv3/fp16":"diffusers",
"anythingv3/fp32":"diffusers",
"analogdiffusion/fp16":"main",
"analogdiffusion/fp32":"main",
"openjourney/fp16":"main",
"openjourney/fp32":"main"
}
]

View File

@@ -0,0 +1,19 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/nightly"
},
{
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan"
}
]

View File

@@ -0,0 +1,84 @@
{
"unet": {
"tuned": {
"fp16": {
"default_compilation_flags": []
},
"fp32": {
"default_compilation_flags": []
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
},
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
},
"fp32": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
},
"clip": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
}
}

View File

@@ -0,0 +1,8 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]

View File

@@ -0,0 +1,259 @@
import os
import io
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import iree_target_map, run_cmd
from shark.shark_downloader import (
download_model,
download_public_file,
WORKDIR,
)
from shark.parser import shark_args
from apps.stable_diffusion.src.utils.stable_args import args
def get_device():
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
return device
def get_device_args():
device = get_device()
device_spec_args = []
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
gpu_flags = get_iree_gpu_args()
for flag in gpu_flags:
device_spec_args.append(flag)
elif device == "vulkan":
device_spec_args.append(
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
return device, device_spec_args
# Download the model (Unet or VAE fp16) from shark_tank
def load_model_from_tank():
from apps.stable_diffusion.src.models import (
get_params,
get_variant_version,
)
variant, version = get_variant_version(args.hf_model_id)
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{variant}/untuned"
if args.annotation_model == "unet":
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
is_base = "/base" if args.use_base_vae else ""
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)
return mlir_model, model_name
# Download the tuned config files from shark_tank
def load_winograd_configs():
device = get_device()
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
if not os.path.exists(WORKDIR):
os.mkdir(WORKDIR)
winograd_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
def load_lower_configs(base_model_id=None):
from apps.stable_diffusion.src.models import get_variant_version
from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
)
if not base_model_id:
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
variant, version = get_variant_version(base_model_id)
if version == "inpaint_v1":
version = "v1_4"
elif version == "inpaint_v2":
version = "v2_1base"
config_bucket = "gs://shark_tank/sd_tuned_configs/"
device, device_spec_args = get_device_args()
spec = ""
if device_spec_args:
spec = device_spec_args[-1].split("=")[-1].strip()
if device == "vulkan":
spec = spec.split("-")[0]
if args.annotation_model == "vae":
if not spec or spec in ["rdna3", "sm_80"]:
config_name = (
f"{args.annotation_model}_{args.precision}_{device}.json"
)
else:
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
and args.width == 768
):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading lowering config file from ", lowering_config_dir)
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
# Annotate the model with Winograd attribute on selected conv ops
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
with create_context() as ctx:
winograd_model = model_annotation(
ctx,
input_contents=input_mlir,
config_path=winograd_config_dir,
search_op="conv",
winograd=True,
)
bytecode_stream = io.BytesIO()
winograd_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
if args.save_annotation:
if model_name.split("_")[-1] != "tuned":
out_file_path = os.path.join(
args.annotation_output, model_name + "_tuned_torch.mlir"
)
else:
out_file_path = os.path.join(
args.annotation_output, model_name + "_torch.mlir"
)
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return bytecode
def dump_after_mlir(input_mlir, use_winograd):
import iree.compiler as ireec
device, device_spec_args = get_device_args()
if use_winograd:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
else:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
dump_module = ireec.compile_str(
input_mlir,
target_backends=[iree_target_map(device)],
extra_args=device_spec_args
+ [
preprocess_flag,
"--compile-to=preprocessing",
],
)
return dump_module
# For Unet annotate the model with tuned lowering configs
def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
# Dump IR after padding/img2col/winograd passes
dump_module = dump_after_mlir(input_mlir, use_winograd)
print("Applying tuned configs on", model_name)
# Annotate the model with lowering configs in the config file
with create_context() as ctx:
tuned_model = model_annotation(
ctx,
input_contents=dump_module,
config_path=lowering_config_dir,
search_op="all",
)
bytecode_stream = io.BytesIO()
tuned_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
if args.save_annotation:
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return bytecode
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
winograd_model, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
tuned_model = mlir_model
else:
use_winograd = False
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)
return tuned_model
if __name__ == "__main__":
mlir_model, model_name = load_model_from_tank()
sd_model_annotation(mlir_model, model_name)

View File

@@ -0,0 +1,594 @@
import argparse
import os
from pathlib import Path
def path_expand(s):
return Path(s).expanduser().resolve()
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
)
p.add_argument(
"-p",
"--prompts",
nargs="+",
default=["cyberpunk forest by Salvador Dali"],
help="text of which images to be generated.",
)
p.add_argument(
"--negative_prompts",
nargs="+",
default=["trees, green"],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--img_path",
type=str,
help="Path to the image input for img2img/inpainting",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=int,
default=-1,
help="the seed to use. -1 for a random one.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `batch_count`.",
)
p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 769, 8),
help="the height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 769, 8),
help="the width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="the value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="the value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="the strength of change applied on the given input image for img2img",
)
##############################################################################
### Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The no. of steps to train",
)
##############################################################################
### Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture",
)
p.add_argument(
"--inpaint_full_res_padding",
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding",
)
p.add_argument(
"--pixels",
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting",
)
p.add_argument(
"--mask_blur",
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend left for outpainting",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend right for outpainting",
)
p.add_argument(
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend top for outpainting",
)
p.add_argument(
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend bottom for outpainting",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0)",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0)",
)
##############################################################################
### Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
)
p.add_argument(
"--batch_count",
type=int,
default=1,
help="number of batch to be generated with random seeds in single execution",
)
p.add_argument(
"--ckpt_loc",
type=str,
default="",
help="Path to SD's .ckpt file.",
)
p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
)
p.add_argument(
"--hf_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer)",
)
p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble"],
help="Enable the stencil feature.",
)
p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
)
p.add_argument(
"--ondemand",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
)
##############################################################################
### Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save a generation information json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="if import_mlir is True, saves mlir via the debug option in shark importer. Does nothing if import_mlir is false (the default)",
)
##############################################################################
### Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the progress bar animation during image generation",
)
p.add_argument(
"--ckpt_dir",
type=str,
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="one of: [api, app, web]",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for enabling rest API",
)
p.add_argument(
"--output_gallery",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the output gallery tab, and avoid exposing images under --output_dir in the UI",
)
p.add_argument(
"--output_gallery_followlinks",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether the output gallery tab in the UI should follow symlinks when listing subdirectorys under --output_dir",
)
##############################################################################
### SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file",
)
p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)
p.add_argument(
"--save_annotation",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file",
)
##############################################################################
### SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all",
)
args, unknown = p.parse_known_args()
if args.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
os.getcwd(), args.hf_model_id.replace("/", "_")
)

View File

@@ -0,0 +1,2 @@
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector

View File

@@ -0,0 +1,6 @@
import cv2
class CannyDetector:
def __call__(self, img, low_threshold, high_threshold):
return cv2.Canny(img, low_threshold, high_threshold)

View File

@@ -0,0 +1,62 @@
import requests
from pathlib import Path
import torch
import numpy as np
# from annotator.util import annotator_ckpts_path
from apps.stable_diffusion.src.utils.stencils.openpose.body import Body
from apps.stable_diffusion.src.utils.stencils.openpose.hand import Hand
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
draw_bodypose,
draw_handpose,
handDetect,
)
body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth"
class OpenposeDetector:
def __init__(self):
cwd = Path.cwd()
ckpt_path = Path(cwd, "stencil_annotator")
ckpt_path.mkdir(parents=True, exist_ok=True)
body_modelpath = ckpt_path / "body_pose_model.pth"
hand_modelpath = ckpt_path / "hand_pose_model.pth"
if not body_modelpath.is_file():
r = requests.get(body_model_path, allow_redirects=True)
open(body_modelpath, "wb").write(r.content)
if not hand_modelpath.is_file():
r = requests.get(hand_model_path, allow_redirects=True)
open(hand_modelpath, "wb").write(r.content)
self.body_estimation = Body(body_modelpath)
self.hand_estimation = Hand(hand_modelpath)
def __call__(self, oriImg, hand=False):
oriImg = oriImg[:, :, ::-1].copy()
with torch.no_grad():
candidate, subset = self.body_estimation(oriImg)
canvas = np.zeros_like(oriImg)
canvas = draw_bodypose(canvas, candidate, subset)
if hand:
hands_list = handDetect(candidate, subset, oriImg)
all_hand_peaks = []
for x, y, w, is_left in hands_list:
peaks = self.hand_estimation(
oriImg[y : y + w, x : x + w, :]
)
peaks[:, 0] = np.where(
peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x
)
peaks[:, 1] = np.where(
peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y
)
all_hand_peaks.append(peaks)
canvas = draw_handpose(canvas, all_hand_peaks)
return canvas, dict(
candidate=candidate.tolist(), subset=subset.tolist()
)

Some files were not shown because too many files have changed in this diff Show More