Compare commits

..

1 Commits

Author SHA1 Message Date
Ean Garvey
2bd2bfa8b9 Allow SD model wrapper to be called for sharktank gen. 2023-04-18 21:09:36 -05:00
219 changed files with 15961 additions and 12116 deletions

View File

@@ -2,4 +2,4 @@
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py
exclude = lit.cfg.py

View File

@@ -51,13 +51,26 @@ jobs:
run: |
./setup_venv.ps1
python process_skipfiles.py
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip install -e .
pip freeze -l
pyinstaller .\apps\shark_studio\shark_studio.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
python process_skipfiles.py
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
# GHA windows VM OOMs so disable for now
#- name: Build and validate the SHARK Runtime package
# shell: powershell
# run: |
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
#- uses: actions/upload-artifact@v2
# with:
# path: dist/*
- name: Upload Release Assets
id: upload-release-assets
uses: dwenegar/upload-release-assets@v1
@@ -65,7 +78,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
assets_path: ./dist/nodai*
assets_path: ./dist/*
#asset_content_type: application/vnd.microsoft.portable-executable
- name: Publish Release
@@ -75,3 +88,80 @@ jobs:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
linux-build:
runs-on: a100
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
backend: [IREE, SHARK]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Setup pip cache
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
- name: Build and validate the IREE package
if: ${{ matrix.backend == 'IREE' }}
continue-on-error: true
run: |
cd $GITHUB_WORKSPACE
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
source iree.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt
if !(grep -Fxq " failed" pytest_results.txt)
then
export SHA=$(git log -1 --format='%h')
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
fi
rm -rf ./wheelhouse/nodai*
- name: Build and validate the SHARK Runtime package
if: ${{ matrix.backend == 'SHARK' }}
run: |
cd $GITHUB_WORKSPACE
./setup_venv.sh
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt

163
.github/workflows/test-models.yml vendored Normal file
View File

@@ -0,0 +1,163 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Validate Models on Shark Runtime
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
suite: [cpu,cuda,vulkan]
python-version: ["3.11"]
include:
- os: ubuntu-latest
suite: lint
exclude:
- os: ubuntu-latest
suite: vulkan
- os: ubuntu-latest
suite: cuda
- os: ubuntu-latest
suite: cpu
- os: MacStudio
suite: cuda
- os: MacStudio
suite: cpu
- os: icelake
suite: vulkan
- os: icelake
suite: cuda
- os: a100
suite: cpu
- os: 7950x
suite: cpu
- os: 7950x
suite: cuda
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
if: matrix.os != '7950x'
- name: Set Environment Variables
if: matrix.os != '7950x'
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
run: |
# See https://github.com/actions/setup-python/issues/433
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'
#cache: 'pip'
#cache-dependency-path: |
# **/requirements-importer.txt
# **/requirements.txt
- uses: actions/checkout@v2
if: matrix.os == '7950x'
- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --check .
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
- name: Validate Models on NVIDIA GPU
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan --update_tank
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
pytest -k vulkan -s --ci
- name: Validate Stable Diffusion Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
python build_tools/stable_diffusion_testing.py --device=vulkan

View File

@@ -1,85 +0,0 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Validate Shark Studio
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [nodai-ubuntu-builder-large]
suite: [cpu] #,cuda,vulkan]
python-version: ["3.11"]
include:
- os: nodai-ubuntu-builder-large
suite: lint
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set Environment Variables
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
run: |
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'
- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --check apps/shark_studio
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
python${{ matrix.python-version }} -m venv shark.venv
source shark.venv/bin/activate
pip install -r requirements.txt --no-cache-dir
pip install -e .
# Disabled due to hang when exporting test llama2
# python apps/shark_studio/tests/api_test.py

28
.gitignore vendored
View File

@@ -2,8 +2,6 @@
__pycache__/
*.py[cod]
*$py.class
*.mlir
*.vmfb
# C extensions
*.so
@@ -159,12 +157,12 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
#.idea/
# vscode related
.vscode
# Shark related artifacts
# Shark related artefacts
*venv/
shark_tmp/
*.vmfb
@@ -172,7 +170,6 @@ shark_tmp/
tank/dict_configs.py
*.csv
reproducers/
apps/shark_studio/web/configs
# ORT related artefacts
cache_models/
@@ -183,29 +180,10 @@ generated_imgs/
# Custom model related artefacts
variants.json
/models/
*.safetensors
models/
# models folder
apps/stable_diffusion/web/models/
# model artifacts (SHARK)
*.tempfile
*.mlir
*.vmfb
# Stencil annotators.
stencil_annotator/
# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData
# Embeded browser cache and other
apps/stable_diffusion/web/EBWebView/
# Llama2 tokenizer configs
llama2_tokenizer_configs/
# Webview2 runtime artefacts
EBWebView/

2
.gitmodules vendored
View File

@@ -1,4 +1,4 @@
[submodule "inference/thirdparty/shark-runtime"]
path = inference/thirdparty/shark-runtime
url =https://github.com/nod-ai/SRT.git
url =https://github.com/nod-ai/SHARK-Runtime.git
branch = shark-06032022

View File

@@ -2,20 +2,18 @@
High Performance Machine Learning Distribution
*We are currently rebuilding SHARK to take advantage of [Turbine](https://github.com/nod-ai/SHARK-Turbine). Until that is complete make sure you use an .exe release or a checkout of the `SHARK-1.0` branch, for a working SHARK*
[![Nightly Release](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
[![Validate torch-models on Shark Runtime](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
<details>
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
#### Linux Drivers
* MESA / RADV drivers wont work with FP16. Please use the latest AMGPU-PRO drivers (non-pro OSS drivers also wont work) or the latest NVidia Linux Drivers.
@@ -24,23 +22,23 @@ Other users please ensure you have your latest vendor drivers and Vulkan SDK fro
</details>
### Quick Start for SHARK Stable Diffusion for Windows 10/11 Users
Install the Driver from (Prerequisites)[https://github.com/nod-ai/SHARK#install-your-hardware-drivers] above
Install the Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
Download the [stable release](https://github.com/nod-ai/shark/releases/latest) or the most recent [SHARK 1.0 pre-release](https://github.com/nod-ai/shark/releases).
Download the [stable release](https://github.com/nod-ai/shark/releases/latest)
Double click the .exe, or [run from the command line](#running) (recommended), and you should have the [UI](http://localhost:8080/) in the browser.
Double click the .exe and you should have the [UI](http://localhost:8080/) in the browser.
If you have custom models put them in a `models/` directory where the .exe is.
If you have custom models put them in a `models/` directory where the .exe is.
Enjoy.
Enjoy.
<details>
<summary>More installation notes</summary>
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
## Running
@@ -48,22 +46,17 @@ Enjoy.
* The first run may take few minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
* You will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/.
* If you prefer to always run in the browser, use the `--ui=web` command argument when running the EXE.
## Stopping
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
</details>
<details>
<summary>Advanced Installation (Only for developers)</summary>
## Advanced Installation (Windows, Linux and macOS) for developers
### Windows 10/11 Users
* Install Git for Windows from [here](https://git-scm.com/download/win) if you don't already have it.
## Check out the code
```shell
@@ -71,22 +64,14 @@ git clone https://github.com/nod-ai/SHARK.git
cd SHARK
```
## Switch to the Correct Branch (IMPORTANT!)
Currently SHARK is being rebuilt for [Turbine](https://github.com/nod-ai/SHARK-Turbine) on the `main` branch. For now you are strongly discouraged from using `main` unless you are working on the rebuild effort, and should not expect the code there to produce a working application for Image Generation, So for now you'll need switch over to the `SHARK-1.0` branch and use the stable code.
```shell
git checkout SHARK-1.0
```
The following setup instructions assume you are on this branch.
## Setup your Python VirtualEnvironment and Dependencies
### Windows 10/11 Users
* Install the latest Python 3.11.x version from [here](https://www.python.org/downloads/windows/)
* Install Git for Windows from [here](https://git-scm.com/download/win)
#### Allow the install script to run in Powershell
```powershell
set-executionpolicy remotesigned
@@ -101,20 +86,21 @@ set-executionpolicy remotesigned
```shell
./setup_venv.sh
source shark1.venv/bin/activate
source shark.venv/bin/activate
```
### Run Stable Diffusion on your device - WebUI
#### Windows 10/11 Users
```powershell
(shark1.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark1.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
(shark.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
```
#### Linux / macOS Users
```shell
(shark1.venv) > cd apps/stable_diffusion/web
(shark1.venv) > python index.py
(shark.venv) > cd apps/stable_diffusion/web
(shark.venv) > python index.py
```
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
@@ -128,7 +114,7 @@ source shark1.venv/bin/activate
#### Windows 10/11 Users
```powershell
(shark1.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
@@ -156,7 +142,7 @@ Here are some samples generated:
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
<details>
@@ -184,7 +170,7 @@ python -m pip install --upgrade pip
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
```shell
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```
### Run shark tank model tests.
@@ -219,7 +205,7 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
If you want to use Python3.11 and with TF Import tools you can use the environment variables like:
Set `USE_IREE=1` to use upstream IREE
```
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
```
### Run any of the hundreds of SHARK tank models via the test framework
@@ -228,7 +214,7 @@ python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use g
# Or a pytest
pytest tank/test_models.py -k "MiniLM"
```
### How to use your locally built IREE / Torch-MLIR with SHARK
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
@@ -254,12 +240,12 @@ Now the SHARK will use your locally build Torch-MLIR repo.
## Benchmarking Dispatches
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
```
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
```
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.
@@ -268,6 +254,7 @@ if you want to instead incorporate this into a python script, you can pass the `
```
shark_module = SharkInference(
mlir_model,
func_name,
device=args.device,
mlir_dialect="tm_tensor",
dispatch_benchmarks="all",
@@ -278,7 +265,7 @@ shark_module = SharkInference(
Output will include:
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
- An .mlir file containing the dispatch benchmark
- An .mlir file containing the dispatch benchmark
- A compiled .vmfb file containing the dispatch benchmark
- An .mlir file containing just the hal executable
- A compiled .vmfb file of the hal executable
@@ -310,7 +297,7 @@ torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
from shark.shark_inference import SharkInference
shark_module = SharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input))
@@ -333,20 +320,15 @@ mhlo_ir = r"""builtin.module {
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
shark_module.compile()
result = shark_module.forward((arg0, arg1))
```
</details>
## Examples Using the REST API
* [Setting up SHARK for use with Blender](./docs/shark_sd_blender.md)
* [Setting up SHARK for use with Koboldcpp](./docs/shark_sd_koboldcpp.md)
## Supported and Validated Models
SHARK is maintained to support the latest innovations in ML Models:
SHARK is maintained to support the latest innovations in ML Models:
| TF HuggingFace Models | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------|----------|-------------|
@@ -372,7 +354,7 @@ For a complete list of the models supported in SHARK, please refer to [tank/READ
* [Upstream IREE issues](https://github.com/google/iree/issues): Feature requests,
bugs, and other work tracking
* [Upstream IREE Discord server](https://discord.gg/wEWh6Z9nMU): Daily development
* [Upstream IREE Discord server](https://discord.gg/26P4xW4): Daily development
discussions with the core team and collaborators
* [iree-discuss email list](https://groups.google.com/forum/#!forum/iree-discuss):
Announcements, general and low-priority discussion
@@ -387,7 +369,7 @@ For a complete list of the models supported in SHARK, please refer to [tank/READ
* Weekly meetings on Mondays 9AM PST. See [here](https://discourse.llvm.org/t/community-meeting-developer-hour-refactoring-recurring-meetings/62575) for more information.
* [MLIR topic within LLVM Discourse](https://llvm.discourse.group/c/llvm-project/mlir/31) SHARK and IREE is enabled by and heavily relies on [MLIR](https://mlir.llvm.org).
</details>
## License
nod.ai SHARK is licensed under the terms of the Apache 2.0 License with LLVM Exceptions.

View File

@@ -1,107 +0,0 @@
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors
import os
import PIL
import numpy as np
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
)
from datetime import datetime
from PIL import Image
from gradio.components.image_editor import (
EditorValue,
)
class control_adapter:
def __init__(
self,
model: str,
):
self.model = None
def export_control_adapter_model(model_keyword):
return None
def export_xl_control_adapter_model(model_keyword):
return None
class preprocessors:
def __init__(
self,
model: str,
):
self.model = None
def export_controlnet_model(model_keyword):
return None
control_adapter_map = {
"sd15": {
"canny": {"initializer": control_adapter.export_control_adapter_model},
"openpose": {"initializer": control_adapter.export_control_adapter_model},
"scribble": {"initializer": control_adapter.export_control_adapter_model},
"zoedepth": {"initializer": control_adapter.export_control_adapter_model},
},
"sdxl": {
"canny": {"initializer": control_adapter.export_xl_control_adapter_model},
},
}
preprocessor_model_map = {
"canny": {"initializer": preprocessors.export_controlnet_model},
"openpose": {"initializer": preprocessors.export_controlnet_model},
"scribble": {"initializer": preprocessors.export_controlnet_model},
"zoedepth": {"initializer": preprocessors.export_controlnet_model},
}
class PreprocessorModel:
def __init__(
self,
hf_model_id,
device="cpu",
):
self.model = hf_model_id
self.device = device
def compile(self):
print("compile not implemented for preprocessor.")
return
def run(self, inputs):
print("run not implemented for preprocessor.")
return inputs
def cnet_preview(model, input_image):
curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints")
if not os.path.exists(control_imgs_path):
os.mkdir(control_imgs_path)
img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png")
match model:
case "canny":
canny = PreprocessorModel("canny")
result = canny(
np.array(input_image),
100,
200,
)
Image.fromarray(result).save(fp=img_dest)
return result, img_dest
case "openpose":
openpose = PreprocessorModel("openpose")
result = openpose(np.array(input_image))
Image.fromarray(result[0]).save(fp=img_dest)
return result, img_dest
case "zoedepth":
zoedepth = PreprocessorModel("ZoeDepth")
result = zoedepth(np.array(input_image))
Image.fromarray(result).save(fp=img_dest)
return result, img_dest
case "scribble":
input_image.save(fp=img_dest)
return input_image, img_dest
case _:
return None, None

View File

@@ -1,125 +0,0 @@
import importlib
import os
import signal
import sys
import warnings
import json
from threading import Thread
from apps.shark_studio.modules.timer import startup_timer
from apps.shark_studio.web.utils.tmp_configs import (
config_tmp,
clear_tmp_mlir,
clear_tmp_imgs,
shark_tmp,
)
def imports():
import torch # noqa: F401
startup_timer.record("import torch")
warnings.filterwarnings(
action="ignore", category=DeprecationWarning, module="torch"
)
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
import gradio # noqa: F401
startup_timer.record("import gradio")
import apps.shark_studio.web.utils.globals as global_obj
global_obj._init()
startup_timer.record("initialize globals")
from apps.shark_studio.modules import (
img_processing,
) # noqa: F401
startup_timer.record("other imports")
def initialize():
configure_sigint_handler()
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
config_tmp()
# clear_tmp_mlir()
clear_tmp_imgs()
from apps.shark_studio.web.utils.file_utils import (
create_model_folders,
)
# Create custom models folders if they don't exist
create_model_folders()
import gradio as gr
# initialize_rest(reload_script_modules=False)
def initialize_rest(*, reload_script_modules=False):
"""
Called both from initialize() and when reloading the webui.
"""
# Keep this for adding reload options to the webUI.
def dumpstacks():
import threading
import traceback
id2name = {th.ident: th.name for th in threading.enumerate()}
code = []
for threadId, stack in sys._current_frames().items():
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
for filename, lineno, name, line in traceback.extract_stack(stack):
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
if line:
code.append(" " + line.strip())
with open(os.path.join(shark_tmp, "stack_dump.log"), "w") as f:
f.write("\n".join(code))
def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware
app.middleware_stack = (
None # reset current middleware to allow modifying user provided list
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly
def configure_cors_middleware(app):
from starlette.middleware.cors import CORSMiddleware
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
cors_options = {
"allow_methods": ["*"],
"allow_headers": ["*"],
"allow_credentials": True,
}
if cmd_opts.api_accept_origin:
cors_options["allow_origins"] = cmd_opts.api_accept_origin.split(",")
app.add_middleware(CORSMiddleware, **cors_options)
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f"Interrupted with signal {sig} in {frame}")
dumpstacks()
os._exit(0)
signal.signal(signal.SIGINT, sigint_handler)

View File

@@ -1,475 +0,0 @@
from turbine_models.custom_models import stateless_llama
from turbine_models.model_runner import vmfbRunner
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
from apps.shark_studio.web.utils.file_utils import (
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.api.utils import parse_device
from urllib.request import urlopen
import iree.runtime as ireert
from itertools import chain
import gc
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
llm_model_map = {
"meta-llama/Llama-2-7b-chat-hf": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"TinyPixel/small-llama2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "TinyPixel/small-llama2",
"compile_flags": ["--iree-opt-const-expr-hoisting=True"],
"stop_token": 2,
"max_tokens": 1024,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
}
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<s>", "</s>"
DEFAULT_CHAT_SYS_PROMPT = """<s>[INST] <<SYS>>
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <</SYS>>\n\n
"""
def append_user_prompt(history, input_prompt):
user_prompt = f"{B_INST} {input_prompt} {E_INST}"
history += user_prompt
return history
class LanguageModel:
def __init__(
self,
model_name,
hf_auth_token=None,
device=None,
quantization="int4",
precision="",
external_weights=None,
use_system_prompt=True,
streaming_llm=False,
):
_, _, self.triple = parse_device(device)
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.device = device.split("=>")[-1].strip()
self.backend = self.device.split("://")[0]
self.driver = self.backend
if "cpu" in device:
self.device = "cpu"
self.backend = "llvm-cpu"
self.driver = "local-task"
print(f"Selected {self.backend} as IREE target backend.")
self.precision = "f32" if "cpu" in device else "f16"
self.quantization = quantization
self.safe_name = self.hf_model_name.replace("/", "_").replace("-", "_")
self.external_weight_file = None
# TODO: find a programmatic solution for model arch spec instead of hardcoding llama2
self.file_spec = "_".join(
[
self.safe_name,
self.precision,
]
)
if self.quantization != "None":
self.file_spec += "_" + self.quantization
if external_weights in ["safetensors", "gguf"]:
self.external_weight_file = get_resource_path(
os.path.join("..", self.file_spec + "." + external_weights)
)
else:
self.external_weights = None
self.external_weight_file = None
if streaming_llm:
# Add streaming suffix to file spec after setting external weights filename.
self.file_spec += "_streaming"
self.streaming_llm = streaming_llm
self.tempfile_name = get_resource_path(
os.path.join("..", f"{self.file_spec}.tempfile")
)
# TODO: Tag vmfb with target triple of device instead of HAL backend
self.vmfb_name = str(
get_resource_path(
os.path.join("..", f"{self.file_spec}_{self.backend}.vmfb.tempfile")
)
)
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.use_system_prompt = use_system_prompt
self.global_iter = 0
self.prev_token_len = 0
self.first_input = True
self.hf_auth_token = hf_auth_token
if self.external_weight_file is not None:
if not os.path.exists(self.external_weight_file):
print(
f"External weight file {self.external_weight_file} does not exist. Generating..."
)
gen_external_params(
hf_model_name=self.hf_model_name,
quantization=self.quantization,
weight_path=self.external_weight_file,
hf_auth_token=hf_auth_token,
precision=self.precision,
)
else:
print(
f"External weight file {self.external_weight_file} found for {self.vmfb_name}"
)
self.external_weight_file = str(self.external_weight_file)
if os.path.exists(self.vmfb_name) and (
external_weights is None or os.path.exists(str(self.external_weight_file))
):
self.runner = vmfbRunner(
device=self.driver,
vmfb_path=self.vmfb_name,
external_weight_path=self.external_weight_file,
)
if self.streaming_llm:
self.model = self.runner.ctx.modules.streaming_state_update
else:
self.model = self.runner.ctx.modules.state_update
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
"initializer"
](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
external_weights=external_weights,
precision=self.precision,
quantization=self.quantization,
streaming_llm=self.streaming_llm,
decomp_attn=True,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.compile()
else:
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
self.compile()
# Reserved for running HF torch model as reference.
self.hf_mod = None
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
# ONLY architecture/api-specific compile-time flags for each backend, if needed.
# hf_model_id-specific global flags currently in model map.
flags = []
if "cpu" in self.backend:
flags.extend(
[
"--iree-global-opt-enable-quantized-matmul-reassociation",
]
)
elif self.backend == "vulkan":
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
elif self.backend == "rocm":
flags.extend(
[
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-outer-dim-concat=true",
"--iree-flow-enable-aggressive-fusion",
]
)
if "gfx9" in self.triple:
flags.extend(
[
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
]
)
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
flatbuffer_blob = compile_module_to_flatbuffer(
self.tempfile_name,
device=self.device,
frontend="auto",
model_config_path=None,
extra_args=flags,
write_to=self.vmfb_name,
)
self.runner = vmfbRunner(
device=self.driver,
vmfb_path=self.vmfb_name,
external_weight_path=self.external_weight_file,
)
if self.streaming_llm:
self.model = self.runner.ctx.modules.streaming_state_update
else:
self.model = self.runner.ctx.modules.state_update
def sanitize_prompt(self, prompt):
if isinstance(prompt, list):
prompt = list(chain.from_iterable(prompt))
prompt = " ".join([x for x in prompt if isinstance(x, str)])
prompt = prompt.replace("\n", " ")
prompt = prompt.replace("\t", " ")
prompt = prompt.replace("\r", " ")
if self.use_system_prompt and self.global_iter == 0:
prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt)
return prompt
else:
return f"{B_INST} {prompt} {E_INST}"
def chat(self, prompt):
prompt = self.sanitize_prompt(prompt)
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
def format_out(results):
return torch.tensor(results.to_host()[0][0])
history = []
for iter in range(self.max_tokens):
if self.streaming_llm:
token_slice = max(self.prev_token_len - 1, 0)
input_tensor = input_tensor[:, token_slice:]
if self.streaming_llm and self.model["get_seq_step"]() > 600:
print("Evicting cache space!")
self.model["evict_kvcache_space"]()
token_len = input_tensor.shape[-1]
device_inputs = [
ireert.asdevicearray(self.runner.config.device, input_tensor)
]
if self.first_input or not self.streaming_llm:
st_time = time.time()
token = self.model["run_initialize"](*device_inputs)
total_time = time.time() - st_time
token_len += 1
self.first_input = False
else:
st_time = time.time()
token = self.model["run_cached_initialize"](*device_inputs)
total_time = time.time() - st_time
token_len += 1
history.append(format_out(token))
while (
format_out(token) != llm_model_map[self.hf_model_name]["stop_token"]
and len(history) < self.max_tokens
):
dec_time = time.time()
if self.streaming_llm and self.model["get_seq_step"]() > 600:
print("Evicting cache space!")
self.model["evict_kvcache_space"]()
token = self.model["run_forward"](token)
history.append(format_out(token))
total_time = time.time() - dec_time
yield self.tokenizer.decode(history), total_time
self.prev_token_len = token_len + len(history)
if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
self.global_iter += 1
return result_output, total_time
# Reference HF model function for sanity checks.
def chat_hf(self, prompt):
if self.hf_mod is None:
self.hf_mod = AutoModelForCausalLM.from_pretrained(
self.hf_model_name,
torch_dtype=torch.float,
token=self.hf_auth_token,
)
prompt = self.sanitize_prompt(prompt)
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
history = []
for iter in range(self.max_tokens):
token_len = input_tensor.shape[-1]
if self.first_input:
st_time = time.time()
result = self.hf_mod(input_tensor)
token = torch.argmax(result.logits[:, -1, :], dim=1)
total_time = time.time() - st_time
token_len += 1
pkv = result.past_key_values
self.first_input = False
history.append(int(token))
while token != llm_model_map[self.hf_model_name]["stop_token"]:
dec_time = time.time()
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
history.append(int(token))
total_time = time.time() - dec_time
token = torch.argmax(result.logits[:, -1, :], dim=1)
pkv = result.past_key_values
yield self.tokenizer.decode(history), total_time
self.prev_token_len = token_len + len(history)
if token == llm_model_map[self.hf_model_name]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
self.global_iter += 1
return result_output, total_time
def get_mfma_spec_path(target_chip, save_dir):
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
attn_spec = urlopen(url).read().decode("utf-8")
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
if os.path.exists(spec_path):
return spec_path
with open(spec_path, "w") as f:
f.write(attn_spec)
return spec_path
def llm_chat_api(InputData: dict):
from datetime import datetime as dt
import apps.shark_studio.web.utils.globals as global_obj
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
if is_chat_completion_api:
print(f"message -> role : {InputData['messages'][0]['role']}")
print(f"message -> content : {InputData['messages'][0]['content']}")
else:
print(f"prompt : {InputData['prompt']}")
model_name = (
InputData["model"]
if "model" in InputData.keys()
else "meta-llama/Llama-2-7b-chat-hf"
)
model_path = llm_model_map[model_name]
device = InputData["device"] if "device" in InputData.keys() else "cpu"
precision = "fp16"
max_tokens = InputData["max_tokens"] if "max_tokens" in InputData.keys() else 4096
device_id = None
if not global_obj.get_llm_obj():
print("\n[LOG] Initializing new pipeline...")
global_obj.clear_cache()
gc.collect()
if "cuda" in device:
device = "cuda"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "cpu" in device:
device = "cpu"
precision = "fp32"
else:
print("unrecognized device")
llm_model = LanguageModel(
model_name=model_name,
hf_auth_token=cmd_opts.hf_auth_token,
device=device,
quantization=cmd_opts.quantization,
external_weights="safetensors",
use_system_prompt=True,
streaming_llm=False,
)
global_obj.set_llm_obj(llm_model)
else:
llm_model = global_obj.get_llm_obj()
llm_model.max_tokens = max_tokens
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = append_user_prompt(
InputData["messages"][0]["role"], InputData["messages"][0]["content"]
)
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
for res_op, _ in llm_model.chat(prompt):
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion" if is_chat_completion_api else "text_completion",
"created": int(end_time),
"choices": choices,
}
if __name__ == "__main__":
lm = LanguageModel(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
)
print("model loaded")
for i in lm.chat("hi, what are you?"):
print(i)

View File

@@ -1,505 +0,0 @@
import gc
import torch
import gradio as gr
import time
import os
import json
import numpy as np
import copy
import importlib.util
import sys
from tqdm.auto import tqdm
from pathlib import Path
from random import randint
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
SharkSDXLPipeline,
)
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.api.utils import parse_device
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.web.utils.file_utils import (
safe_name,
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.img_processing import (
save_output_img,
)
from apps.shark_studio.modules.ckpt_processing import (
preprocessCKPT,
save_irpa,
)
EMPTY_SD_MAP = {
"clip": None,
"scheduler": None,
"unet": None,
"vae_decode": None,
}
EMPTY_SDXL_MAP = {
"prompt_encoder": None,
"scheduled_unet": None,
"vae_decode": None,
"pipeline": None,
"full_pipeline": None,
}
EMPTY_FLAGS = {
"clip": None,
"unet": None,
"vae": None,
"pipeline": None,
}
def load_script(source, module_name):
"""
reads file source and loads it as a module
:param source: file to load
:param module_name: name of module to register in sys.modules
:return: loaded module
"""
spec = importlib.util.spec_from_file_location(module_name, source)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
class StableDiffusion:
# This class is responsible for executing image generation and creating
# /managing a set of compiled modules to run Stable Diffusion. The init
# aims to be as general as possible, and the class will infer and compile
# a list of necessary modules or a combined "pipeline module" for a
# specified job based on the inference task.
def __init__(
self,
base_model_id,
height: int,
width: int,
batch_size: int,
steps: int,
scheduler: str,
precision: str,
device: str,
target_triple: str = None,
custom_vae: str = None,
num_loras: int = 0,
import_ir: bool = True,
is_controlled: bool = False,
external_weights: str = "safetensors",
):
self.precision = precision
self.compiled_pipeline = False
self.base_model_id = base_model_id
self.custom_vae = custom_vae
self.is_sdxl = "xl" in self.base_model_id.lower()
self.is_custom = ".py" in self.base_model_id.lower()
if self.is_custom:
custom_module = load_script(
os.path.join(get_checkpoints_path("scripts"), self.base_model_id),
"custom_pipeline",
)
self.turbine_pipe = custom_module.StudioPipeline
self.model_map = custom_module.MODEL_MAP
elif self.is_sdxl:
self.turbine_pipe = SharkSDXLPipeline
self.model_map = EMPTY_SDXL_MAP
else:
self.turbine_pipe = SharkSDPipeline
self.model_map = EMPTY_SD_MAP
max_length = 64
target_backend, self.rt_device, triple = parse_device(device, target_triple)
pipe_id_list = [
safe_name(base_model_id),
str(batch_size),
str(max_length),
f"{str(height)}x{str(width)}",
precision,
triple,
]
if num_loras > 0:
pipe_id_list.append(str(num_loras) + "lora")
if is_controlled:
pipe_id_list.append("controlled")
if custom_vae:
pipe_id_list.append(custom_vae)
self.pipe_id = "_".join(pipe_id_list)
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.weights_path = Path(
os.path.join(
get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)
)
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
decomp_attn = True
attn_spec = None
if triple in ["gfx940", "gfx942", "gfx90a"]:
decomp_attn = False
attn_spec = "mfma"
elif triple in ["gfx1100", "gfx1103", "gfx1150"]:
decomp_attn = False
attn_spec = "wmma"
if triple in ["gfx1103", "gfx1150"]:
# external weights have issues on igpu
external_weights = None
elif target_backend == "llvm-cpu":
decomp_attn = False
self.sd_pipe = self.turbine_pipe(
hf_model_name=base_model_id,
scheduler_id=scheduler,
height=height,
width=width,
precision=precision,
max_length=max_length,
batch_size=batch_size,
num_inference_steps=steps,
device=target_backend,
iree_target_triple=triple,
ireec_flags=EMPTY_FLAGS,
attn_spec=attn_spec,
decomp_attn=decomp_attn,
pipeline_dir=self.pipeline_dir,
external_weights_dir=self.weights_path,
external_weights=external_weights,
custom_vae=custom_vae,
)
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
gc.collect()
def prepare_pipe(
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
):
print(f"\n[LOG] Preparing pipeline...")
self.is_img2img = False
mlirs = copy.deepcopy(self.model_map)
vmfbs = copy.deepcopy(self.model_map)
weights = copy.deepcopy(self.model_map)
if not self.is_sdxl:
compiled_pipeline = False
self.compiled_pipeline = compiled_pipeline
if custom_weights:
custom_weights = os.path.join(
get_checkpoints_path("checkpoints"),
safe_name(self.base_model_id.split("/")[-1]),
custom_weights,
)
diffusers_weights_path = preprocessCKPT(custom_weights, self.precision)
for key in weights:
if key in ["scheduled_unet", "unet"]:
unet_weights_path = os.path.join(
diffusers_weights_path,
"unet",
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(unet_weights_path, "unet.")
elif key in ["clip", "prompt_encoder"]:
if not self.is_sdxl:
sd1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
else:
clip_1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
clip_2_path = os.path.join(
diffusers_weights_path,
"text_encoder_2",
"model.safetensors",
)
weights[key] = [
save_irpa(clip_1_path, "text_encoder_model_1."),
save_irpa(clip_2_path, "text_encoder_model_2."),
]
elif key in ["vae_decode"] and weights[key] is None:
vae_weights_path = os.path.join(
diffusers_weights_path,
"vae",
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(vae_weights_path, "vae.")
vmfbs, weights = self.sd_pipe.check_prepared(
mlirs, vmfbs, weights, interactive=False
)
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
self.sd_pipe.load_pipeline(
vmfbs, weights, self.rt_device, self.compiled_pipeline
)
print(
"\n[LOG] Pipeline successfully prepared for runtime. Generating images..."
)
return
def generate_images(
self,
prompt,
negative_prompt,
image,
strength,
guidance_scale,
seed,
ondemand,
resample_type,
control_mode,
hints,
):
img = self.sd_pipe.generate_images(
prompt,
negative_prompt,
1,
guidance_scale,
seed,
return_imgs=True,
)
return img
def shark_sd_fn_dict_input(
sd_kwargs: dict,
):
print("\n[LOG] Submitting Request...")
for key in sd_kwargs:
if sd_kwargs[key] in [None, []]:
sd_kwargs[key] = None
if sd_kwargs[key] in ["None"]:
sd_kwargs[key] = ""
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
if not sd_kwargs["device"]:
gr.Warning("No device specified. Please specify a device.")
return None, ""
if sd_kwargs["height"] not in [512, 1024]:
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
return None, ""
if sd_kwargs["height"] != sd_kwargs["width"]:
gr.Warning("Height and width must be the same. This is a temporary limitation.")
return None, ""
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
if sd_kwargs["steps"] > 10:
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
return None, ""
if sd_kwargs["guidance_scale"] > 3:
gr.Warning(
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
)
return None, ""
if sd_kwargs["target_triple"] == "":
if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "":
gr.Warning(
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
)
return None, ""
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
return generated_imgs
def shark_sd_fn(
prompt,
negative_prompt,
sd_init_image: list,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: list,
batch_count: int,
batch_size: int,
scheduler: str,
base_model_id: str,
custom_weights: str,
custom_vae: str,
precision: str,
device: str,
target_triple: str,
ondemand: bool,
compiled_pipeline: bool,
resample_type: str,
controlnets: dict,
embeddings: dict,
):
sd_kwargs = locals()
if not isinstance(sd_init_image, list):
sd_init_image = [sd_init_image]
is_img2img = True if sd_init_image[0] is not None else False
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
adapters = {}
is_controlled = False
control_mode = None
hints = []
num_loras = 0
import_ir = True
for i in embeddings:
num_loras += 1 if embeddings[i] else 0
if "model" in controlnets:
for i, model in enumerate(controlnets["model"]):
if "xl" not in base_model_id.lower():
adapters[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map["runwayml/stable-diffusion-v1-5"][
model
],
"strength": controlnets["strength"][i],
}
else:
adapters[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][
model
],
"strength": controlnets["strength"][i],
}
if model is not None:
is_controlled = True
control_mode = controlnets["control_mode"]
for i in controlnets["hint"]:
hints.append[i]
submit_pipe_kwargs = {
"base_model_id": base_model_id,
"height": height,
"width": width,
"batch_size": batch_size,
"precision": precision,
"device": device,
"target_triple": target_triple,
"custom_vae": custom_vae,
"num_loras": num_loras,
"import_ir": import_ir,
"is_controlled": is_controlled,
"steps": steps,
"scheduler": scheduler,
}
submit_prep_kwargs = {
"custom_weights": custom_weights,
"adapters": adapters,
"embeddings": embeddings,
"is_img2img": is_img2img,
"compiled_pipeline": compiled_pipeline,
}
submit_run_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": sd_init_image,
"strength": strength,
"guidance_scale": guidance_scale,
"seed": seed,
"ondemand": ondemand,
"resample_type": resample_type,
"control_mode": control_mode,
"hints": hints,
}
if (
not global_obj.get_sd_obj()
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
):
print("\n[LOG] Initializing new pipeline...")
global_obj.clear_cache()
gc.collect()
# Initializes the pipeline and retrieves IR based on all
# parameters that are static in the turbine output format,
# which is currently MLIR in the torch dialect.
sd_pipe = StableDiffusion(
**submit_pipe_kwargs,
)
global_obj.set_sd_obj(sd_pipe)
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
if (
not global_obj.get_prep_kwargs()
or global_obj.get_prep_kwargs() != submit_prep_kwargs
):
global_obj.set_prep_kwargs(submit_prep_kwargs)
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
generated_imgs = []
for current_batch in range(batch_count):
start_time = time.time()
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
if not isinstance(out_imgs, list):
out_imgs = [out_imgs]
# total_time = time.time() - start_time
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
# print(f"\n[LOG] {text_output}")
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
# break
# else:
for batch in range(batch_size):
save_output_img(
out_imgs[batch],
seed,
sd_kwargs,
)
generated_imgs.extend(out_imgs)
# TODO: make seed changes over batch counts more configurable.
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)
return (generated_imgs, "")
def unload_sd():
print("Unloading models.")
import apps.shark_studio.web.utils.globals as global_obj
global_obj.clear_cache()
gc.collect()
def cancel_sd():
print("Inject call to cancel longer API calls.")
return
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
content = fopen.read()
return content
def safe_name(name):
return name.replace("/", "_").replace("\\", "_").replace(".", "_")
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
global_obj._init()
sd_json = view_json_file(
get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json"))
)
sd_kwargs = json.loads(sd_json)
for arg in vars(cmd_opts):
if arg in sd_kwargs:
sd_kwargs[arg] = getattr(cmd_opts, arg)
for i in shark_sd_fn_dict_input(sd_kwargs):
print(i)

View File

@@ -1,389 +0,0 @@
import numpy as np
import json
from random import (
randint,
seed as seed_random,
getstate as random_getstate,
setstate as random_setstate,
)
from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from cpuinfo import get_cpu_info
# TODO: migrate these utils to studio
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
if "local" in driver_name:
device_list.append(
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
# for drivers with single devices
# let the default device be selected without any indexing
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
hip_devices = get_devices_by_name("hip")
available_devices.extend(hip_devices)
for idx, device_str in enumerate(available_devices):
if "AMD Radeon(TM) Graphics =>" in device_str:
igpu_id_candidates = [
x.split("w/")[-1].split("=>")[0]
for x in available_devices
if "M Graphics" in x
]
for igpu_name in igpu_id_candidates:
if igpu_name:
available_devices[idx] = device_str.replace(
"AMD Radeon(TM) Graphics", igpu_name
)
break
return available_devices
def set_init_device_flags():
if "vulkan" in cmd_opts.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
cmd_opts.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_vulkan_target_triple}."
)
elif "cuda" in cmd_opts.device:
cmd_opts.device = "cuda"
elif "metal" in cmd_opts.device:
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_metal_target_platform:
from shark.iree_utils.metal_utils import get_metal_target_triple
triple = get_metal_target_triple(device_name)
if triple is not None:
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_metal_target_platform}."
)
elif "cpu" in cmd_opts.device:
cmd_opts.device = "cpu"
def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if cmd_opts.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if cmd_opts.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def parse_device(device_str, target_override=""):
from shark.iree_utils.compile_utils import (
clean_device_info,
get_iree_target_triple,
iree_target_map,
)
rt_driver, device_id = clean_device_info(device_str)
target_backend = iree_target_map(rt_driver)
if device_id:
rt_device = f"{rt_driver}://{device_id}"
else:
rt_device = rt_driver
if target_override:
return target_backend, rt_device, target_override
match target_backend:
case "vulkan-spirv":
triple = get_iree_target_triple(device_str)
return target_backend, rt_device, triple
case "rocm":
triple = get_rocm_target_chip(device_str)
return target_backend, rt_device, triple
case "llvm-cpu":
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
def get_rocm_target_chip(device_str):
# TODO: Use a data file to map device_str to target chip.
rocm_chip_map = {
"6700": "gfx1031",
"6800": "gfx1030",
"6900": "gfx1030",
"7900": "gfx1100",
"MI300X": "gfx942",
"MI300A": "gfx940",
"MI210": "gfx90a",
"MI250": "gfx90a",
"MI100": "gfx908",
"MI50": "gfx906",
"MI60": "gfx906",
"780M": "gfx1103",
}
for key in rocm_chip_map:
if key in device_str:
return rocm_chip_map[key]
raise AssertionError(
f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues."
)
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired
combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def get_opt_flags(model, precision="fp16"):
iree_flags = []
if len(cmd_opts.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
)
if "rocm" in cmd_opts.device:
from shark.iree_utils.gpu_utils import get_iree_rocm_args
rocm_args = get_iree_rocm_args()
iree_flags.extend(rocm_args)
if cmd_opts.iree_constant_folding == False:
iree_flags.append("--iree-opt-const-expr-hoisting=False")
iree_flags.append(
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
if cmd_opts.data_tiling == False:
iree_flags.append("--iree-opt-data-tiling=False")
if "vae" not in model:
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user
selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for
the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
if "local" in driver_name:
device_list.append(
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
# for drivers with single devices
# let the default device be selected without any indexing
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
return available_devices
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None
if isinstance(seed_input, int):
return [seed_input]
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
return seed_input
raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)

View File

@@ -1,145 +0,0 @@
import os
import json
import re
import requests
import torch
import safetensors
from shark_turbine.aot.params import (
ParameterArchiveBuilder,
)
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from omegaconf import OmegaConf
from diffusers import StableDiffusionPipeline
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
create_vae_diffusers_config,
convert_ldm_vae_checkpoint,
)
def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = os.path.join("diffusers", path.stem + f"_{precision}")
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
return path_to_diffusers
def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
return path_to_diffusers
else:
print(
"Diffusers' checkpoint will be identified here : ",
path_to_diffusers,
)
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but
# non-EMA weights have been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
# they want to go for EMA weight extraction or not.
extract_ema = False
print("Loading diffusers' pipeline from original stable diffusion checkpoint")
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
)
if precision == "fp16":
pipe.to(dtype=torch.float16)
pipe.save_pretrained(path_to_diffusers)
del pipe
print("Loading complete")
return path_to_diffusers
def save_irpa(weights_path, prepend_str):
weights = safetensors.torch.load_file(weights_path)
archive = ParameterArchiveBuilder()
for key in weights.keys():
new_key = prepend_str + key
archive.add_tensor(new_key, weights[key])
irpa_file = weights_path.replace(".safetensors", ".irpa")
archive.save(irpa_file)
return irpa_file
def convert_original_vae(vae_checkpoint):
vae_state_dict = {}
for key in list(vae_checkpoint.keys()):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
"main/configs/stable-diffusion/v1-inference.yaml"
)
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config)
return converted_vae_checkpoint
def process_custom_pipe_weights(custom_weights):
if custom_weights != "":
if custom_weights.startswith("https://civitai.com/api/"):
# download the checkpoint from civitai if we don't already have it
weights_path = get_civitai_checkpoint(custom_weights)
# act as if we were given the local file as custom_weights originally
custom_weights_tgt = get_path_to_diffusers_checkpoint(weights_path)
custom_weights_params = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights_params = custom_weights
return custom_weights_params, custom_weights_tgt
def get_civitai_checkpoint(url: str):
with requests.get(url, allow_redirects=True, stream=True) as response:
response.raise_for_status()
# civitai api returns the filename in the content disposition
base_filename = re.findall(
'"([^"]*)"', response.headers["Content-Disposition"]
)[0]
destination_path = Path.cwd() / (cmd_opts.model_dir or "models") / base_filename
# we don't have this model downloaded yet
if not destination_path.is_file():
print(f"downloading civitai model from {url} to {destination_path}")
size = int(response.headers["content-length"], 0)
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
with open(destination_path, "wb") as f:
for chunk in response.iter_content(chunk_size=65536):
f.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
# we already have this model downloaded
else:
print(f"civitai model already downloaded to {destination_path}")
response.close()
return destination_path.as_posix()

View File

@@ -1,185 +0,0 @@
import os
import sys
import torch
import json
import safetensors
from dataclasses import dataclass
from safetensors.torch import load_file
from apps.shark_studio.web.utils.file_utils import (
get_checkpoint_pathfile,
get_path_stem,
)
@dataclass
class LoRAweight:
up: torch.tensor
down: torch.tensor
mid: torch.tensor
alpha: torch.float32 = 1.0
def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
# gather the weights from the LoRA in a more convenient form, assumes
# everything will have an up.weight.
weight_dict: dict[str, LoRAweight] = {}
for key in state_dict:
if key.startswith(splitting_prefix) and key.endswith("up.weight"):
stem = key.split("up.weight")[0]
weight_key = stem.removesuffix(".lora_")
weight_key = weight_key.removesuffix("_lora_")
weight_key = weight_key.removesuffix(".lora_linear_layer.")
if weight_key not in weight_dict:
weight_dict[weight_key] = LoRAweight(
state_dict[f"{stem}up.weight"],
state_dict[f"{stem}down.weight"],
state_dict.get(f"{stem}mid.weight", None),
(
state_dict[f"{weight_key}.alpha"]
/ state_dict[f"{stem}up.weight"].shape[1]
if f"{weight_key}.alpha" in state_dict
else 1.0
),
)
# Directly update weight in model
# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
# and similar code in https://github.com/huggingface/diffusers/issues/3064
# TODO: handle mid weights (how do they even work?)
for key, lora_weight in weight_dict.items():
curr_layer = model
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
weight = curr_layer.weight.data
scale = lora_weight.alpha * lora_strength
if len(weight.size()) == 2:
if len(lora_weight.up.shape) == 4:
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
change = torch.mm(lora_weight.up, lora_weight.down)
elif lora_weight.down.size()[2:4] == (1, 1):
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
change = torch.nn.functional.conv2d(
lora_weight.down.permute(1, 0, 2, 3),
lora_weight.up,
).permute(1, 0, 2, 3)
curr_layer.weight.data += change * scale
return model
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
return unet
main_file_name = get_path_stem(use_lora)
if ".bin" in use_lora:
main_file_name += ".bin"
elif ".safetensors" in use_lora:
main_file_name += ".safetensors"
elif ".pt" in use_lora:
main_file_name += ".pt"
else:
sys.exit("Only .bin and .safetensors format for LoRA is supported")
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)
def update_lora_weight(model, use_lora, model_name, lora_strength=1.0):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora, lora_strength)
try:
return processLoRA(model, use_lora, "lora_te_", lora_strength)
except:
return None
def get_lora_metadata(lora_filename):
# get the metadata from the file
filename = get_checkpoint_pathfile(lora_filename, "lora")
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
metadata = f.metadata()
# guard clause for if there isn't any metadata
if not metadata:
return None
# metadata is a dictionary of strings, the values of the keys we're
# interested in are actually json, and need to be loaded as such
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
tag_dirs = [dir for dir in tag_frequencies.keys()]
# gather the tag frequency information for all the datasets trained
all_frequencies = {}
for dataset in tag_dirs:
frequencies = sorted(
[entry for entry in tag_frequencies[dataset].items()],
reverse=True,
key=lambda x: x[1],
)
# get a figure for the total number of images processed for this dataset
# either then number actually listed or in its dataset_dir entry or
# the highest frequency's number if that doesn't exist
img_count = dataset_dirs.get(dir, {}).get("img_count", frequencies[0][1])
# add the dataset frequencies to the overall frequencies replacing the
# frequency counts on the tags with a percentage/ratio
all_frequencies.update(
[(entry[0], entry[1] / img_count) for entry in frequencies]
)
trained_model_id = " ".join(
[
metadata.get("ss_sd_model_hash", ""),
metadata.get("ss_sd_model_name", ""),
metadata.get("ss_base_model_version", ""),
]
).strip()
# return the topmost <count> of all frequencies in all datasets
return {
"model": trained_model_id,
"frequencies": sorted(
all_frequencies.items(), reverse=True, key=lambda x: x[1]
),
}

View File

@@ -1,202 +0,0 @@
import os
import re
import json
import torch
import numpy as np
from csv import DictWriter
from PIL import Image, PngImagePlugin
from pathlib import Path
from datetime import datetime as dt
from base64 import decode
resamplers = {
"Lanczos": Image.Resampling.LANCZOS,
"Nearest Neighbor": Image.Resampling.NEAREST,
"Bilinear": Image.Resampling.BILINEAR,
"Bicubic": Image.Resampling.BICUBIC,
"Hamming": Image.Resampling.HAMMING,
"Box": Image.Resampling.BOX,
}
resampler_list = resamplers.keys()
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info=None):
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
if extra_info is None:
extra_info = {}
generated_imgs_path = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompt"][0][:15])
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
img_model = extra_info["base_model_id"]
if extra_info["custom_weights"] not in [None, "None"]:
img_model = Path(os.path.basename(extra_info["custom_weights"])).stem
img_vae = None
if extra_info["custom_vae"]:
img_vae = Path(os.path.basename(extra_info["custom_vae"])).stem
img_loras = None
if extra_info["embeddings"]:
img_lora = []
for i in extra_info["embeddings"]:
img_lora += Path(os.path.basename(cmd_opts.use_lora)).stem
img_loras = ", ".join(img_lora)
if cmd_opts.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if cmd_opts.write_metadata_to_png:
# Using a conditional expression caused problems, so setting a new
# variable for now.
# if cmd_opts.use_hiresfix:
# png_size_text = (
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
# )
# else:
png_size_text = f"{extra_info['width']}x{extra_info['height']}"
pngInfo.add_text(
"parameters",
f"{extra_info['prompt'][0]}"
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
f"\nSteps: {extra_info['steps']},"
f"Sampler: {extra_info['scheduler']}, "
f"CFG scale: {extra_info['guidance_scale']}, "
f"Seed: {img_seed},"
f"Size: {png_size_text}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_loras}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if cmd_opts.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {cmd_opts.output_img_format} is not "
f"supported yet. Image saved as png instead."
f"Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
# importance for each data point. Something to consider.
new_entry = {}
new_entry.update(extra_info)
csv_mode = "a" if os.path.isfile(csv_path) else "w"
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
if csv_mode == "w":
dictwriter_obj.writeheader()
dictwriter_obj.writerow(new_entry)
csv_obj.close()
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
# For stencil, the input image can be of any size, but we need to ensure that
# it conforms with our model constraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image, width, height, resampler_type=None):
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 128:
n_size = 128
if width == min_size:
width = n_size
height = n_size / aspect_ratio
else:
height = n_size
width = n_size * aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
min_size = min(width, height)
if min_size > 768:
n_size = 768
if width == min_size:
height = n_size
width = n_size * aspect_ratio
else:
width = n_size
height = n_size / aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
if resampler_type in resamplers:
resampler = resamplers[resampler_type]
else:
resampler = resamplers["Nearest Neighbor"]
new_image = image.resize((n_width, n_height), resampler=resampler)
return new_image, n_width, n_height
def process_sd_init_image(self, sd_init_image, resample_type):
if isinstance(sd_init_image, list):
images = []
for img in sd_init_image:
img, _ = self.process_sd_init_image(img, resample_type)
images.append(img)
is_img2img = True
return images, is_img2img
if isinstance(sd_init_image, str):
if os.path.isfile(sd_init_image):
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type)
else:
image = None
is_img2img = False
elif isinstance(sd_init_image, Image.Image):
image = sd_init_image.convert("RGB")
elif sd_init_image:
image = sd_init_image["image"].convert("RGB")
else:
image = None
is_img2img = False
if image:
resample_type = (
resamplers[resample_type]
if resample_type in resampler_list
# Fallback to Lanczos
else Image.Resampling.LANCZOS
)
image = image.resize((self.width, self.height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
image_arr = 2 * (image_arr - 0.5)
is_img2img = True
image = image_arr
return image, is_img2img

View File

@@ -1,37 +0,0 @@
import sys
class Logger:
def __init__(self, filename, filter=None):
self.terminal = sys.stdout
self.log = open(filename, "w")
self.filter = filter
def write(self, message):
for x in message.split("\n"):
if self.filter in x:
self.log.write(message)
else:
self.terminal.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
def logger_test(x):
print("[LOG] This is a test")
print(f"This is another test, without the filter")
return x
def read_sd_logs():
sys.stdout.flush()
with open("shark_tmp/sd.log", "r") as f:
return f.read()
sys.stdout = Logger("shark_tmp/sd.log", filter="[LOG]")

View File

@@ -1,205 +0,0 @@
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
clean_device_info,
get_iree_target_triple,
)
from apps.shark_studio.web.utils.file_utils import (
get_checkpoints_path,
get_resource_path,
)
from apps.shark_studio.modules.shared_cmd_opts import (
cmd_opts,
)
from iree import runtime as ireert
from pathlib import Path
import gc
import os
class SharkPipelineBase:
# This class is a lightweight base for managing an
# inference API class. It should provide methods for:
# - compiling a set (model map) of torch IR modules
# - preparing weights for an inference job
# - loading weights for an inference job
# - utilites like benchmarks, tests
def __init__(
self,
model_map: dict,
base_model_id: str,
static_kwargs: dict,
device: str,
import_mlir: bool = True,
):
self.model_map = model_map
self.pipe_map = {}
self.static_kwargs = static_kwargs
self.base_model_id = base_model_id
self.triple = get_iree_target_triple(device)
self.device, self.device_id = clean_device_info(device)
self.import_mlir = import_mlir
self.iree_module_dict = {}
self.tmp_dir = get_resource_path(cmd_opts.tmp_dir)
if not os.path.exists(self.tmp_dir):
os.mkdir(self.tmp_dir)
self.tempfiles = {}
self.pipe_vmfb_path = ""
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
# First checks whether we have .vmfbs precompiled, then populates the map
# with the precompiled executables and fetches executables for the rest of the map.
# The weights aren't static here anymore so this function should be a part of pipeline
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
# and your model map is populated with any IR - unique model IDs and their static params,
# call this method to get the artifacts associated with your map.
self.pipe_id = self.safe_name(pipe_id)
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
if submodel == "None":
print("\n[LOG] Gathering any pre-compiled artifacts....")
for key in self.model_map:
self.get_compiled_map(pipe_id, submodel=key)
else:
self.pipe_map[submodel] = {}
self.get_precompiled(self.pipe_id, submodel)
ireec_flags = []
if submodel in self.iree_module_dict:
return
elif "vmfb_path" in self.pipe_map[submodel]:
return
elif submodel not in self.tempfiles:
print(
f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..."
)
if submodel in self.static_kwargs:
init_kwargs = self.static_kwargs[submodel]
for key in self.static_kwargs["pipe"]:
if key not in init_kwargs:
init_kwargs[key] = self.static_kwargs["pipe"][key]
self.import_torch_ir(submodel, init_kwargs)
self.get_compiled_map(pipe_id, submodel)
else:
ireec_flags = (
self.model_map[submodel]["ireec_flags"]
if "ireec_flags" in self.model_map[submodel]
else []
)
weights_path = self.get_io_params(submodel)
if weights_path:
ireec_flags.append("--iree-opt-const-eval=False")
self.iree_module_dict[submodel] = get_iree_compiled_module(
self.tempfiles[submodel],
device=self.device,
frontend="torch",
mmap=True,
external_weight_file=weights_path,
extra_args=ireec_flags,
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
)
return
def get_io_params(self, submodel):
if "external_weight_file" in self.static_kwargs[submodel]:
# we are using custom weights
weights_path = self.static_kwargs[submodel]["external_weight_file"]
elif "external_weight_path" in self.static_kwargs[submodel]:
# we are using the default weights for the HF model
weights_path = self.static_kwargs[submodel]["external_weight_path"]
else:
# assume the torch IR contains the weights.
weights_path = None
return weights_path
def get_precompiled(self, pipe_id, submodel="None"):
if submodel == "None":
for model in self.model_map:
self.get_precompiled(pipe_id, model)
vmfbs = []
for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path):
vmfbs.extend(filenames)
break
for file in vmfbs:
if submodel in file:
self.pipe_map[submodel]["vmfb_path"] = os.path.join(
self.pipe_vmfb_path, file
)
return
def import_torch_ir(self, submodel, kwargs):
torch_ir = self.model_map[submodel]["initializer"](
**self.safe_dict(kwargs), compile_to="torch"
)
if submodel == "clip":
# clip.export_clip_model returns (torch_ir, tokenizer)
torch_ir = torch_ir[0]
self.tempfiles[submodel] = os.path.join(
self.tmp_dir, f"{submodel}.torch.tempfile"
)
with open(self.tempfiles[submodel], "w+") as f:
f.write(torch_ir)
del torch_ir
gc.collect()
return
def load_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
print(f"\n[LOG] {submodel} is ready for inference.")
continue
if "vmfb_path" in self.pipe_map[submodel]:
weights_path = self.get_io_params(submodel)
# print(
# f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
# )
self.iree_module_dict[submodel] = {}
(
self.iree_module_dict[submodel]["vmfb"],
self.iree_module_dict[submodel]["config"],
self.iree_module_dict[submodel]["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
self.pipe_map[submodel]["vmfb_path"],
self.device,
device_idx=0,
rt_flags=[],
external_weight_file=weights_path,
)
else:
self.get_compiled_map(self.pipe_id, submodel)
return
def unload_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
del self.iree_module_dict[submodel]
gc.collect()
return
def run(self, submodel, inputs):
if not isinstance(inputs, list):
inputs = [inputs]
inp = [
ireert.asdevicearray(
self.iree_module_dict[submodel]["config"].device, input
)
for input in inputs
]
return self.iree_module_dict[submodel]["vmfb"]["main"](*inp)
def safe_name(self, name):
return name.replace("/", "_").replace("-", "_").replace("\\", "_")
def safe_dict(self, kwargs: dict):
flat_args = {}
for i in kwargs:
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
else:
flat_args[i] = kwargs[i]
return flat_args

View File

@@ -1,376 +0,0 @@
from typing import List, Optional, Union
from iree import runtime as ireert
import re
import torch
import numpy as np
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs:
text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens,
weights,
max_length,
bos,
eos,
no_boseos_middle=True,
chunk_length=77,
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j
* (chunk_length - 2) : min(
len(weights[i]), (j + 1) * (chunk_length - 2)
)
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe,
text_input,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.run("clip", text_input_chunk)[0].to_host()
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)
else:
text_embeddings = pipe.run("clip", text_input)[0]
text_embeddings = torch.from_numpy(text_embeddings.to_host())
return text_embeddings
# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
return [[49407 if token is None else token for token in tokens[0]]]
def get_weighted_text_embeddings(
pipe,
prompt: List[str],
uncond_prompt: List[str] = None,
max_embeddings_multiples: Optional[int] = 8,
no_boseos_middle: Optional[bool] = True,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(
pipe, prompt, max_length - 2
)
if uncond_prompt is not None:
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(
prompt, max_length=max_length, truncation=True
).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
prompt_tokens = filter_nonetype_tokens(prompt_tokens)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
uncond_tokens = filter_nonetype_tokens(uncond_tokens)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device="cpu")
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
prompt_weights = torch.tensor(prompt_weights, dtype=torch.float, device="cpu")
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
uncond_weights = torch.tensor(uncond_weights, dtype=torch.float, device="cpu")
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@@ -1,118 +0,0 @@
# from shark_turbine.turbine_models.schedulers import export_scheduler_model
from diffusers import (
LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
def get_schedulers(model_id):
# TODO: switch over to turbine and run all on GPU
print(f"\n[LOG] Initializing schedulers from model id: {model_id}")
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
# schedulers["DDPM"] = DDPMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DDIM"] = DDIMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
# model_id, subfolder="scheduler", algorithm_type="dpmsolver"
# )
# schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
# model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
# )
# schedulers["DPMSolverMultistepKarras"] = (
# DPMSolverMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# use_karras_sigmas=True,
# )
# )
# schedulers["DPMSolverMultistepKarras++"] = (
# DPMSolverMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# algorithm_type="dpmsolver++",
# use_karras_sigmas=True,
# )
# )
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["EulerAncestralDiscrete"] = (
EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
)
# schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["KDPM2AncestralDiscrete"] = (
# KDPM2AncestralDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# )
# schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
return schedulers
def export_scheduler_model(model):
return "None", "None"
scheduler_model_map = {
"PNDM": export_scheduler_model("PNDMScheduler"),
# "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
# "LCM": export_scheduler_model("LCMScheduler"),
# "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
# "DDPM": export_scheduler_model("DDPMScheduler"),
# "DDIM": export_scheduler_model("DDIMScheduler"),
# "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
# "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
# "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
# "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
# "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
# "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
}

View File

@@ -1,66 +0,0 @@
import numpy as np
import json
from random import (
randint,
seed as seed_random,
getstate as random_getstate,
setstate as random_setstate,
)
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None
if isinstance(seed_input, int):
return [seed_input]
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
return seed_input
raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)
# Generate a set of seeds from an input expression for batch_count batches,
# optionally using that input as the rng seed for any randomly generated seeds.
def batch_seeds(seed_input: str | list | int, batch_count: int, repeatable=False):
# turn the input into a list if possible
seeds = parse_seed_input(seed_input)
# slice or pad the list to be of batch_count length
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
if repeatable:
if all(seed < 0 for seed in seeds):
seeds[0] = sanitize_seed(seeds[0])
# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
seed_random(str([n for n in seeds if n > -1]))
# generate any seeds that are unspecified
seeds = [sanitize_seed(seed) for seed in seeds]
if repeatable:
# reset the rng back to normal
random_setstate(saved_random_state)
return seeds

View File

@@ -1,791 +0,0 @@
import argparse
import os
from pathlib import Path
from apps.shark_studio.modules.img_processing import resampler_list
def path_expand(s):
return Path(s).expanduser().resolve()
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
# Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
)
p.add_argument(
"-p",
"--prompt",
nargs="+",
default=[
"a photo taken of the front of a super-car drifting on a road near "
"mountains at high speeds with smoke coming off the tires, front "
"angle, front point of view, trees in the mountains of the "
"background, ((sharp focus))"
],
help="Text of which images to be generated.",
)
p.add_argument(
"--negative_prompt",
nargs="+",
default=[
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
"blurry, ugly, blur, oversaturated, cropped"
],
help="Text you don't want to see in the generated image.",
)
p.add_argument(
"--sd_init_image",
type=str,
help="Path to the image input for img2img/inpainting.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="The number of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=str,
default=-1,
help="The seed or list of seeds to use. -1 for a random one.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="The number of inferences to be made in a single `batch_count`.",
)
p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 1025, 8),
help="The height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 1025, 8),
help="The width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="The value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="The value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="Max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--max_embeddings_multiples",
type=int,
default=5,
help="The max multiple length of prompt embeddings compared to the max "
"output length of text encoder.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="The strength of change applied on the given input image for " "img2img.",
)
p.add_argument(
"--use_hiresfix",
type=bool,
default=False,
help="Use Hires Fix to do higher resolution images, while trying to "
"avoid the issues that come with it. This is accomplished by first "
"generating an image using txt2img, then running it through img2img.",
)
p.add_argument(
"--hiresfix_height",
type=int,
default=768,
choices=range(128, 769, 8),
help="The height of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_width",
type=int,
default=768,
choices=range(128, 769, 8),
help="The width of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_strength",
type=float,
default=0.6,
help="The denoising strength to apply for the Hires Fix.",
)
p.add_argument(
"--resample_type",
type=str,
default="Nearest Neighbor",
choices=resampler_list,
help="The resample type to use when resizing an image before being run "
"through stable diffusion.",
)
##############################################################################
# Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model.",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt.",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The number of steps to train.",
)
##############################################################################
# Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting.",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture.",
)
p.add_argument(
"--inpaint_full_res_padding",
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding.",
)
p.add_argument(
"--pixels",
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting.",
)
p.add_argument(
"--mask_blur",
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting.",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend left for outpainting.",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend right for outpainting.",
)
p.add_argument(
"--up",
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend top for outpainting.",
)
p.add_argument(
"--down",
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend bottom for outpainting.",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) "
"(min=0.0, max=4.0).",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0).",
)
##############################################################################
# Model Config and Usage Params
##############################################################################
p.add_argument("--device", type=str, default="vulkan", help="Device to run the model.")
p.add_argument(
"--precision", type=str, default="fp16", help="Precision to run the model."
)
p.add_argument(
"--import_mlir",
default=True,
action=argparse.BooleanOptionalAction,
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
)
p.add_argument(
"--use_tuned",
default=False,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available.",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--scheduler",
type=str,
default="DDIM",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
"HeunDiscrete].",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="Specify the format in which output image is save. "
"Supported options: jpg / png.",
)
p.add_argument(
"--output_dir",
type=str,
default=os.path.join(os.getcwd(), "generated_imgs"),
help="Directory path to save the output images and json.",
)
p.add_argument(
"--batch_count",
type=int,
default=1,
help="Number of batches to be generated with random seeds in " "single execution.",
)
p.add_argument(
"--repeatable_seeds",
default=False,
action=argparse.BooleanOptionalAction,
help="The seed of the first batch will be used as the rng seed to "
"generate the subsequent seeds for subsequent batches in that run.",
)
p.add_argument(
"--custom_weights",
type=str,
default="",
help="Path to a .safetensors or .ckpt file for SD pipeline weights.",
)
p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
"needs to be plugged in.",
)
p.add_argument(
"--base_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption.",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
"or an integer).",
)
p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble", "zoedepth"],
help="Enable the stencil feature.",
)
p.add_argument(
"--control_mode",
choices=["Prompt", "Balanced", "Controlnet"],
default="Balanced",
help="How Controlnet injection should be prioritized.",
)
p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint " "file (~3 MB).",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="Runs the quantized version of stable diffusion model. "
"This is currently in experimental phase. "
"Currently, only runs the stable-diffusion-2-1-base model in "
"int8 quantization.",
)
p.add_argument(
"--lowvram",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM.",
)
p.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
p.add_argument(
"--external_weights",
type=str,
default=None,
help="What type of externalized weights to use. Currently options are 'safetensors' and defaults to inlined weights.",
)
p.add_argument(
"--device_allocator_heap_key",
type=str,
default="",
help="Specify heap key for device caching allocator."
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan.",
)
p.add_argument(
"--iree_metal_target_platform",
type=str,
default="",
help="Specify target triple for metal.",
)
##############################################################################
# Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="Use the default scheduler precompiled into the model if available.",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. "
"If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. " "Use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help="Dispatches to return benchmark data on. "
'Use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help="Directory where you want to store dispatch data "
'generated with "--dispatch_benchmarks".',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for inserting debug frames between iterations " "for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="Flag setting warmup count for CLIP and VAE [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to clear all mlir and vmfb from common locations. "
"Recompiling will take several minutes.",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether or not to save a generation information "
"json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for whether or not to save generation information in "
"PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="If import_mlir is True, saves mlir via the debug option "
"in shark importer. Does nothing if import_mlir is false (the default).",
)
p.add_argument(
"--compile_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to toggle debug assert/verify flags for imported IR in the"
"iree-compiler. Default to false.",
)
p.add_argument(
"--iree_constant_folding",
default=True,
action=argparse.BooleanOptionalAction,
help="Controls constant folding in iree-compile for all SD models.",
)
p.add_argument(
"--data_tiling",
default=False,
action=argparse.BooleanOptionalAction,
help="Controls data tiling in iree-compile for all SD models.",
)
p.add_argument(
"--quantization",
type=str,
default="None",
help="Quantization to be used for api-exposed model.",
)
##############################################################################
# Web UI flags
##############################################################################
p.add_argument(
"--webui",
default=True,
action=argparse.BooleanOptionalAction,
help="controls whether the webui is launched.",
)
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the progress bar animation during " "image generation.",
)
p.add_argument(
"--tmp_dir",
type=str,
default=os.path.join(os.getcwd(), "shark_tmp"),
help="Path to tmp directory",
)
p.add_argument(
"--config_dir",
type=str,
default=os.path.join(os.getcwd(), "configs"),
help="Path to config directory",
)
p.add_argument(
"--model_dir",
type=str,
default=os.path.join(os.getcwd(), "models"),
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="One of: [api, app, web].",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for generating a public URL.",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="Flag for setting server port.",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling rest API.",
)
p.add_argument(
"--api_accept_origin",
action="append",
type=str,
help="An origin to be accepted by the REST api for Cross Origin"
"Resource Sharing (CORS). Use multiple times for multiple origins, "
'or use --api_accept_origin="*" to accept all origins. If no origins '
"are set no CORS headers will be returned by the api. Use, for "
"instance, if you need to access the REST api from Javascript running "
"in a web browser.",
)
p.add_argument(
"--debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling debugging log in WebUI.",
)
p.add_argument(
"--output_gallery",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the output gallery tab, and avoid exposing "
"images under --output_dir in the UI.",
)
p.add_argument(
"--configs_path",
default=None,
type=str,
help="Path to .json config directory.",
)
p.add_argument(
"--output_gallery_followlinks",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether the output gallery tab in the UI should "
"follow symlinks when listing subdirectories under --output_dir.",
)
p.add_argument(
"--api_log",
default=False,
action=argparse.BooleanOptionalAction,
help="Enables Compatibility API logging.",
)
##############################################################################
# SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file.",
)
p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)
p.add_argument(
"--save_annotation",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file.",
)
##############################################################################
# SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file.",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning.",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all.",
)
##############################################################################
# DocuChat Flags
##############################################################################
p.add_argument(
"--run_docuchat_web",
default=False,
action=argparse.BooleanOptionalAction,
help="Specifies whether the docuchat's web version is running or not.",
)
##############################################################################
# rocm Flags
##############################################################################
p.add_argument(
"--iree_rocm_target_chip",
type=str,
default="",
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
)
cmd_opts, unknown = p.parse_known_args()
if cmd_opts.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
os.getcwd(), cmd_opts.hf_model_id.replace("/", "_")
)

View File

@@ -1,106 +0,0 @@
import time
import argparse
class TimerSubcategory:
def __init__(self, timer, category):
self.timer = timer
self.category = category
self.start = None
self.original_base_category = timer.base_category
def __enter__(self):
self.start = time.time()
self.timer.base_category = self.original_base_category + self.category + "/"
self.timer.subcategory_level += 1
if self.timer.print_log:
print(f"{' ' * self.timer.subcategory_level}{self.category}:")
def __exit__(self, exc_type, exc_val, exc_tb):
elapsed_for_subcategroy = time.time() - self.start
self.timer.base_category = self.original_base_category
self.timer.add_time_to_record(
self.original_base_category + self.category,
elapsed_for_subcategroy,
)
self.timer.subcategory_level -= 1
self.timer.record(self.category, disable_log=True)
class Timer:
def __init__(self, print_log=False):
self.start = time.time()
self.records = {}
self.total = 0
self.base_category = ""
self.print_log = print_log
self.subcategory_level = 0
def elapsed(self):
end = time.time()
res = end - self.start
self.start = end
return res
def add_time_to_record(self, category, amount):
if category not in self.records:
self.records[category] = 0
self.records[category] += amount
def record(self, category, extra_time=0, disable_log=False):
e = self.elapsed()
self.add_time_to_record(self.base_category + category, e + extra_time)
self.total += e + extra_time
if self.print_log and not disable_log:
print(
f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s"
)
def subcategory(self, name):
self.elapsed()
subcat = TimerSubcategory(self, name)
return subcat
def summary(self):
res = f"{self.total:.1f}s"
additions = [
(category, time_taken)
for category, time_taken in self.records.items()
if time_taken >= 0.1 and "/" not in category
]
if not additions:
return res
res += " ("
res += ", ".join(
[f"{category}: {time_taken:.1f}s" for category, time_taken in additions]
)
res += ")"
return res
def dump(self):
return {"total": self.total, "records": self.records}
def reset(self):
self.__init__()
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument(
"--log-startup",
action="store_true",
help="print a detailed log of what's happening at startup",
)
args = parser.parse_known_args()[0]
startup_timer = Timer(print_log=args.log_startup)
startup_record = None

View File

@@ -1,48 +0,0 @@
# -*- mode: python ; coding: utf-8 -*-
from apps.shark_studio.studio_imports import pathex, datas, hiddenimports
binaries = []
block_cipher = None
a = Analysis(
['web/index.py'],
pathex=pathex,
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
module_collection_mode={
'gradio': 'py', # Collect gradio package as source .py files
},
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='nodai_shark_studio',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -1,68 +0,0 @@
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
# python path for pyinstaller
pathex = [
".",
]
# datafiles for pyinstaller
datas = []
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
datas += copy_metadata("packaging")
datas += copy_metadata("filelock")
datas += copy_metadata("numpy")
datas += copy_metadata("importlib_metadata")
datas += copy_metadata("omegaconf")
datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += copy_metadata("gradio")
datas += copy_metadata("scipy")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree", include_py_files=True)
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("scipy", include_py_files=True)
datas += [
("web/ui/css/*", "ui/css"),
("web/ui/js/*", "ui/js"),
("web/ui/logos/*", "logos"),
]
# hidden imports for pyinstaller
hiddenimports = ["shark", "apps"]
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
blacklist = ["tests", "convert"]
hiddenimports += [
x
for x in collect_submodules("transformers")
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
hiddenimports += ["iree._runtime"]
hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x]

View File

@@ -1,58 +0,0 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import logging
import unittest
import json
import gc
from apps.shark_studio.api.llm import LanguageModel, llm_chat_api
from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file
from apps.shark_studio.web.utils.file_utils import get_resource_path
# class SDAPITest(unittest.TestCase):
# def testSDSimple(self):
# from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
# import apps.shark_studio.web.utils.globals as global_obj
# global_obj._init()
# sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
# sd_kwargs = json.loads(sd_json)
# for arg in vars(cmd_opts):
# if arg in sd_kwargs:
# sd_kwargs[arg] = getattr(cmd_opts, arg)
# for i in shark_sd_fn_dict_input(sd_kwargs):
# print(i)
class LLMAPITest(unittest.TestCase):
def test01_LLMSmall(self):
lm = LanguageModel(
"TinyPixel/small-llama2",
hf_auth_token=None,
device="cpu",
precision="fp32",
quantization="None",
streaming_llm=True,
)
count = 0
label = "Turkishoure Turkish"
for msg, _ in lm.chat("hi, what are you?"):
# skip first token output
if count == 0:
count += 1
continue
assert (
msg.strip(" ") == label
), f"LLM API failed to return correct response, expected '{label}', received {msg}"
break
del lm
gc.collect()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

View File

@@ -1,41 +0,0 @@
import torch
from diffusers import (
UNet2DConditionModel,
)
from torch.fx.experimental.proxy_tensor import make_fx
class UnetModel(torch.nn.Module):
def __init__(self, hf_model_name):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
hf_model_name,
subfolder="unet",
)
def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
samples = torch.cat([sample] * 2)
unet_out = self.unet.forward(
samples, timestep, encoder_hidden_states, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
if __name__ == "__main__":
hf_model_name = "CompVis/stable-diffusion-v1-4"
unet = UnetModel(hf_model_name)
inputs = (torch.randn(1, 4, 64, 64), 1, torch.randn(2, 77, 768), 7.5)
fx_g = make_fx(
unet,
decomposition_table={},
tracing_mode="symbolic",
_allow_non_fake_inputs=True,
_allow_fake_constant=False,
)(*inputs)
print(fx_g)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 347 KiB

View File

@@ -1,45 +0,0 @@
import requests
from PIL import Image
import base64
from io import BytesIO
import json
def llm_chat_test(verbose=False):
# Define values here
prompt = "What is the significance of the number 42?"
url = "http://127.0.0.1:8080/v1/chat/completions"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"model": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"messages": [
{
"role": "",
"content": prompt,
}
],
"device": "vulkan://0",
"max_tokens": 4096,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
res_dict = json.loads(res.content.decode("utf-8"))
print(f"[chat] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(f"\n{res_dict['choices'][0]['message']['content']}\n")
if __name__ == "__main__":
# "Exercises the chatbot REST API of Shark. Make sure "
# "Shark is running in API mode on 127.0.0.1:8080 before running"
# "this script."
llm_chat_test(verbose=True)

View File

@@ -1,286 +0,0 @@
import base64
import io
import os
import time
import datetime
import uvicorn
import ipaddress
import requests
import threading
import collections
import gradio as gr
from PIL import Image, PngImagePlugin
from threading import Lock
from io import BytesIO
from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
# from sdapi_v1 import shark_sd_api
from apps.shark_studio.api.llm import llm_chat_api
def decode_base64_to_image(encoding):
if encoding.startswith("http://") or encoding.startswith("https://"):
headers = {}
response = requests.get(encoding, timeout=30, headers=headers)
try:
image = Image.open(BytesIO(response.content))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid image url") from e
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(
output_bytes,
format="PNG",
pnginfo=(metadata if use_metadata else None),
)
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
class FIFOLock(object):
def __init__(self):
self._lock = threading.Lock()
self._inner_lock = threading.Lock()
self._pending_threads = collections.deque()
def acquire(self, blocking=True):
with self._inner_lock:
lock_acquired = self._lock.acquire(False)
if lock_acquired:
return True
elif not blocking:
return False
release_event = threading.Event()
self._pending_threads.append(release_event)
release_event.wait()
return self._lock.acquire()
def release(self):
with self._inner_lock:
if self._pending_threads:
release_event = self._pending_threads.popleft()
release_event.set()
self._lock.release()
__enter__ = acquire
def __exit__(self, t, v, tb):
self.release()
def api_middleware(app: FastAPI):
rich_available = False
try:
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
import anyio # importing just so it can be placed on silent list
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
rich_available = True
except Exception:
pass
@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
res: Response = await call_next(req)
duration = str(round(time.time() - ts, 4))
res.headers["X-Process-Time"] = duration
endpoint = req.scope.get("path", "err")
if cmd_opts.api_log and endpoint.startswith("/sdapi"):
print(
"API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}".format(
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code=res.status_code,
ver=req.scope.get("http_version", "0.0"),
cli=req.scope.get("client", ("0:0.0.0", 0))[0],
prot=req.scope.get("scheme", "err"),
method=req.scope.get("method", "err"),
endpoint=endpoint,
duration=duration,
)
)
return res
def handle_exception(request: Request, e: Exception):
err = {
"error": type(e).__name__,
"detail": vars(e).get("detail", ""),
"body": vars(e).get("body", ""),
"errors": str(e),
}
if not isinstance(
e, HTTPException
): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
console.print_exception(
show_locals=True,
max_frames=2,
extra_lines=1,
suppress=[anyio, starlette],
word_wrap=False,
width=min([console.width, 200]),
)
else:
print(message)
raise (e)
return JSONResponse(
status_code=vars(e).get("status_code", 500),
content=jsonable_encoder(err),
)
@app.middleware("http")
async def exception_handling(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
return handle_exception(request, e)
@app.exception_handler(Exception)
async def fastapi_exception_handler(request: Request, e: Exception):
return handle_exception(request, e)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, e: HTTPException):
return handle_exception(request, e)
class ApiCompat:
def __init__(self, app: FastAPI, queue_lock: Lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
api_middleware(self.app)
# self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
# self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
# self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
# self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
# self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
# self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
# self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
# self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
# self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
# self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
# self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
# self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
# self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
# self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
# self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
# self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
# self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
# self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
# self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
# self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
# self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
# self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
# self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
# self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
# self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
# self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
# self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
# self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
# chat APIs needed for compatibility with multiple extensions using OpenAI API
self.add_api_route("/v1/chat/completions", llm_chat_api, methods=["POST"])
self.add_api_route("/v1/completions", llm_chat_api, methods=["POST"])
self.add_api_route("/chat/completions", llm_chat_api, methods=["POST"])
self.add_api_route("/completions", llm_chat_api, methods=["POST"])
self.add_api_route(
"/v1/engines/codegen/completions", llm_chat_api, methods=["POST"]
)
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
# def refresh_checkpoints(self):
# with self.queue_lock:
# studio_data.refresh_checkpoints()
# def refresh_vae(self):
# with self.queue_lock:
# studio_data.refresh_vae_list()
# def unloadapi(self):
# unload_model_weights()
# return {}
# def reloadapi(self):
# reload_model_weights()
# return {}
# def skip(self):
# studio.state.skip()
def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
uvicorn.run(
self.app,
host=server_name,
port=port,
root_path=root_path,
)
# def kill_studio(self):
# restart.stop_program()
# def restart_studio(self):
# if restart.is_restartable():
# restart.restart_program()
# return Response(status_code=501)
# def preprocess(self, args: dict):
# try:
# studio.state.begin(job="preprocess")
# preprocess(**args)
# studio.state.end()
# return models.PreprocessResponse(info="preprocess complete")
# except:
# studio.state.end()
# def stop_studio(request):
# studio.state.server_command = "stop"
# return Response("Stopping.")

View File

@@ -1 +0,0 @@

View File

@@ -1,222 +0,0 @@
from multiprocessing import Process, freeze_support
freeze_support()
from PIL import Image
import os
import time
import sys
import logging
import apps.shark_studio.api.initializers as initialize
from apps.shark_studio.modules import timer
startup_timer = timer.startup_timer
startup_timer.record("launcher")
initialize.imports()
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
# import before IREE to avoid MLIR library issues
import torch_mlir
def create_api(app):
from apps.shark_studio.web.api.compat import ApiCompat, FIFOLock
queue_lock = FIFOLock()
api = ApiCompat(app, queue_lock)
return api
def api_only():
from fastapi import FastAPI
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
initialize.initialize()
app = FastAPI()
initialize.setup_middleware(app)
api = create_api(app)
# from modules import script_callbacks
# script_callbacks.before_ui_callback()
# script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")
api.launch(
server_name="0.0.0.0",
port=cmd_opts.server_port,
root_path="",
)
def launch_webui(address):
from tkinter import Tk
import webview
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())
def webui():
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.ui.utils import (
amdicon_loc,
amdlogo_loc,
)
launch_api = cmd_opts.api
initialize.initialize()
from ui.chat import chat_element
from ui.sd import sd_element
from ui.outputgallery import outputgallery_element
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
# if args.api or "api" in args.ui.split(","):
# from apps.shark_studio.api.llm import (
# chat,
# )
# from apps.shark_studio.web.api import sdapi
#
# from fastapi import FastAPI, APIRouter
# from fastapi.middleware.cors import CORSMiddleware
# import uvicorn
#
# # init global sd pipeline and config
# global_obj._init()
#
# api = FastAPI()
# api.mount("/sdapi/", sdapi)
#
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
# api.add_api_route(
# "/v1/chat/completions", llm_chat_api, methods=["post"]
# )
# api.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
# api.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
# api.add_api_route("/completions", llm_chat_api, methods=["post"])
# api.add_api_route(
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
# )
# api.include_router(APIRouter())
#
# # deal with CORS requests if CORS accept origins are set
# if args.api_accept_origin:
# print(
# f"API Configured for CORS. Accepting origins: { args.api_accept_origin }"
# )
# api.add_middleware(
# CORSMiddleware,
# allow_origins=args.api_accept_origin,
# allow_methods=["GET", "POST"],
# allow_headers=["*"],
# )
# else:
# print("API not configured for CORS")
#
# uvicorn.run(api, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
import gradio as gr
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js")
# from apps.shark_studio.web.ui import load_ui_from_script
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme,
js=gradio_workarounds,
analytics_enabled=False,
title="Shark Studio 2.0 Beta",
) as studio_web:
amd_logo = Image.open(amdlogo_loc)
gr.Image(
value=amd_logo,
show_label=False,
interactive=False,
elem_id="tab_bar_logo",
show_download_button=False,
)
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
# and that the order in the code here is the order they should
# appear in the ui, as the id value doesn't determine the order.
# Where possible, avoid changing the id of any tab that is the
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
with gr.TabItem(label="Stable Diffusion", id=0):
sd_element.render()
with gr.TabItem(label="Output Gallery", id=1):
outputgallery_element.render()
with gr.TabItem(label="Chat Bot", id=2):
chat_element.render()
studio_web.queue()
# if args.ui == "app":
# t = Process(
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
# )
# t.start()
studio_web.launch(
share=cmd_opts.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=cmd_opts.server_port,
favicon_path=amdicon_loc,
)
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
if cmd_opts.webui == False:
api_only()
else:
webui()

View File

@@ -1,239 +0,0 @@
import gradio as gr
import time
import os
from pathlib import Path
from datetime import datetime as dt
import json
import sys
from apps.shark_studio.api.llm import (
llm_model_map,
LanguageModel,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
B_SYS, E_SYS = "<s>", "</s>"
B_SYS, E_SYS = "<s>", "</s>"
B_SYS, E_SYS = "<s>", "</s>"
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def append_bot_prompt(history, input_prompt):
user_prompt = f"{input_prompt} {E_SYS} {E_SYS}"
history += user_prompt
return history
language_model = None
def get_default_config():
return False
# model_vmfb_key = ""
def chat_fn(
prompt_prefix,
history,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
cli=False,
):
global language_model
if streaming_llm and prompt_prefix == "Clear":
language_model = None
return "Clearing history...", ""
if language_model is None:
history[-1][-1] = "Getting the model ready..."
yield history, ""
language_model = LanguageModel(
model,
device=device,
precision=precision,
external_weights="safetensors",
use_system_prompt=prompt_prefix,
streaming_llm=streaming_llm,
hf_auth_token=cmd_opts.hf_auth_token,
)
history[-1][-1] = "Getting the model ready... Done"
yield history, ""
history[-1][-1] = ""
token_count = 0
total_time = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, exec_time in language_model.chat(history):
history[-1][-1] = f"{text}{E_SYS}"
if is_first:
prefill_time = exec_time
is_first = False
yield history, f"Prefill: {prefill_time:.2f}"
else:
total_time += exec_time
token_count += 1
tokens_per_sec = token_count / total_time
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
def view_json_file(file_obj):
content = ""
with open(file_obj.name, "r") as fopen:
content = fopen.read()
return content
with gr.Blocks(title="Chat") as chat_element:
with gr.Row():
model_choices = list(llm_model_map.keys())
model = gr.Dropdown(
label="Select Model",
value=model_choices[0],
choices=model_choices,
allow_custom_value=True,
)
supported_devices = global_obj.get_device_list()
enabled = True
if len(supported_devices) == 0:
supported_devices = ["cpu-task"]
supported_devices = [x for x in supported_devices if "sync" not in x]
device = gr.Dropdown(
label="Device",
value=supported_devices[0],
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
)
precision = gr.Radio(
label="Precision",
value="fp32",
choices=[
# "int4",
# "int8",
# "fp16",
"fp32",
],
visible=False,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=False,
interactive=True,
visible=False,
)
streaming_llm = gr.Checkbox(
label="Run in streaming mode (requires recompilation)",
value=True,
interactive=False,
visible=False,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=True,
interactive=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(label="Upload sharding configuration", visible=False)
json_view_button = gr.Button("View as JSON", visible=False)
json_view = gr.JSON(visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
submit_click_event = submit.click(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(
fn=chat_fn,
inputs=[
clear,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
).then(lambda: None, None, [chatbot], queue=False)

View File

@@ -1,67 +0,0 @@
from apps.shark_studio.web.ui.utils import (
HSLHue,
hsl_color,
)
from apps.shark_studio.modules.embeddings import get_lora_metadata
# Answers HTML to show the most frequent tags used when a LoRA was trained,
# taken from the metadata of its .safetensors file.
def lora_changed(lora_files):
# tag frequency percentage, that gets maximum amount of the staring hue
TAG_COLOR_THRESHOLD = 0.55
# tag frequency percentage, above which a tag is displayed
TAG_DISPLAY_THRESHOLD = 0.65
# template for the html used to display a tag
TAG_HTML_TEMPLATE = (
'<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
)
output = []
for lora_file in lora_files:
if lora_file == "":
output.extend(["<div><i>No LoRA selected</i></div>"])
elif not lora_file.lower().endswith(".safetensors"):
output.extend(
[
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
]
)
else:
metadata = get_lora_metadata(lora_file)
if metadata:
frequencies = metadata["frequencies"]
output.extend(
[
"".join(
[
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
]
+ [
TAG_HTML_TEMPLATE.format(
color=hsl_color(
(tag[1] - TAG_COLOR_THRESHOLD)
/ (1 - TAG_COLOR_THRESHOLD),
start=HSLHue.RED,
end=HSLHue.GREEN,
),
tag=tag[0],
)
for tag in frequencies
if tag[1] > TAG_DISPLAY_THRESHOLD
],
)
]
)
elif metadata is None:
output.extend(
[
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
]
)
else:
output.extend(
[
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]
)
return output

View File

@@ -1,49 +0,0 @@
// workaround gradio after 4.7, not applying any @media rules form the custom .css file
() => {
console.log(`innerWidth: ${window.innerWidth}` )
// 1536px rules
const mediaQuery1536 = window.matchMedia('(min-width: 1536px)')
function handleWidth1536(event) {
// display in full width for desktop devices
document.querySelectorAll(".gradio-container")
.forEach( (node) => {
if (event.matches) {
node.classList.add("gradio-container-size-full");
} else {
node.classList.remove("gradio-container-size-full")
}
});
}
mediaQuery1536.addEventListener("change", handleWidth1536);
mediaQuery1536.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1536}));
// 1921px rules
const mediaQuery1921 = window.matchMedia('(min-width: 1921px)')
function handleWidth1921(event) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
document.querySelectorAll("#gallery")
.forEach( (node) => {
if (event.matches) {
node.classList.add("gallery-force-height768");
node.classList.add("gallery-limit-height768");
} else {
node.classList.remove("gallery-force-height768");
node.classList.remove("gallery-limit-height768");
}
});
}
mediaQuery1921.addEventListener("change", handleWidth1921);
mediaQuery1921.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1921}));
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.4 KiB

View File

@@ -1,406 +0,0 @@
import glob
import gradio as gr
import os
import subprocess
import sys
from PIL import Image
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
from apps.shark_studio.web.ui.utils import amdlogo_loc
from apps.shark_studio.web.utils.metadata import displayable_metadata
# -- Functions for file, directory and image info querying
output_dir = get_generated_imgs_path()
def outputgallery_filenames(subdir) -> list[str]:
new_dir_path = os.path.join(output_dir, subdir)
if os.path.exists(new_dir_path):
filenames = [
glob.glob(new_dir_path + "/" + ext) for ext in ("*.png", "*.jpg", "*.jpeg")
]
return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True)
else:
return []
def output_subdirs() -> list[str]:
# Gets a list of subdirectories of output_dir and below, as relative paths.
relative_paths = [
os.path.relpath(entry[0], output_dir)
for entry in os.walk(
output_dir, followlinks=cmd_opts.output_gallery_followlinks
)
]
# It is less confusing to always including the subdir that will take any
# images generated today even if it doesn't exist yet
if get_generated_imgs_todays_subdir() not in relative_paths:
relative_paths.append(get_generated_imgs_todays_subdir())
# sort subdirectories so that the date named ones we probably
# created in this or previous sessions come first, sorted with the most
# recent first. Other subdirs are listed after.
generated_paths = sorted(
[path for path in relative_paths if path.isnumeric()], reverse=True
)
result_paths = generated_paths + sorted(
[path for path in relative_paths if (not path.isnumeric()) and path != "."]
)
return result_paths
# --- Define UI layout for Gradio
with gr.Blocks() as outputgallery_element:
amd_logo = Image.open(amdlogo_loc)
with gr.Row(elem_id="outputgallery_gallery"):
# needed to workaround gradio issue:
# https://github.com/gradio-app/gradio/issues/2907
dev_null = gr.Textbox("", visible=False)
gallery_files = gr.State(value=[])
subdirectory_paths = gr.State(value=[])
with gr.Column(scale=6):
logo = gr.Image(
label="Getting subdirectories...",
value=amd_logo,
interactive=False,
visible=True,
show_label=True,
elem_id="top_logo",
elem_classes="logo_centered",
show_download_button=False,
)
gallery = gr.Gallery(
label="",
value=gallery_files.value,
visible=False,
show_label=True,
columns=4,
)
with gr.Column(scale=4):
with gr.Group():
with gr.Row():
with gr.Column(
scale=15,
min_width=160,
elem_id="output_subdir_container",
):
subdirectories = gr.Dropdown(
label=f"Subdirectories of {output_dir}",
type="value",
choices=subdirectory_paths.value,
value="",
interactive=True,
elem_classes="dropdown_no_container",
allow_custom_value=True,
)
with gr.Column(
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
open_subdir = gr.Button(
variant="secondary",
value="\U0001F5C1", # unicode open folder
interactive=False,
size="sm",
)
with gr.Column(
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
refresh = gr.Button(
variant="secondary",
value="\u21BB", # unicode clockwise arrow circle
size="sm",
)
image_columns = gr.Slider(
label="Columns shown", value=4, minimum=1, maximum=16, step=1
)
outputgallery_filename = gr.Textbox(
label="Filename",
value="None",
interactive=False,
show_copy_button=True,
)
with gr.Accordion(
label="Parameter Information", open=False
) as parameters_accordian:
image_parameters = gr.DataFrame(
headers=["Parameter", "Value"],
col_count=2,
wrap=True,
elem_classes="output_parameters_dataframe",
value=[["Status", "No image selected"]],
interactive=True,
)
with gr.Accordion(label="Send To", open=True):
with gr.Row():
outputgallery_sendto_sd = gr.Button(
value="Stable Diffusion",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)
# --- Event handlers
def on_clear_gallery():
return [
gr.Gallery(
value=[],
visible=False,
),
gr.Image(
visible=True,
),
]
def on_image_columns_change(columns):
return gr.Gallery(columns=columns)
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
new_images = outputgallery_filenames(subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
return [
new_images,
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
]
def on_open_subdir(subdir):
subdir_path = os.path.normpath(os.path.join(output_dir, subdir))
if os.path.isdir(subdir_path):
if sys.platform == "linux":
subprocess.run(["xdg-open", subdir_path])
elif sys.platform == "darwin":
subprocess.run(["open", subdir_path])
elif sys.platform == "win32":
os.startfile(subdir_path)
def on_refresh(current_subdir: str) -> list:
# get an up-to-date subdirectory list
refreshed_subdirs = output_subdirs()
# get the images using either the current subdirectory or the most
# recent valid one
new_subdir = (
current_subdir
if current_subdir in refreshed_subdirs
else refreshed_subdirs[0]
)
new_images = outputgallery_filenames(new_subdir)
new_label = (
f"{len(new_images)} images in " f"{os.path.join(output_dir, new_subdir)}"
)
return [
gr.Dropdown(
choices=refreshed_subdirs,
value=new_subdir,
),
refreshed_subdirs,
new_images,
gr.Gallery(value=new_images, label=new_label, visible=len(new_images) > 0),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
]
def on_new_image(subdir, subdir_paths, status) -> list:
# prevent error triggered when an image generates before the tab
# has even been selected
subdir_paths = (
subdir_paths
if len(subdir_paths) > 0
else [get_generated_imgs_todays_subdir()]
)
# only update if the current subdir is the most recent one as
# new images only go there
if subdir_paths[0] == subdir:
new_images = outputgallery_filenames(subdir)
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, subdir)} - {status}"
)
return [
new_images,
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
]
else:
# otherwise change nothing,
# (only untyped gradio gr.update() does this)
return [gr.update(), gr.update(), gr.update()]
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
# evt.index is an index into the full list of filenames for
# the current subdirectory
filename = images[evt.index]
params = displayable_metadata(filename)
if params:
if params["source"] == "missing":
return [
"Could not find this image file, refresh the gallery and update the images",
[["Status", "File missing"]],
]
else:
return [
filename,
list(map(list, params["parameters"].items())),
]
return [
filename,
[["Status", "No parameters found"]],
]
def on_outputgallery_filename_change(filename: str) -> list:
exists = filename != "None" and os.path.exists(filename)
return [
# disable or enable each of the sendto button based on whether
# an image is selected
gr.Button(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh
# to populate the subdirectory select box and the images from the most
# recent subdirectory.
#
# We do it at this point rather than setting this up in the controls'
# definitions as when you refresh the browser you always get what was
# *initially* set, which won't include any new subdirectories or images
# that might have created since the application was started. Doing it
# this way means a browser refresh/reload always gets the most
# up-to-date data.
def on_select_tab(subdir_paths, request: gr.Request):
local_client = request.headers["host"].startswith(
"127.0.0.1:"
) or request.headers["host"].startswith("localhost:")
if len(subdir_paths) == 0:
return on_refresh("") + [gr.update(interactive=local_client)]
else:
return (
# Change nothing, (only untyped gr.update() does this)
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
# replacement phase.
clear_gallery = dict(
fn=on_clear_gallery,
inputs=None,
outputs=[gallery, logo],
queue=False,
)
subdirectories.select(**clear_gallery).then(
on_select_subdir,
[subdirectories],
[gallery_files, gallery, logo],
queue=False,
)
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
queue=False,
)
image_columns.change(
fn=on_image_columns_change,
inputs=[image_columns],
outputs=[gallery],
queue=False,
)
gallery.select(
on_select_image,
[gallery_files],
[outputgallery_filename, image_parameters],
queue=False,
)
outputgallery_filename.change(
on_outputgallery_filename_change,
[outputgallery_filename],
[
outputgallery_sendto_sd,
],
queue=False,
)
# We should have been given the .select function for our tab, so set it up
def outputgallery_tab_select(select):
select(
fn=on_select_tab,
inputs=[subdirectory_paths],
outputs=[
subdirectories,
subdirectory_paths,
gallery_files,
gallery,
logo,
open_subdir,
],
queue=False,
)
# We should have been passed a list of components on other tabs that update
# when a new image has generated on that tab, so set things up so the user
# will see that new image if they are looking at today's subdirectory
def outputgallery_watch(components: gr.Textbox):
for component in components:
component.change(
on_new_image,
inputs=[subdirectories, subdirectory_paths, component],
outputs=[gallery_files, gallery, logo],
queue=False,
)

View File

@@ -1,777 +0,0 @@
import os
import json
import gradio as gr
import numpy as np
from inspect import signature
from PIL import Image
from pathlib import Path
from datetime import datetime as dt
from gradio.components.image_editor import (
EditorValue,
)
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_checkpoints_path,
get_checkpoints,
get_configs_path,
write_default_sd_configs,
)
from apps.shark_studio.api.sd import (
shark_sd_fn_dict_input,
cancel_sd,
unload_sd,
)
from apps.shark_studio.api.controlnet import (
cnet_preview,
)
from apps.shark_studio.modules.schedulers import (
scheduler_model_map,
)
from apps.shark_studio.modules.img_processing import (
resampler_list,
resize_stencil,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.ui.utils import (
amdlogo_loc,
none_to_str_none,
str_none_to_none,
)
from apps.shark_studio.web.utils.state import (
status_label,
)
from apps.shark_studio.web.ui.common_events import lora_changed
from apps.shark_studio.modules import logger
import apps.shark_studio.web.utils.globals as global_obj
sd_default_models = [
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1-base",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/sdxl-turbo",
]
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
content = fopen.read()
return content
def submit_to_cnet_config(
stencil: str,
preprocessed_hint: str,
cnet_strength: int,
control_mode: str,
curr_config: dict,
):
if any(i in [None, ""] for i in [stencil, preprocessed_hint]):
return gr.update()
if curr_config is not None:
if "controlnets" in curr_config:
curr_config["controlnets"]["control_mode"] = control_mode
curr_config["controlnets"]["model"].append(stencil)
curr_config["controlnets"]["hint"].append(preprocessed_hint)
curr_config["controlnets"]["strength"].append(cnet_strength)
return curr_config
cnet_map = {}
cnet_map["controlnets"] = {
"control_mode": control_mode,
"model": [stencil],
"hint": [preprocessed_hint],
"strength": [cnet_strength],
}
return cnet_map
def update_embeddings_json(embedding):
return {"embeddings": [embedding]}
def submit_to_main_config(input_cfg: dict, main_cfg: dict):
if main_cfg in [None, "", {}]:
return input_cfg
for base_key in input_cfg:
main_cfg[base_key] = input_cfg[base_key]
return main_cfg
def pull_sd_configs(
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
controlnets,
embeddings,
):
sd_args = str_none_to_none(locals())
sd_cfg = {}
for arg in sd_args:
if arg in [
"prompt",
"negative_prompt",
"sd_init_image",
]:
sd_cfg[arg] = [sd_args[arg]]
elif arg in ["controlnets", "embeddings"]:
if isinstance(arg, dict):
sd_cfg[arg] = json.loads(sd_args[arg])
else:
sd_cfg[arg] = {}
else:
sd_cfg[arg] = sd_args[arg]
return json.dumps(sd_cfg)
def load_sd_cfg(sd_json: dict, load_sd_config: str):
new_sd_config = none_to_str_none(json.loads(view_json_file(load_sd_config)))
if sd_json:
for key in new_sd_config:
sd_json[key] = new_sd_config[key]
else:
sd_json = new_sd_config
for i in sd_json["sd_init_image"]:
if i is not None:
if os.path.isfile(i):
sd_image = [Image.open(i, mode="r")]
else:
sd_image = None
return [
sd_json["prompt"][0],
sd_json["negative_prompt"][0],
sd_image,
sd_json["height"],
sd_json["width"],
sd_json["steps"],
sd_json["strength"],
sd_json["guidance_scale"],
sd_json["seed"],
sd_json["batch_count"],
sd_json["batch_size"],
sd_json["scheduler"],
sd_json["base_model_id"],
sd_json["custom_weights"],
sd_json["custom_vae"],
sd_json["precision"],
sd_json["device"],
sd_json["target_triple"],
sd_json["ondemand"],
sd_json["compiled_pipeline"],
sd_json["resample_type"],
sd_json["controlnets"],
sd_json["embeddings"],
sd_json,
]
def save_sd_cfg(config: dict, save_name: str):
if os.path.exists(save_name):
filepath = save_name
elif cmd_opts.configs_path:
filepath = os.path.join(cmd_opts.configs_path, save_name)
else:
filepath = os.path.join(get_configs_path(), save_name)
if ".json" not in filepath:
filepath += ".json"
with open(filepath, mode="w") as f:
f.write(json.dumps(config))
return "..."
def create_canvas(width, height):
data = Image.fromarray(
np.zeros(
shape=(height, width, 3),
dtype=np.uint8,
)
+ 255
)
img_dict = {
"background": data,
"layers": [],
"composite": None,
}
return EditorValue(img_dict)
def import_original(original_img, width, height):
if original_img is None:
resized_img = create_canvas(width, height)
return resized_img
else:
resized_img, _, _ = resize_stencil(original_img, width, height)
img_dict = {
"background": resized_img,
"layers": [],
"composite": None,
}
return EditorValue(img_dict)
def base_model_changed(base_model_id):
new_choices = get_checkpoints(
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
) + get_checkpoints(model_type="checkpoints")
return gr.Dropdown(
value=new_choices[0] if len(new_choices) > 0 else "None",
choices=["None"] + new_choices,
)
with gr.Blocks(title="Stable Diffusion") as sd_element:
with gr.Column(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=2, min_width=600):
with gr.Accordion(
label="\U0001F4D0\U0000FE0F Device Settings", open=False
):
device = gr.Dropdown(
elem_id="device",
label="Device",
value=global_obj.get_device_list()[0],
choices=global_obj.get_device_list(),
allow_custom_value=False,
)
target_triple = gr.Textbox(
elem_id="target_triple",
label="Architecture",
value="",
)
with gr.Row():
ondemand = gr.Checkbox(
value=cmd_opts.lowvram,
label="Low VRAM",
interactive=True,
)
precision = gr.Radio(
label="Precision",
value=cmd_opts.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}"
base_model_id = gr.Dropdown(
label="\U000026F0\U0000FE0F Base Model",
info="Select or enter HF model ID",
elem_id="custom_model",
value="stabilityai/stable-diffusion-2-1-base",
choices=sd_default_models,
allow_custom_value=True,
) # base_model_id
with gr.Row():
height = gr.Slider(
384,
1024,
value=cmd_opts.height,
step=8,
label="\U00002195\U0000FE0F Height",
)
width = gr.Slider(
384,
1024,
value=cmd_opts.width,
step=8,
label="\U00002194\U0000FE0F Width",
)
with gr.Accordion(
label="\U00002696\U0000FE0F Model Weights", open=False
):
with gr.Column():
custom_weights = gr.Dropdown(
label="Checkpoint Weights",
info="Select or enter HF model ID",
elem_id="custom_model",
value="None",
allow_custom_value=True,
choices=["None"]
+ get_checkpoints(os.path.basename(str(base_model_id))),
) # custom_weights
base_model_id.change(
fn=base_model_changed,
inputs=[base_model_id],
outputs=[custom_weights],
)
sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
"\\", "\n\\"
)
sd_vae_info = f"VAE Path: {sd_vae_info}"
custom_vae = gr.Dropdown(
label=f"VAE Model",
info=sd_vae_info,
elem_id="custom_model",
value=(
os.path.basename(cmd_opts.custom_vae)
if cmd_opts.custom_vae
else "None"
),
choices=["None"] + get_checkpoints("vae"),
allow_custom_value=True,
scale=1,
)
sd_lora_info = (str(get_checkpoints_path("loras"))).replace(
"\\", "\n\\"
)
lora_opt = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=sd_lora_info,
elem_id="lora_weights",
value=None,
multiselect=True,
choices=[] + get_checkpoints("lora"),
scale=2,
)
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
embeddings_config = gr.JSON(
label="Embeddings Options", min_width=50, scale=1
)
gr.on(
triggers=[lora_opt.change],
fn=lora_changed,
inputs=[lora_opt],
outputs=[lora_tags],
queue=True,
show_progress=False,
).then(
fn=update_embeddings_json,
inputs=[lora_opt],
outputs=[embeddings_config],
show_progress=False,
)
with gr.Accordion(
label="\U0001F9EA\U0000FE0F Input Image Processing", open=False
):
strength = gr.Slider(
0,
1,
value=cmd_opts.strength,
step=0.01,
label="Denoising Strength",
)
resample_type = gr.Dropdown(
value=cmd_opts.resample_type,
choices=resampler_list,
label="Resample Type",
allow_custom_value=True,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="\U00002795\U0000FE0F Prompt",
value=cmd_opts.prompt[0],
lines=2,
elem_id="prompt_box",
show_copy_button=True,
)
negative_prompt = gr.Textbox(
label="\U00002796\U0000FE0F Negative Prompt",
value=cmd_opts.negative_prompt[0],
lines=2,
elem_id="negative_prompt_box",
show_copy_button=True,
)
with gr.Row(equal_height=True):
seed = gr.Textbox(
value=cmd_opts.seed,
label="\U0001F331\U0000FE0F Seed",
info="An integer or a JSON list of integers, -1 for random",
show_copy_button=True,
)
scheduler = gr.Dropdown(
elem_id="scheduler",
label="\U0001F4C5\U0000FE0F Scheduler",
info="\U000E0020", # forces same height as seed
value="EulerDiscrete",
choices=scheduler_model_map.keys(),
allow_custom_value=False,
)
with gr.Row():
steps = gr.Slider(
1,
100,
value=cmd_opts.steps,
step=1,
label="\U0001F3C3\U0000FE0F Steps",
)
guidance_scale = gr.Slider(
0,
50,
value=cmd_opts.guidance_scale,
step=0.1,
label="\U0001F5C3\U0000FE0F CFG Scale",
)
with gr.Accordion(
label="Controlnet Options",
open=False,
visible=False,
):
preprocessed_hints = gr.State([])
with gr.Column():
sd_cnet_info = (
str(get_checkpoints_path("controlnet"))
).replace("\\", "\n\\")
with gr.Row():
cnet_config = gr.JSON()
with gr.Column():
clear_config = gr.ClearButton(
value="Clear Controlnet Config",
size="sm",
components=cnet_config,
)
control_mode = gr.Radio(
choices=["Prompt", "Balanced", "Controlnet"],
value="Balanced",
label="Control Mode",
)
with gr.Row():
with gr.Column(scale=1):
cnet_model = gr.Dropdown(
allow_custom_value=True,
label=f"Controlnet Model",
info=sd_cnet_info,
value="None",
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
]
+ get_checkpoints("controlnet"),
)
cnet_strength = gr.Slider(
label="Controlnet Strength",
minimum=0,
maximum=100,
value=50,
step=1,
)
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=8,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=8,
)
make_canvas = gr.Button(
value="Make Canvas!",
)
use_input_img = gr.Button(
value="Use Original Image",
size="sm",
)
cnet_input = gr.Image(
value=None,
type="pil",
image_mode="RGB",
interactive=True,
)
with gr.Column(scale=1):
cnet_output = gr.Image(
value=None,
visible=True,
label="Preprocessed Hint",
interactive=False,
show_label=True,
)
cnet_gen = gr.Button(
value="Preprocess controlnet input",
)
use_result = gr.Button(
"Submit",
size="sm",
)
make_canvas.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[cnet_input],
queue=False,
)
cnet_gen.click(
fn=cnet_preview,
inputs=[
cnet_model,
cnet_input,
],
outputs=[
cnet_output,
preprocessed_hints,
],
)
use_result.click(
fn=submit_to_cnet_config,
inputs=[
cnet_model,
cnet_output,
cnet_strength,
control_mode,
cnet_config,
],
outputs=[
cnet_config,
],
queue=False,
)
with gr.Column(scale=3, min_width=600):
with gr.Tabs() as sd_tabs:
sd_element.load(
# Workaround for Gradio issue #7085
# TODO: revert to setting selected= in gr.Tabs declaration
# once this is resolved in Gradio
lambda: gr.Tabs(selected=101),
outputs=[sd_tabs],
)
with gr.Tab(label="Input Image", id=100) as sd_tab_init_image:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
# TODO: make this import image prompt info if it exists
sd_init_image = gr.Image(
type="pil",
interactive=True,
show_label=False,
)
use_input_img.click(
fn=import_original,
inputs=[
sd_init_image,
canvas_width,
canvas_height,
],
outputs=[cnet_input],
queue=False,
)
with gr.Tab(label="Generate Images", id=101) as sd_tab_gallery:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
sd_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
columns=2,
object_fit="fit",
preview=True,
)
with gr.Row():
batch_count = gr.Slider(
1,
100,
value=cmd_opts.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=cmd_opts.batch_size,
step=1,
label="Batch Size",
interactive=True,
visible=True,
)
compiled_pipeline = gr.Checkbox(
False,
label="Faster txt2img (SDXL only)",
)
with gr.Row():
stable_diffusion = gr.Button("Start")
unload = gr.Button("Unload Models")
unload.click(
fn=unload_sd,
queue=False,
show_progress=False,
)
stop_batch = gr.Button("Stop")
with gr.Tab(label="Config", id=102) as sd_tab_config:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
Path(get_configs_path()).mkdir(
parents=True, exist_ok=True
)
default_config_file = os.path.join(
get_configs_path(),
"default_sd_config.json",
)
write_default_sd_configs(get_configs_path())
sd_json = gr.JSON(
elem_classes=["fill"],
value=view_json_file(default_config_file),
)
with gr.Row():
with gr.Column(scale=3):
load_sd_config = gr.FileExplorer(
label="Load Config",
file_count="single",
root_dir=(
cmd_opts.configs_path
if cmd_opts.configs_path
else get_configs_path()
),
height=75,
)
with gr.Column(scale=1):
save_sd_config = gr.Button(
value="Save Config", size="sm"
)
clear_sd_config = gr.ClearButton(
value="Clear Config",
size="sm",
components=sd_json,
)
with gr.Row():
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
show_label=False,
)
load_sd_config.change(
fn=load_sd_cfg,
inputs=[sd_json, load_sd_config],
outputs=[
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
sd_json,
],
)
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Tab(label="Log", id=103) as sd_tab_log:
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
pull_kwargs = dict(
fn=pull_sd_configs,
inputs=[
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
],
outputs=[
sd_json,
],
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Stable Diffusion", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=sd_status,
)
gen_kwargs = dict(
fn=shark_sd_fn_dict_input,
inputs=[sd_json],
outputs=[
sd_gallery,
sd_status,
],
)
prompt_submit = prompt.submit(**status_kwargs).then(**pull_kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(**pull_kwargs)
generate_click = (
stable_diffusion.click(**status_kwargs).then(**pull_kwargs).then(**gen_kwargs)
)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -1,43 +0,0 @@
from enum import IntEnum
import math
import sys
import os
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
amdlogo_loc = resource_path("logos/amd-logo.jpg")
amdicon_loc = resource_path("logos/amd-icon.jpg")
class HSLHue(IntEnum):
RED = 0
YELLOW = 60
GREEN = 120
CYAN = 180
BLUE = 240
MAGENTA = 300
def hsl_color(alpha: float, start, end):
b = (end - start) * (alpha if alpha > 0 else 0)
result = b + start
# Return a CSS HSL string
return f"hsl({math.floor(result)}, 80%, 35%)"
def none_to_str_none(props: dict):
for key in props:
props[key] = "None" if props[key] == None else props[key]
return props
def str_none_to_none(props: dict):
for key in props:
props[key] = None if props[key] == "None" else props[key]
return props

View File

@@ -1,12 +0,0 @@
import os
import sys
def get_available_devices():
return ["cpu-task"]
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)

View File

@@ -1,95 +0,0 @@
default_sd_config = r"""{
"prompt": [
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 50,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": false,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
sdxl_30steps = r"""{
"prompt": [
"a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 1024,
"width": 1024,
"steps": 30,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-xl-base-1.0",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": true,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
sdxl_turbo = r"""{
"prompt": [
"A cat wearing a hat that says 'TURBO' on it. The cat is sitting on a skateboard."
],
"negative_prompt": [
""
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 2,
"strength": 0.8,
"guidance_scale": 0,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerAncestralDiscrete",
"base_model_id": "stabilityai/sdxl-turbo",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": true,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
default_sd_configs = {
"default_sd_config.json": default_sd_config,
"sdxl-30steps.json": sdxl_30steps,
"sdxl-turbo.json": sdxl_turbo,
}

View File

@@ -1,102 +0,0 @@
import os
import sys
import glob
from datetime import datetime as dt
from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
checkpoints_filetypes = (
"*.ckpt",
"*.safetensors",
)
from apps.shark_studio.web.utils.default_configs import default_sd_configs
def write_default_sd_configs(path):
for key in default_sd_configs.keys():
config_fpath = os.path.join(path, key)
with open(config_fpath, "w") as f:
f.write(default_sd_configs[key])
def safe_name(name):
return name.split("/")[-1].replace("-", "_")
def get_path_stem(path):
path = Path(path)
return path.stem
def get_resource_path(path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
if os.path.isabs(path):
return path
else:
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
result = Path(os.path.join(base_path, path)).resolve(strict=False)
return result
def get_configs_path() -> Path:
configs = get_resource_path(cmd_opts.config_dir)
if not os.path.exists(configs):
os.mkdir(configs)
return Path(configs)
def get_generated_imgs_path() -> Path:
outputs = get_resource_path(cmd_opts.output_dir)
if not os.path.exists(outputs):
os.mkdir(outputs)
return Path(outputs)
def get_tmp_path() -> Path:
tmpdir = get_resource_path(cmd_opts.model_dir)
if not os.path.exists(tmpdir):
os.mkdir(tmpdir)
return Path(tmpdir)
def get_generated_imgs_todays_subdir() -> str:
return dt.now().strftime("%Y%m%d")
def create_model_folders():
dir = ["checkpoints", "vae", "lora", "vmfb"]
if not os.path.isdir(cmd_opts.model_dir):
try:
os.makedirs(cmd_opts.model_dir)
except OSError:
sys.exit(
f"Invalid --model_dir argument, "
f"{cmd_opts.model_dir} folder does not exist, and cannot be created."
)
for root in dir:
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)
def get_checkpoints_path(model_type=""):
return get_resource_path(os.path.join(cmd_opts.model_dir, model_type))
def get_checkpoints(model_type="checkpoints"):
ckpt_files = []
file_types = checkpoints_filetypes
if model_type == "lora":
file_types = file_types + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_checkpoints_path(model_type), extn))
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_checkpoint_pathfile(checkpoint_name, model_type="checkpoints"):
return os.path.join(get_checkpoints_path(model_type), checkpoint_name)

View File

@@ -1,134 +0,0 @@
import gc
from ...api.utils import get_available_devices
"""
The global objects include SD pipeline and config.
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
Also we could avoid memory leak when switching models by clearing the cache.
"""
def _init():
global _sd_obj
global _llm_obj
global _devices
global _pipe_kwargs
global _prep_kwargs
global _gen_kwargs
global _schedulers
_sd_obj = None
_llm_obj = None
_devices = None
_pipe_kwargs = None
_prep_kwargs = None
_gen_kwargs = None
_schedulers = None
set_devices()
def set_sd_obj(value):
global _sd_obj
global _llm_obj
_llm_obj = None
_sd_obj = value
def set_llm_obj(value):
global _sd_obj
global _llm_obj
_llm_obj = value
_sd_obj = None
def set_devices():
global _devices
_devices = get_available_devices()
def set_sd_scheduler(key):
global _sd_obj
_sd_obj.scheduler = _schedulers[key]
def set_sd_status(value):
global _sd_obj
_sd_obj.status = value
def set_pipe_kwargs(value):
global _pipe_kwargs
_pipe_kwargs = value
def set_prep_kwargs(value):
global _prep_kwargs
_prep_kwargs = value
def set_gen_kwargs(value):
global _gen_kwargs
_gen_kwargs = value
def set_schedulers(value):
global _schedulers
_schedulers = value
def get_sd_obj():
global _sd_obj
return _sd_obj
def get_llm_obj():
global _llm_obj
return _llm_obj
def get_device_list():
global _devices
return _devices
def get_sd_status():
global _sd_obj
return _sd_obj.status
def get_pipe_kwargs():
global _pipe_kwargs
return _pipe_kwargs
def get_prep_kwargs():
global _prep_kwargs
return _prep_kwargs
def get_gen_kwargs():
global _gen_kwargs
return _gen_kwargs
def get_scheduler(key):
global _schedulers
return _schedulers[key]
def clear_cache():
global _sd_obj
global _llm_obj
global _pipe_kwargs
global _prep_kwargs
global _gen_kwargs
global _schedulers
del _sd_obj
del _llm_obj
del _schedulers
gc.collect()
_sd_obj = None
_llm_obj = None
_pipe_kwargs = None
_prep_kwargs = None
_gen_kwargs = None
_schedulers = None

View File

@@ -1,6 +0,0 @@
from .png_metadata import (
import_png_metadata,
)
from .display import (
displayable_metadata,
)

View File

@@ -1,43 +0,0 @@
import csv
import os
from .format import humanize, humanizable
def csv_path(image_filename: str):
return os.path.join(os.path.dirname(image_filename), "imgs_details.csv")
def has_csv(image_filename: str) -> bool:
return os.path.exists(csv_path(image_filename))
def matching_filename(image_filename: str, row):
# we assume the final column of the csv has the original filename with full path and match that
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
# the value of the OUTPUT key
return os.path.basename(image_filename) in (
row[-1] if isinstance(row, list) else row["OUTPUT"]
)
def parse_csv(image_filename: str):
csv_filename = csv_path(image_filename)
with open(csv_filename, "r", newline="") as csv_file:
# We use a reader or DictReader here for images_details.csv depending on whether we think it
# has headers or not. Having headers means less guessing of the format.
has_header = csv.Sniffer().has_header(csv_file.read(2048))
csv_file.seek(0)
reader = csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
matches = [
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
humanize(row)
for row in reader
if row
and (has_header or humanizable(row))
and matching_filename(image_filename, row)
]
return matches[0] if matches else {}

View File

@@ -1,53 +0,0 @@
import json
import os
from PIL import Image
from .png_metadata import parse_generation_parameters
from .exif_metadata import has_exif, parse_exif
from .csv_metadata import has_csv, parse_csv
from .format import compact, humanize
def displayable_metadata(image_filename: str) -> dict:
if not os.path.isfile(image_filename):
return {"source": "missing", "parameters": {}}
pil_image = Image.open(image_filename)
# we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads,
# and we go via that for SendTo, and is directly tied to the image)
if "parameters" in pil_image.info:
return {
"source": "png",
"parameters": compact(
parse_generation_parameters(pil_image.info["parameters"])
),
}
# we have a matching json file (next most likely to be accurate when it's there)
json_path = os.path.splitext(image_filename)[0] + ".json"
if os.path.isfile(json_path):
with open(json_path) as params_file:
return {
"source": "json",
"parameters": compact(
humanize(json.load(params_file), includes_filename=False)
),
}
# we have a CSV file so try that (can be different shapes, and it usually has no
# headers/param names so of the things we we *know* have parameters, it's the
# last resort)
if has_csv(image_filename):
params = parse_csv(image_filename)
if params: # we might not have found the filename in the csv
return {
"source": "csv",
"parameters": compact(params), # already humanized
}
# EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something*
if has_exif(image_filename):
return {"source": "exif", "parameters": parse_exif(pil_image)}
# we've got nothing
return None

View File

@@ -1,52 +0,0 @@
from PIL import Image
from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS
def has_exif(image_filename: str) -> bool:
return True if Image.open(image_filename).getexif() else False
def parse_exif(pil_image: Image) -> dict:
img_exif = pil_image.getexif()
# See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594
# I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I
# I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a
# dependency
exif_tags = {
TAGS.get(key, key): str(val)
for (key, val) in img_exif.items()
if key in TAGS
and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo)
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
def try_get_ifd(ifd_id):
try:
return img_exif.get_ifd(ifd_id).items()
except KeyError:
return {}
ifd_tags = {
TAGS.get(key, key): str(val)
for ifd_id in IFD
for (key, val) in try_get_ifd(ifd_id)
if ifd_id != IFD.GPSInfo
and key in TAGS
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
gps_tags = {
GPSTAGS.get(key, key): str(val)
for (key, val) in try_get_ifd(IFD.GPSInfo)
if key in GPSTAGS
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
return {**exif_tags, **ifd_tags, **gps_tags}

View File

@@ -1,139 +0,0 @@
# As SHARK has evolved more columns have been added to images_details.csv. However, since
# no version of the CSV has any headers (yet) we don't actually have anything within the
# file that tells us which parameter each column is for. So this is a list of known patterns
# indexed by length which is what we're going to have to use to guess which columns are the
# right ones for the file we're looking at.
# The same ordering is used for JSON, but these do have key names, however they are not very
# human friendly, nor do they match up with the what is written to the .png headers
# So these are functions to try and get something consistent out the raw input from all
# these sources
PARAMS_FORMATS = {
9: {
"VARIANT": "Model",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"OUTPUT": "Filename",
},
10: {
"MODEL": "Model",
"VARIANT": "Variant",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"OUTPUT": "Filename",
},
12: {
"VARIANT": "Model",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
},
}
PARAMS_FORMAT_CURRENT = {
"VARIANT": "Model",
"VAE": "VAE",
"LORA": "LoRA",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
}
def compact(metadata: dict) -> dict:
# we don't want to alter the original dictionary
result = dict(metadata)
# discard the filename because we should already have it
if result.keys() & {"Filename"}:
result.pop("Filename")
# make showing the sizes more compact by using only one line each
if result.keys() & {"Size-1", "Size-2"}:
result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}"
elif result.keys() & {"Height", "Width"}:
result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}"
if result.keys() & {"Hires resize-1", "Hires resize-1"}:
hires_y = result.pop("Hires resize-1")
hires_x = result.pop("Hires resize-2")
if hires_x == 0 and hires_y == 0:
result["Hires resize"] = "None"
else:
result["Hires resize"] = f"{hires_y}x{hires_x}"
# remove VAE if it exists and is empty
if (result.keys() & {"VAE"}) and (not result["VAE"] or result["VAE"] == "None"):
result.pop("VAE")
# remove LoRA if it exists and is empty
if (result.keys() & {"LoRA"}) and (not result["LoRA"] or result["LoRA"] == "None"):
result.pop("LoRA")
return result
def humanizable(metadata: dict | list[str], includes_filename=True) -> dict:
lookup_key = len(metadata) + (0 if includes_filename else 1)
return lookup_key in PARAMS_FORMATS.keys()
def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
lookup_key = len(metadata) + (0 if includes_filename else 1)
# For lists we can only work based on the length, we have no other information
if isinstance(metadata, list):
if humanizable(metadata, includes_filename):
return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata))
else:
raise KeyError(
f"Humanize could not find the format for a parameter list of length {len(metadata)}"
)
# For dictionaries we try to use the matching length parameter format if
# available, otherwise we just use the current format which is assumed to
# have everything currently known about. Then we swap keys in the metadata
# that match keys in the format for the friendlier name that we have set
# in the format value
if isinstance(metadata, dict):
if humanizable(metadata, includes_filename):
format = PARAMS_FORMATS[lookup_key]
else:
format = PARAMS_FORMAT_CURRENT
return {
format[key]: metadata[key]
for key in format.keys()
if key in metadata.keys() and metadata[key]
}
raise TypeError("Can only humanize parameter lists or dictionaries")

View File

@@ -1,216 +0,0 @@
import re
from pathlib import Path
from apps.shark_studio.web.utils.file_utils import (
get_checkpoint_pathfile,
)
from apps.shark_studio.api.sd import EMPTY_SD_MAP as sd_model_map
from apps.shark_studio.modules.schedulers import (
scheduler_model_map,
)
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
def parse_generation_parameters(x: str):
res = {}
prompt = ""
negative_prompt = ""
done_with_prompt = False
*lines, lastline = x.strip().split("\n")
if len(re_param.findall(lastline)) < 3:
lines.append(lastline)
lastline = ""
for i, line in enumerate(lines):
line = line.strip()
if line.startswith("Negative prompt:"):
done_with_prompt = True
line = line[16:].strip()
if done_with_prompt:
negative_prompt += ("" if negative_prompt == "" else "\n") + line
else:
prompt += ("" if prompt == "" else "\n") + line
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v)
if m is not None:
res[k + "-1"] = m.group(1)
res[k + "-2"] = m.group(2)
else:
res[k] = v
# Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res:
res["Clip skip"] = "1"
hypernet = res.get("Hypernet", None)
if hypernet is not None:
res[
"Prompt"
] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
res["Hires resize-2"] = 0
return res
def try_find_model_base_from_png_metadata(file: str, folder: str = "models") -> str:
custom = ""
# Remove extension from file info
if file.endswith(".safetensors") or file.endswith(".ckpt"):
file = Path(file).stem
# Check for the file name match with one of the local ckpt or safetensors files
if Path(get_checkpoint_pathfile(file + ".ckpt", folder)).is_file():
custom = file + ".ckpt"
if Path(get_checkpoint_pathfile(file + ".safetensors", folder)).is_file():
custom = file + ".safetensors"
return custom
def find_model_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
png_hf_id = ""
png_custom = ""
if key in metadata:
model_file = metadata[key]
png_custom = try_find_model_base_from_png_metadata(model_file)
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if model_file in sd_model_map:
png_custom = model_file
# If nothing had matched, check vendor/hf_model_id
if not png_custom and model_file.count("/"):
png_hf_id = model_file
# No matching model was found
if not png_custom and not png_hf_id:
print(
"Import PNG info: Unable to find a matching model for %s" % model_file
)
return png_custom, png_hf_id
def find_vae_from_png_metadata(key: str, metadata: dict[str, str | int]) -> str:
vae_custom = ""
if key in metadata:
vae_file = metadata[key]
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
# VAE input is optional, should not print or throw an error if missing
return vae_custom
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
lora_custom = ""
if key in metadata:
lora_file = metadata[key]
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
def import_png_metadata(
pil_data,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
cfg_scale = float(metadata["CFG scale"])
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
elif "Model" in metadata and png_hf_model_id:
custom_model = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
if "Prompt" in metadata:
prompt = metadata["Prompt"]
if "Sampler" in metadata:
if metadata["Sampler"] in scheduler_model_map:
sampler = metadata["Sampler"]
else:
print(
"Import PNG info: Unable to find a scheduler for %s"
% metadata["Sampler"]
)
except Exception as ex:
if pil_data and pil_data.info.get("parameters"):
print("import_png_metadata failed with %s" % ex)
pass
return (
None,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
)

View File

@@ -1,39 +0,0 @@
import apps.shark_studio.web.utils.globals as global_obj
import gc
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
if batch_index < batch_count:
bs = f"x{batch_size}" if batch_size > 1 else ""
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
else:
return f"{tab_name} complete"
def get_generation_text_info(seeds, device):
cfg_dump = {}
for cfg in global_obj.get_config_dict():
cfg_dump[cfg] = cfg
text_output = f"prompt={cfg_dump['prompts']}"
text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}"
text_output += (
f"\nmodel_id={cfg_dump['hf_model_id']}, " f"ckpt_loc={cfg_dump['ckpt_loc']}"
)
text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}"
text_output += (
f"\nsteps={cfg_dump['steps']}, "
f"guidance_scale={cfg_dump['guidance_scale']}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={cfg_dump['height']}x{cfg_dump['width']}, "
if not cfg_dump.use_hiresfix
else f"\nsize={cfg_dump['hiresfix_height']}x{cfg_dump['hiresfix_width']}, "
)
text_output += (
f"batch_count={cfg_dump['batch_count']}, "
f"batch_size={cfg_dump['batch_size']}, "
f"max_length={cfg_dump['max_length']}"
)
return text_output

View File

@@ -1,75 +0,0 @@
import os
import shutil
from time import time
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
shark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "shark_tmp/")
def clear_tmp_mlir():
cleanup_start = time()
print("Clearing .mlir temporary files from a prior run. This may take some time...")
mlir_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.endswith(".mlir")
]
for filename in mlir_files:
os.remove(os.path.join(shark_tmp, filename))
print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.")
def clear_tmp_imgs():
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
print(
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
)
# Clear all gradio tmp images from the last session
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
cleanup_start = time()
print(
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
)
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
print(
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
)
# older SHARK versions had to workaround gradio bugs and stored things differently
else:
image_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
print(
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
)
cleanup_start = time()
for filename in image_files:
os.remove(shark_tmp + filename)
print(
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
)
else:
print("No temporary images files to clear.")
def config_tmp():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
clear_tmp_mlir()
clear_tmp_imgs()

View File

@@ -0,0 +1,87 @@
Compile / Run Instructions:
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
Compile Commands FP32/FP16:
```shell
Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=mhlo for tf models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
CPU:
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
```
Run / Benchmark Command (FP32 - NCHW):
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
```shell
## Vulkan AMD:
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CUDA:
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=cuda --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CPU:
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=local-task --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
Run via vulkan_gui for RGP Profiling:
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
```shell
./build/vulkan_gui/iree-vulkan-gui --module=/path/to/unet.vmfb --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
</details>
<details>
<summary>Debug Commands</summary>
## Debug commands and other advanced usage follows.
```shell
python txt2img.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
```
## dump all dispatch .spv and isa using amdllpc
```shell
python txt2img.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
```
## Compile and save the .vmfb (using vulkan fp16 as an example):
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
```
## Capture an RGP trace
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
```
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
```shell
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf16
```
## Run the unet module with iree-benchmark-module (same config as above):
```shell
##if you want to use .npz inputs:
unzip ~/.local/shark_tank/<your unet>/inputs.npz
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --input=@arr_0.npy --input=1xf16 --input=@arr_2.npy --input=@arr_3.npy --input=@arr_4.npy
```
</details>

View File

@@ -0,0 +1,4 @@
from apps.stable_diffusion.scripts.inpaint import inpaint_inf
from apps.stable_diffusion.scripts.outpaint import outpaint_inf
from apps.stable_diffusion.scripts.upscaler import upscaler_inf
from apps.stable_diffusion.scripts.train_lora_word import lora_train

View File

@@ -0,0 +1,126 @@
import sys
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
StencilPipeline,
resize_stencil,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
def main():
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
image = Image.open(args.img_path).convert("RGB")
# When the models get uploaded, it should be default to False.
args.import_mlir = True
use_stencil = args.use_stencil
if use_stencil:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, args.width, args.height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model
if use_stencil:
img2img_obj = StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
start_time = time.time()
generated_imgs = img2img_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.batch_size,
args.height,
args.width,
args.steps,
args.strength,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, strength={args.strength}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += img2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
extra_info = {"STRENGTH": args.strength}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,287 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def inpaint_inf(
prompt: str,
negative_prompt: str,
image_dict,
height: int,
width: int,
inpaint_full_res: bool,
inpaint_full_res_padding: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
args.mask_path = "not none"
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
args.custom_vae = custom_vae
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"inpaint",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
mask_image,
batch_size,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
def main():
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
if args.mask_path is None:
print("Flag --mask_path is required.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
model_id = (
args.hf_model_id
if "inpaint" in args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
image = Image.open(args.img_path)
mask_image = Image.open(args.mask_path)
inpaint_obj = InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = inpaint_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
mask_image,
args.batch_size,
args.height,
args.width,
args.inpaint_full_res,
args.inpaint_full_res_padding,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += inpaint_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,19 @@
from apps.stable_diffusion.src import args
from apps.stable_diffusion.scripts import (
img2img,
txt2img,
# inpaint,
# outpaint,
)
if __name__ == "__main__":
if args.app == "txt2img":
txt2img.main()
elif args.app == "img2img":
img2img.main()
# elif args.app == "inpaint":
# inpaint.main()
# elif args.app == "outpaint":
# outpaint.main()
else:
print(f"args.app value is {args.app} but this isn't supported")

View File

@@ -0,0 +1,312 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def outpaint_inf(
prompt: str,
negative_prompt: str,
init_image,
pixels: int,
mask_blur: int,
directions: list,
noise_q: float,
color_variation: float,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
args.custom_vae = custom_vae
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"outpaint",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
left = True if "left" in directions else False
right = True if "right" in directions else False
top = True if "up" in directions else False
bottom = True if "down" in directions else False
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
init_image,
pixels,
mask_blur,
left,
right,
top,
bottom,
noise_q,
color_variation,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
def main():
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
model_id = (
args.hf_model_id
if "inpaint" in args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
image = Image.open(args.img_path)
outpaint_obj = OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = outpaint_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.pixels,
args.mask_blur,
args.left,
args.right,
args.top,
args.bottom,
args.noise_q,
args.color_variation,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += outpaint_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
# save this information as metadata of output generated image.
directions = []
if args.left:
directions.append("left")
if args.right:
directions.append("right")
if args.top:
directions.append("up")
if args.bottom:
directions.append("down")
extra_info = {
"PIXELS": args.pixels,
"MASK_BLUR": args.mask_blur,
"DIRECTIONS": directions,
"NOISE_Q": args.noise_q,
"COLOR_VARIATION": args.color_variation,
}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,240 @@
import logging
import os
from models.stable_diffusion.main import stable_diff_inf
from models.stable_diffusion.utils import get_available_devices
from dotenv import load_dotenv
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram import BotCommand
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
from io import BytesIO
import random
log = logging.getLogger("TG.Bot")
logging.basicConfig()
log.warning("Start")
load_dotenv()
os.environ["AMD_ENABLE_LLPC"] = "0"
TG_TOKEN = os.getenv("TG_TOKEN")
SELECTED_MODEL = "stablediffusion"
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
STEPS = 30
NEGATIVE_PROMPT = (
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
" legs, fused fingers, too many fingers"
)
GUIDANCE_SCALE = 6
available_devices = get_available_devices()
models_list = [
"stablediffusion",
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]
sheds_list = [
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
]
def image_to_bytes(image):
bio = BytesIO()
bio.name = "image.jpeg"
image.save(bio, "JPEG")
bio.seek(0)
return bio
def get_try_again_markup():
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
reply_markup = InlineKeyboardMarkup(keyboard)
return reply_markup
def generate_image(prompt):
seed = random.randint(1, 10000)
log.warning(SELECTED_MODEL)
log.warning(STEPS)
image, text = stable_diff_inf(
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
steps=STEPS,
guidance_scale=GUIDANCE_SCALE,
seed=seed,
scheduler_key=SELECTED_SCHEDULER,
variant=SELECTED_MODEL,
device_key=available_devices[0],
)
return image, seed
async def generate_and_send_photo(
update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
progress_msg = await update.message.reply_text(
"Generating image...", reply_to_message_id=update.message.message_id
)
im, seed = generate_image(prompt=update.message.text)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{update.message.text}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=update.message.message_id,
)
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
query = update.callback_query
if query.data in models_list:
global SELECTED_MODEL
SELECTED_MODEL = query.data
await query.answer()
await query.edit_message_text(text=f"Selected model: {query.data}")
return
if query.data in sheds_list:
global SELECTED_SCHEDULER
SELECTED_SCHEDULER = query.data
await query.answer()
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
return
replied_message = query.message.reply_to_message
await query.answer()
progress_msg = await query.message.reply_text(
"Generating image...", reply_to_message_id=replied_message.message_id
)
if query.data == "TRYAGAIN":
prompt = replied_message.text
im, seed = generate_image(prompt)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{prompt}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=replied_message.message_id,
)
async def select_model_handler(update, context):
text = "Select model"
keyboard = []
for model in models_list:
keyboard.append(
[
InlineKeyboardButton(text=model, callback_data=model),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def select_scheduler_handler(update, context):
text = "Select schedule"
keyboard = []
for shed in sheds_list:
keyboard.append(
[
InlineKeyboardButton(text=shed, callback_data=shed),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def set_steps_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_steps ")[1]
global STEPS
STEPS = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_steps 30"
)
await update.message.reply_text(input_args)
async def set_negative_prompt_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_negative_prompt ")[1]
global NEGATIVE_PROMPT
NEGATIVE_PROMPT = input_args
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_negative_prompt ugly, bad art, mutated"
)
await update.message.reply_text(input_args)
async def set_guidance_scale_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_guidance_scale ")[1]
global GUIDANCE_SCALE
GUIDANCE_SCALE = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_guidance_scale 7"
)
await update.message.reply_text(input_args)
async def setup_bot_commands(application: Application) -> None:
await application.bot.set_my_commands(
[
BotCommand("select_model", "to select model"),
BotCommand("select_scheduler", "to select scheduler"),
BotCommand("set_steps", "to set steps"),
BotCommand("set_guidance_scale", "to set guidance scale"),
BotCommand("set_negative_prompt", "to set negative prompt"),
]
)
app = (
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
)
app.add_handler(CommandHandler("select_model", select_model_handler))
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
app.add_handler(CommandHandler("set_steps", set_steps_handler))
app.add_handler(
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
)
app.add_handler(
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
)
app.add_handler(
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
)
app.add_handler(CallbackQueryHandler(button))
log.warning("Start bot")
app.run_polling()

View File

@@ -0,0 +1,692 @@
# Install the required libs
# pip install -U git+https://github.com/huggingface/diffusers.git
# pip install accelerate transformers ftfy
# HuggingFace Token
# YOUR_TOKEN = "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"
# Import required libraries
import itertools
import math
import os
from typing import List
import random
import torch_mlir
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import PIL
import logging
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
import torch._dynamo as dynamo
from torch.fx.experimental.proxy_tensor import make_fx
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
from shark.shark_inference import SharkInference
torch._dynamo.config.verbose = True
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker,
)
from PIL import Image
from tqdm.auto import tqdm
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer,
)
from io import BytesIO
from dataclasses import dataclass
from apps.stable_diffusion.src import (
args,
get_schedulers,
set_init_device_flags,
clear_all,
)
from apps.stable_diffusion.src.utils import update_lora_weight
# Setup the dataset
class LoraDataset(Dataset):
def __init__(
self,
data_root,
tokenizer,
size=512,
repeats=100,
interpolation="bicubic",
set="train",
prompt="myloraprompt",
center_crop=False,
):
self.data_root = data_root
self.tokenizer = tokenizer
self.size = size
self.center_crop = center_crop
self.prompt = prompt
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
self.num_images = len(self.image_paths)
self._length = self.num_images
if set == "train":
self._length = self.num_images * repeats
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
example["input_ids"] = self.tokenizer(
self.prompt,
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids[0]
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = np.array(image).astype(np.uint8)
image = (image / 127.5 - 1.0).astype(np.float32)
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
return example
def torch_device(device):
device_tokens = device.split("=>")
if len(device_tokens) == 1:
device_str = device_tokens[0].strip()
else:
device_str = device_tokens[1].strip()
device_type_tokens = device_str.split("://")
if device_type_tokens[0] == "metal":
device_type_tokens[0] = "vulkan"
if len(device_type_tokens) > 1:
return device_type_tokens[0] + ":" + device_type_tokens[1]
else:
return device_type_tokens[0]
########## Setting up the model ##########
def lora_train(
prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
training_images_dir: str,
lora_save_dir: str,
use_lora: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
print(
"Note LoRA training is not compatible with the latest torch-mlir branch"
)
print(
"To run LoRA training you'll need this to follow this guide for the torch-mlir branch: https://github.com/nod-ai/SHARK/tree/main/shark/examples/shark_training/stable_diffusion"
)
torch.manual_seed(seed)
args.prompts = [prompt]
args.steps = steps
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = custom_model
else:
args.hf_model_id = custom_model
args.training_images_dir = training_images_dir
args.lora_save_dir = lora_save_dir
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = torch_device(device)
args.use_lora = use_lora
# Load the Stable Diffusion model
text_encoder = CLIPTextModel.from_pretrained(
args.hf_model_id, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(args.hf_model_id, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
args.hf_model_id, subfolder="unet"
)
def freeze_params(params):
for param in params:
param.requires_grad = False
# Freeze everything but LoRA
freeze_params(vae.parameters())
freeze_params(unet.parameters())
freeze_params(text_encoder.parameters())
# Move vae and unet to device
vae.to(args.device)
unet.to(args.device)
text_encoder.to(args.device)
if use_lora != "":
update_lora_weight(unet, args.use_lora, "unet")
else:
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = vae
def forward(self, input):
x = self.vae.encode(input, return_dict=False)[0]
return x
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = unet
def forward(self, x, y, z):
return self.unet.forward(x, y, z, return_dict=False)[0]
shark_vae = VaeModel()
shark_unet = UnetModel()
####### Creating our training data ########
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id,
subfolder="tokenizer",
)
# Let's create the Dataset and Dataloader
train_dataset = LoraDataset(
data_root=args.training_images_dir,
tokenizer=tokenizer,
size=vae.sample_size,
prompt=args.prompts[0],
repeats=100,
center_crop=False,
set="train",
)
def create_dataloader(train_batch_size=1):
return torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True
)
# Create noise_scheduler for training
noise_scheduler = DDPMScheduler.from_config(
args.hf_model_id, subfolder="scheduler"
)
######## Training ###########
# Define hyperparameters for our training. If you are not happy with your results,
# you can tune the `learning_rate` and the `max_train_steps`
# Setting up all training args
hyperparameters = {
"learning_rate": 5e-04,
"scale_lr": True,
"max_train_steps": steps,
"train_batch_size": batch_size,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True,
"mixed_precision": "fp16",
"seed": 42,
"output_dir": "sd-concept-output",
}
# creating output directory
cwd = os.getcwd()
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
while not os.path.exists(str(out_dir)):
try:
os.mkdir(out_dir)
except OSError as error:
print("Output directory not created")
###### Torch-MLIR Compilation ######
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
@make_simple_dynamo_backend
def refbackend_torchdynamo_backend(
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
):
# handling usage of empty tensor without initializing
transform_fx(fx_graph)
fx_graph.recompile()
if _returns_nothing(fx_graph):
return fx_graph
removed_none_indexes = _remove_nones(fx_graph)
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
mlir_module = torch_mlir.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors"
)
bytecode_stream = BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
def compiled_callable(*inputs):
inputs = [x.numpy() for x in inputs]
result = shark_module("forward", inputs)
if was_unwrapped:
result = [
result,
]
if not isinstance(result, list):
result = torch.from_numpy(result)
else:
result = tuple(torch.from_numpy(x) for x in result)
result = list(result)
for removed_index in removed_none_indexes:
result.insert(removed_index, None)
result = tuple(result)
return result
return compiled_callable
def predictions(torch_func, jit_func, batchA, batchB):
res = jit_func(batchA.numpy(), batchB.numpy())
if res is not None:
# prediction = torch.from_numpy(res)
prediction = res
else:
prediction = None
return prediction
logger = logging.getLogger(__name__)
train_batch_size = hyperparameters["train_batch_size"]
gradient_accumulation_steps = hyperparameters[
"gradient_accumulation_steps"
]
learning_rate = hyperparameters["learning_rate"]
if hyperparameters["scale_lr"]:
learning_rate = (
learning_rate
* gradient_accumulation_steps
* train_batch_size
# * accelerator.num_processes
)
# Initialize the optimizer
optimizer = torch.optim.AdamW(
lora_layers.parameters(), # only optimize the embeddings
lr=learning_rate,
)
# Training function
def train_func(batch_pixel_values, batch_input_ids):
# Convert images to latent space
latents = shark_vae(batch_pixel_values).sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.num_train_timesteps,
(bsz,),
device=latents.device,
).long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch_input_ids)[0]
# Predict the noise residual
noise_pred = shark_unet(
noisy_latents,
timesteps,
encoder_hidden_states,
)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
)
loss = (
F.mse_loss(noise_pred, target, reduction="none")
.mean([1, 2, 3])
.mean()
)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
def training_function():
max_train_steps = hyperparameters["max_train_steps"]
output_dir = hyperparameters["output_dir"]
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
train_dataloader = create_dataloader(train_batch_size)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
num_train_epochs = math.ceil(
max_train_steps / num_update_steps_per_epoch
)
# Train!
total_batch_size = (
train_batch_size
* gradient_accumulation_steps
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(
f" Instantaneous batch size per device = {train_batch_size}"
)
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(
# range(max_train_steps), disable=not accelerator.is_local_main_process
range(max_train_steps)
)
progress_bar.set_description("Steps")
global_step = 0
params__ = [
i for i in text_encoder.get_input_embeddings().parameters()
]
for epoch in range(num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
dynamo_callable = dynamo.optimize(
refbackend_torchdynamo_backend
)(train_func)
lam_func = lambda x, y: dynamo_callable(
torch.from_numpy(x), torch.from_numpy(y)
)
loss = predictions(
train_func,
lam_func,
batch["pixel_values"],
batch["input_ids"],
)
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item()}
progress_bar.set_postfix(**logs)
if global_step >= max_train_steps:
break
training_function()
# Save the lora weights
unet.save_attn_procs(args.lora_save_dir)
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
if param.grad is not None:
del param.grad # free some memory
torch.cuda.empty_cache()
if __name__ == "__main__":
if args.clear_all:
clear_all()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
if len(args.prompts) != 1:
print("Need exactly one prompt for the LoRA word")
lora_train(
args.prompts[0],
args.height,
args.width,
args.training_steps,
args.guidance_scale,
args.seed,
args.batch_count,
args.batch_size,
args.scheduler,
"None",
args.hf_model_id,
args.precision,
args.device,
args.max_length,
args.training_images_dir,
args.lora_save_dir,
args.use_lora,
)

View File

@@ -0,0 +1,114 @@
import os
from pathlib import Path
from shark_tuner.codegen_tuner import SharkCodegenTuner
from shark_tuner.iree_utils import (
dump_dispatches,
create_context,
export_module_to_mlir_file,
)
from shark_tuner.model_annotation import model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.utils import set_init_device_flags
from apps.stable_diffusion.src.utils.sd_annotation import (
get_device_args,
load_winograd_configs,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def load_mlir_module():
sd_model = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
max_len=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=False,
low_cpu_mem_usage=args.low_cpu_mem_usage,
return_mlir=True,
)
if args.annotation_model == "unet":
mlir_module = sd_model.unet()
model_name = sd_model.model_name["unet"]
elif args.annotation_model == "vae":
mlir_module = sd_model.vae()
model_name = sd_model.model_name["vae"]
else:
raise ValueError(
f"{args.annotation_model} is not supported for tuning."
)
return mlir_module, model_name
def main():
args.use_tuned = False
set_init_device_flags()
mlir_module, model_name = load_mlir_module()
# Get device and device specific arguments
device, device_spec_args = get_device_args()
device_spec = ""
if device_spec_args:
device_spec = device_spec_args[-1].split("=")[-1].strip()
if device == "vulkan":
device_spec = device_spec.split("-")[0]
# Add winograd annotation for vulkan device
winograd_config = (
load_winograd_configs()
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
else ""
)
with create_context() as ctx:
input_module = model_annotation(
ctx,
input_contents=mlir_module,
config_path=winograd_config,
search_op="conv",
winograd=True,
)
# Dump model dispatches
if device == "vulkan" and device_spec == "rdna3":
device = "vulkan/RX 7900"
generates_dir = Path.home() / "tmp"
if not os.path.exists(generates_dir):
os.makedirs(generates_dir)
dump_mlir = generates_dir / "temp.mlir"
dispatch_dir = generates_dir / f"{model_name}_{device_spec}_dispatches"
export_module_to_mlir_file(input_module, dump_mlir)
dump_dispatches(dump_mlir, device, dispatch_dir, False)
# Tune each dispatch
dtype = "f16" if args.precision == "fp16" else "f32"
config_filename = f"{model_name}_{device_spec}_configs.json"
for f_path in os.listdir(dispatch_dir):
if not f_path.endswith(".mlir"):
continue
model_dir = os.path.join(dispatch_dir, f_path)
tuner = SharkCodegenTuner(
model_dir,
device,
"random",
args.num_iters,
args.tuned_config_dir,
dtype,
args.search_op,
batch_size=1,
config_filename=config_filename,
use_dispatch=True,
)
tuner.tune()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,85 @@
import torch
import transformers
import time
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
def main():
if args.clear_all:
clear_all()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,280 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def upscaler_inf(
prompt: str,
negative_prompt: str,
init_image,
height: int,
width: int,
steps: int,
noise_level: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
if init_image is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB").resize((height, width))
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
args.custom_vae = custom_vae
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
args.height = 128
args.width = 128
new_config_obj = Config(
"upscaler",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
args.height,
args.width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_size = batch_size
args.max_length = max_length
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
"DDPM"
)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"NOISE LEVEL": noise_level}
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
low_res_img = image
high_res_img = Image.new("RGB", (height * 4, width * 4))
for i in range(0, width, 128):
for j in range(0, height, 128):
box = (j, i, j + 128, i + 128)
upscaled_image = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
low_res_img.crop(box),
batch_size,
args.height,
args.width,
steps,
noise_level,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
save_output_img(high_res_img, img_seed, extra_info)
generated_imgs.append(high_res_img)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
# When the models get uploaded, it should be default to False.
args.import_mlir = True
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
image = (
Image.open(args.img_path)
.convert("RGB")
.resize((args.height, args.width))
)
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model
upscaler_obj = UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ddpm_scheduler=schedulers["DDPM"],
ondemand=args.ondemand,
)
start_time = time.time()
generated_imgs = upscaler_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.batch_size,
args.height,
args.width,
args.steps,
args.noise_level,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, noise_level={args.noise_level}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += upscaler_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
extra_info = {"NOISE LEVEL": args.noise_level}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)

View File

@@ -0,0 +1,85 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv-python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
( 'web/ui/css/*', 'ui/css' ),
( 'web/ui/logos/*', 'logos' )
]
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
a = Analysis(
['web/index.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,83 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_submodules
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
]
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
a = Analysis(
['scripts/main.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,18 @@
from apps.stable_diffusion.src.utils import (
args,
set_init_device_flags,
prompt_examples,
get_available_devices,
clear_all,
save_output_img,
resize_stencil,
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,
StencilPipeline,
UpscalerPipeline,
)
from apps.stable_diffusion.src.schedulers import get_schedulers

View File

@@ -0,0 +1,12 @@
from apps.stable_diffusion.src.models.model_wrappers import (
SharkifyStableDiffusionModel,
)
from apps.stable_diffusion.src.models.opt_params import (
get_vae_encode,
get_vae,
get_unet,
get_clip,
get_tokenizer,
get_params,
get_variant_version,
)

View File

@@ -0,0 +1,712 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from collections import defaultdict
import torch
import safetensors.torch
import traceback
import sys
import os
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
base_models,
args,
preprocessCKPT,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
get_stencil_model_id,
update_lora_weight,
)
# These shapes are parameter dependent.
def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape = []
for i in range(len(shape)):
if shape[i] == "max_len":
new_shape.append(max_len)
elif shape[i] == "height":
new_shape.append(height)
elif shape[i] == "width":
new_shape.append(width)
elif isinstance(shape[i], str):
if "*" in shape[i]:
mul_val = int(shape[i].split("*")[0])
if "batch_size" in shape[i]:
new_shape.append(batch_size * mul_val)
elif "height" in shape[i]:
new_shape.append(height * mul_val)
elif "width" in shape[i]:
new_shape.append(width * mul_val)
elif "/" in shape[i]:
import math
div_val = int(shape[i].split("/")[1])
if "batch_size" in shape[i]:
new_shape.append(math.ceil(batch_size / div_val))
elif "height" in shape[i]:
new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val))
else:
new_shape.append(shape[i])
return new_shape
def check_compilation(model, model_name):
if not model:
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
class SharkifyStableDiffusionModel:
def __init__(
self,
model_id: str,
custom_weights: str,
custom_vae: str,
precision: str,
max_len: int = 64,
width: int = 512,
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
low_cpu_mem_usage: bool = False,
debug: bool = False,
sharktank_dir: str = "",
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
use_lora: str = "",
use_quantize: str = None,
return_mlir: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.use_quantize = use_quantize
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
self.model_id = "stabilityai/stable-diffusion-2-1-base"
self.custom_vae = custom_vae
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
"_"
+ str(batch_size)
+ "_"
+ str(max_len)
+ "_"
+ str(height)
+ "_"
+ str(width)
+ "_"
+ precision
)
print(f'use_tuned? sharkify: {use_tuned}')
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.is_upscaler = is_upscaler
self.use_stencil = get_stencil_model_id(use_stencil)
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora
print(self.model_name)
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
self.sharktank_dir = sharktank_dir
self.generate_vmfb = generate_vmfb
self.inputs = dict()
self.model_to_run = ""
if self.custom_weights != "":
self.model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights, self.is_inpaint)
else:
self.model_to_run = args.hf_model_id
self.custom_vae = self.process_custom_vae()
self.base_model_id = fetch_and_update_base_model_id(self.model_to_run)
if self.base_model_id != "" and args.ckpt_loc != "":
args.hf_model_id = self.base_model_id
self.return_mlir = return_mlir
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
index = 0
for model in sub_model_list:
sub_model = model
model_config = self.model_name
if "vae" == model:
if self.custom_vae != "":
model_config = model_config + get_path_stem(self.custom_vae)
if self.base_vae:
sub_model = "base_vae"
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
sys.exit("please specify max_len in the range [32, 77].")
if not (width % 8 == 0 and width >= 128):
sys.exit("width should be greater than 128 and multiple of 8")
if not (height % 8 == 0 and height >= 128):
sys.exit("height should be greater than 128 and multiple of 8")
# Get the input info for a model i.e. "unet", "clip", "vae", etc.
def get_input_info_for(self, model_info):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = []
for inp in model_info:
shape = model_info[inp]["shape"]
dtype = dtype_config[model_info[inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, self.max_len, self.width, self.height, self.batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map.append(tensor)
return input_map
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
def forward(self, input):
latents = self.vae.encode(input).latent_dist.sample()
return 0.18215 * latents
vae_encode = VaeEncodeModel()
inputs = tuple(self.inputs["vae_encode"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
shark_vae_encode, vae_encode_mlir = compile_through_fx(
vae_encode,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae_encode"],
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae_encode",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae_encode, vae_encode_mlir
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
super().__init__()
self.vae = None
if custom_vae == "":
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae
def forward(self, input):
if not self.base_vae:
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
if self.base_vae:
return x
x = x * 255.0
return x.round()
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae, vae_mlir = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae, vae_mlir
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
def forward( self, latent, timestep, text_embedding, guidance_scale, control1,
control2, control3, control4, control5, control6, control7,
control8, control9, control10, control11, control12, control13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple([ control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12,])
mb_res_samples = control13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
down_block_additional_residuals=db_res_samples,
mid_block_additional_residual=mb_res_samples,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = ControlledUnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["stencil_unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self):
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
):
super().__init__()
self.cnet = ControlNetModel.from_pretrained(
model_id,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.cnet.in_channels
self.train(False)
def forward(
self,
latent,
timestep,
text_embedding,
stencil_image_input,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
# TODO: guidance NOT NEEDED change in `get_input_info` later
latents = torch.cat(
[latent] * 2
) # needs to be same as controlledUNET latents
stencil_image = torch.cat(
[stencil_image_input] * 2
) # needs to be same as controlledUNET latents
down_block_res_samples, mid_block_res_sample = self.cnet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
controlnet_cond=stencil_image,
return_dict=False,
)
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
scnet = StencilControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
input_mask = [True, True, True, True]
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
extended_model_name=self.model_name["stencil_adaptor"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_adaptor",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_cnet, cnet_mlir
def get_unet(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
if(args.attention_slicing.isdigit()):
self.unet.set_attention_slice(int(args.attention_slicing))
else:
self.unet.set_attention_slice(args.attention_slicing)
# TODO: Instead of flattening the `control` try to use the list.
def forward(
self, latent, timestep, text_embedding, guidance_scale,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_unet_upscaler(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, latent, timestep, text_embedding, noise_level):
unet_out = self.unet.forward(
latent,
timestep,
text_embedding,
noise_level,
return_dict=False,
)[0]
return unet_out
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip, clip_mlir = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
extended_model_name=self.model_name["clip"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip, clip_mlir
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
if not custom_vae.endswith((".ckpt", ".safetensors")):
return self.custom_vae
try:
preprocessCKPT(self.custom_vae)
return get_path_to_diffusers_checkpoint(self.custom_vae)
except:
print("Processing standalone Vae checkpoint")
vae_checkpoint = None
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
if custom_vae.endswith(".ckpt"):
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
else:
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
if "state_dict" in vae_checkpoint:
vae_checkpoint = vae_checkpoint["state_dict"]
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
def compile_unet_variants(self, model):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler()
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
return get_unet()
else:
return self.get_unet()
else:
return self.get_controlled_unet()
def vae_encode(self):
try:
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
check_compilation(compiled_vae_encode, "Vae Encode")
if self.return_mlir:
return vae_encode_mlir
return compiled_vae_encode
except Exception as e:
sys.exit(e)
def clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(base_models["clip"])
compiled_clip, clip_mlir = self.get_clip()
check_compilation(compiled_clip, "Clip")
if self.return_mlir:
return clip_mlir
return compiled_clip
except Exception as e:
sys.exit(e)
def unet(self):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
compiled_unet = None
unet_inputs = base_models[model]
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet, unet_mlir = self.compile_unet_variants(model)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
try:
compiled_unet, unet_mlir = self.compile_unet_variants(model)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(self.model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
break
check_compilation(compiled_unet, "Unet")
if self.return_mlir:
return unet_mlir
return compiled_unet
except Exception as e:
sys.exit(e)
def vae(self):
try:
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
self.inputs["vae"] = self.get_input_info_for(vae_input)
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")
if self.return_mlir:
return vae_mlir
return compiled_vae
except Exception as e:
sys.exit(e)
def controlnet(self):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
check_compilation(compiled_stencil_adaptor, "Stencil")
if self.return_mlir:
return controlnet_mlir
return compiled_stencil_adaptor
except Exception as e:
sys.exit(e)
def __call__(self):
from utils import get_vmfb_path_name
from stable_args import args
import traceback, functools, operator, os
model_name = ["clip", "base_vae" if self.base_vae else "vae", "unet"]
vmfb_path = [
get_vmfb_path_name(model + self.model_name)[0]
for model in model_name
]
for model_id in base_models:
self.inputs = self.get_input_info_for(
base_models[model_id],
)
try:
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
except Exception as e:
if args.enable_stack_trace:
traceback.print_exc()
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = functools.reduce(
operator.__and__, vmfb_present
)
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(len(vmfb_path)):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
print("Retrying with a different base model configuration")
continue
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please use `enable_stack_trace` and create an issue at https://github.com/nod-ai/SHARK/issues"
)

View File

@@ -0,0 +1,123 @@
import sys
from transformers import CLIPTokenizer
from apps.stable_diffusion.src.utils import (
models_db,
args,
get_shark_model,
get_opt_flags,
)
hf_model_variant_map = {
"Linaqruf/anything-v3.0": ["anythingv3", "v1_4"],
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v1_4"],
"prompthero/openjourney": ["openjourney", "v1_4"],
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v1_4"],
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
bucket_key = "gs://shark_tank/prashant_nod"
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
def get_params(bucket_key, model_key, model, is_tuned, precision):
try:
bucket = models_db[0][bucket_key]
model_name = models_db[1][model_key]
except KeyError:
raise Exception(
f"{bucket_key}/{model_key} is not present in the models database"
)
iree_flags = get_opt_flags(model, precision="fp16")
return bucket, model_name, iree_flags
def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
# TODO: Get the quantize model from model_db.json
if args.use_quantize == "int8":
bk, mk, flags = get_quantize_model()
return get_shark_model(bk, mk, flags)
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "unet", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae_encode():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
is_base = "/base" if args.use_base_vae else ""
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
variant, version = get_variant_version(args.hf_model_id)
bucket_key = f"{variant}/untuned"
model_key = (
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
)
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "clip", "untuned", "fp32"
)
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
)
return tokenizer

View File

@@ -0,0 +1,18 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
InpaintPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpaint import (
OutpaintPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_stencil import (
StencilPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_upscaler import (
UpscalerPipeline,
)

View File

@@ -0,0 +1,200 @@
import torch
import time
import numpy as np
from tqdm.auto import tqdm
from random import randint
from PIL import Image
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class Image2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_image_latents(
self,
image,
batch_size,
height,
width,
generator,
num_inference_steps,
strength,
dtype,
):
# Pre process image -> get image encoded -> process latents
# TODO: process with variable HxW combos
# Pre process image
image = image.resize((width, height))
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
image_arr = 2 * (image_arr - 0.5)
# set scheduler steps
self.scheduler.set_timesteps(num_inference_steps)
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps
)
t_start = max(num_inference_steps - init_timestep, 0)
# timesteps reduced as per strength
timesteps = self.scheduler.timesteps[t_start:]
# new number of steps to be used as per strength will be
# num_inference_steps = num_inference_steps - t_start
# image encode
latents = self.encode_image((image_arr,))
latents = torch.from_numpy(latents).to(dtype)
# add noise to data
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
latents = self.scheduler.add_noise(
latents, noise, timesteps[0].repeat(1)
)
return latents, timesteps
def encode_image(self, input_image):
self.load_vae_encode()
vae_encode_start = time.time()
latents = self.vae_encode("forward", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
if self.ondemand:
self.unload_vae_encode()
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
return latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
strength,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
use_stencil,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare input image latent
image_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=image_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,473 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from PIL import Image, ImageOps
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class InpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
latents = latents * self.scheduler.init_noise_sigma
return latents
def get_crop_region(self, mask, pad=0):
h, w = mask.shape
crop_left = 0
for i in range(w):
if not (mask[:, i] == 0).all():
break
crop_left += 1
crop_right = 0
for i in reversed(range(w)):
if not (mask[:, i] == 0).all():
break
crop_right += 1
crop_top = 0
for i in range(h):
if not (mask[i] == 0).all():
break
crop_top += 1
crop_bottom = 0
for i in reversed(range(h)):
if not (mask[i] == 0).all():
break
crop_bottom += 1
return (
int(max(crop_left - pad, 0)),
int(max(crop_top - pad, 0)),
int(min(w - crop_right + pad, w)),
int(min(h - crop_bottom + pad, h)),
)
def expand_crop_region(
self,
crop_region,
processing_width,
processing_height,
image_width,
image_height,
):
x1, y1, x2, y2 = crop_region
ratio_crop_region = (x2 - x1) / (y2 - y1)
ratio_processing = processing_width / processing_height
if ratio_crop_region > ratio_processing:
desired_height = (x2 - x1) / ratio_processing
desired_height_diff = int(desired_height - (y2 - y1))
y1 -= desired_height_diff // 2
y2 += desired_height_diff - desired_height_diff // 2
if y2 >= image_height:
diff = y2 - image_height
y2 -= diff
y1 -= diff
if y1 < 0:
y2 -= y1
y1 -= y1
if y2 >= image_height:
y2 = image_height
else:
desired_width = (y2 - y1) * ratio_processing
desired_width_diff = int(desired_width - (x2 - x1))
x1 -= desired_width_diff // 2
x2 += desired_width_diff - desired_width_diff // 2
if x2 >= image_width:
diff = x2 - image_width
x2 -= diff
x1 -= diff
if x1 < 0:
x2 -= x1
x1 -= x1
if x2 >= image_width:
x2 = image_width
return x1, y1, x2, y2
def resize_image(self, resize_mode, im, width, height):
"""
resize_mode:
0: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
1: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
"""
if resize_mode == 0:
ratio = width / height
src_ratio = im.width / im.height
src_w = (
width if ratio > src_ratio else im.width * height // im.height
)
src_h = (
height if ratio <= src_ratio else im.height * width // im.width
)
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(
resized,
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
)
else:
ratio = width / height
src_ratio = im.width / im.height
src_w = (
width if ratio < src_ratio else im.width * height // im.height
)
src_h = (
height if ratio >= src_ratio else im.height * width // im.width
)
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(
resized,
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
)
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
res.paste(
resized.resize((width, fill_height), box=(0, 0, width, 0)),
box=(0, 0),
)
res.paste(
resized.resize(
(width, fill_height),
box=(0, resized.height, width, resized.height),
),
box=(0, fill_height + src_h),
)
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
res.paste(
resized.resize(
(fill_width, height), box=(0, 0, 0, height)
),
box=(0, 0),
)
res.paste(
resized.resize(
(fill_width, height),
box=(resized.width, 0, resized.width, height),
),
box=(fill_width + src_w, 0),
)
return res
def prepare_mask_and_masked_image(
self,
image,
mask,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
):
# preprocess image
image = image.resize((width, height))
mask = mask.resize((width, height))
paste_to = ()
overlay_image = None
if inpaint_full_res:
# prepare overlay image
overlay_image = Image.new("RGB", (image.width, image.height))
overlay_image.paste(
image.convert("RGB"),
mask=ImageOps.invert(mask.convert("L")),
)
# prepare mask
mask = mask.convert("L")
crop_region = self.get_crop_region(
np.array(mask), inpaint_full_res_padding
)
crop_region = self.expand_crop_region(
crop_region, width, height, mask.width, mask.height
)
x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region)
mask = self.resize_image(1, mask, width, height)
paste_to = (x1, y1, x2 - x1, y2 - y1)
# prepare image
image = image.crop(crop_region)
image = self.resize_image(1, image, width, height)
if isinstance(image, (Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image, paste_to, overlay_image
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
):
mask = torch.nn.functional.interpolate(
mask, size=(height // 8, width // 8)
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
return mask, masked_image_latents
def apply_overlay(self, image, paste_loc, overlay):
x, y, w, h = paste_loc
image = self.resize_image(0, image, w, h)
overlay.paste(image, (x, y))
return overlay
def generate_images(
self,
prompts,
neg_prompts,
image,
mask_image,
batch_size,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Preprocess mask and image
(
mask,
masked_image,
paste_to,
overlay_image,
) = self.prepare_mask_and_masked_image(
image,
mask_image,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
)
# Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask=mask,
masked_image=masked_image,
batch_size=batch_size,
height=height,
width=width,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
mask=mask,
masked_image_latents=masked_image_latents,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
if inpaint_full_res:
output_image = self.apply_overlay(
all_imgs[0], paste_to, overlay_image
)
return [output_image]
return all_imgs

View File

@@ -0,0 +1,567 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from PIL import Image, ImageDraw, ImageFilter
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
import math
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class OutpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_mask_and_masked_image(
self, image, mask, mask_blur, width, height
):
if mask_blur > 0:
mask = mask.filter(ImageFilter.GaussianBlur(mask_blur))
image = image.resize((width, height))
mask = mask.resize((width, height))
# preprocess image
if isinstance(image, (Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
):
mask = torch.nn.functional.interpolate(
mask, size=(height // 8, width // 8)
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
return mask, masked_image_latents
def get_matched_noise(
self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05
):
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(data):
if data.ndim > 2: # has channels
out_fft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]),
dtype=np.complex128,
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_fft[:, :, c] = np.fft.fft2(
np.fft.fftshift(c_data), norm="ortho"
)
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
else: # one channel
out_fft = np.zeros(
(data.shape[0], data.shape[1]), dtype=np.complex128
)
out_fft[:, :] = np.fft.fft2(
np.fft.fftshift(data), norm="ortho"
)
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
return out_fft
def _ifft2(data):
if data.ndim > 2: # has channels
out_ifft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]),
dtype=np.complex128,
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_ifft[:, :, c] = np.fft.ifft2(
np.fft.fftshift(c_data), norm="ortho"
)
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
else: # one channel
out_ifft = np.zeros(
(data.shape[0], data.shape[1]), dtype=np.complex128
)
out_ifft[:, :] = np.fft.ifft2(
np.fft.fftshift(data), norm="ortho"
)
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
return out_ifft
def _get_gaussian_window(width, height, std=3.14, mode=0):
window_scale_x = float(width / min(width, height))
window_scale_y = float(height / min(width, height))
window = np.zeros((width, height))
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
for y in range(height):
fy = (y / height * 2.0 - 1.0) * window_scale_y
if mode == 0:
window[:, y] = np.exp(-(x**2 + fy**2) * std)
else:
window[:, y] = (
1 / ((x**2 + 1.0) * (fy**2 + 1.0))
) ** (std / 3.14)
return window
def _get_masked_window_rgb(np_mask_grey, hardness=1.0):
np_mask_rgb = np.zeros(
(np_mask_grey.shape[0], np_mask_grey.shape[1], 3)
)
if hardness != 1.0:
hardened = np_mask_grey[:] ** hardness
else:
hardened = np_mask_grey[:]
for c in range(3):
np_mask_rgb[:, :, c] = hardened[:]
return np_mask_rgb
def _match_cumulative_cdf(source, template):
src_values, src_unique_indices, src_counts = np.unique(
source.ravel(), return_inverse=True, return_counts=True
)
tmpl_values, tmpl_counts = np.unique(
template.ravel(), return_counts=True
)
# calculate normalized quantiles for each array
src_quantiles = np.cumsum(src_counts) / source.size
tmpl_quantiles = np.cumsum(tmpl_counts) / template.size
interp_a_values = np.interp(
src_quantiles, tmpl_quantiles, tmpl_values
)
return interp_a_values[src_unique_indices].reshape(source.shape)
def _match_histograms(image, reference):
if image.ndim != reference.ndim:
raise ValueError(
"Image and reference must have the same number of channels."
)
if image.shape[-1] != reference.shape[-1]:
raise ValueError(
"Number of channels in the input image and reference image must match!"
)
matched = np.empty(image.shape, dtype=image.dtype)
for channel in range(image.shape[-1]):
matched_channel = _match_cumulative_cdf(
image[..., channel], reference[..., channel]
)
matched[..., channel] = matched_channel
matched = matched.astype(np.float64, copy=False)
return matched
width = _np_src_image.shape[0]
height = _np_src_image.shape[1]
num_channels = _np_src_image.shape[2]
np_src_image = _np_src_image[:] * (1.0 - np_mask_rgb)
np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
# rather than leave the masked area black, we get better results from fft by filling the average unmasked color
windowed_image = _np_src_image * (
1.0 - _get_masked_window_rgb(np_mask_grey)
)
windowed_image /= np.max(windowed_image)
windowed_image += np.average(_np_src_image) * np_mask_rgb
src_fft = _fft2(
windowed_image
) # get feature statistics from masked src img
src_dist = np.absolute(src_fft)
src_phase = src_fft / src_dist
# create a generator with a static seed to make outpainting deterministic / only follow global seed
rng = np.random.default_rng(0)
noise_window = _get_gaussian_window(
width, height, mode=1
) # start with simple gaussian noise
noise_rgb = rng.random((width, height, num_channels))
noise_grey = np.sum(noise_rgb, axis=2) / 3.0
# the colorfulness of the starting noise is blended to greyscale with a parameter
noise_rgb *= color_variation
for c in range(num_channels):
noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
noise_fft = _fft2(noise_rgb)
for c in range(num_channels):
noise_fft[:, :, c] *= noise_window
noise_rgb = np.real(_ifft2(noise_fft))
shaped_noise_fft = _fft2(noise_rgb)
shaped_noise_fft[:, :, :] = (
np.absolute(shaped_noise_fft[:, :, :]) ** 2
* (src_dist**noise_q)
* src_phase
) # perform the actual shaping
# color_variation
brightness_variation = 0.0
contrast_adjusted_np_src = (
_np_src_image[:] * (brightness_variation + 1.0)
- brightness_variation * 2.0
)
shaped_noise = np.real(_ifft2(shaped_noise_fft))
shaped_noise -= np.min(shaped_noise)
shaped_noise /= np.max(shaped_noise)
shaped_noise[img_mask, :] = _match_histograms(
shaped_noise[img_mask, :] ** 1.0,
contrast_adjusted_np_src[ref_mask, :],
)
shaped_noise = (
_np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
)
matched_noise = shaped_noise[:]
return np.clip(matched_noise, 0.0, 1.0)
def generate_images(
self,
prompts,
neg_prompts,
image,
pixels,
mask_blur,
is_left,
is_right,
is_top,
is_bottom,
noise_q,
color_variation,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
process_width = width
process_height = height
left = pixels if is_left else 0
right = pixels if is_right else 0
up = pixels if is_top else 0
down = pixels if is_bottom else 0
target_w = math.ceil((image.width + left + right) / 64) * 64
target_h = math.ceil((image.height + up + down) / 64) * 64
if left > 0:
left = left * (target_w - image.width) // (left + right)
if right > 0:
right = target_w - image.width - left
if up > 0:
up = up * (target_h - image.height) // (up + down)
if down > 0:
down = target_h - image.height - up
def expand(
init_img,
expand_pixels,
is_left=False,
is_right=False,
is_top=False,
is_bottom=False,
):
is_horiz = is_left or is_right
is_vert = is_top or is_bottom
pixels_horiz = expand_pixels if is_horiz else 0
pixels_vert = expand_pixels if is_vert else 0
res_w = init_img.width + pixels_horiz
res_h = init_img.height + pixels_vert
process_res_w = math.ceil(res_w / 64) * 64
process_res_h = math.ceil(res_h / 64) * 64
img = Image.new("RGB", (process_res_w, process_res_h))
img.paste(
init_img,
(pixels_horiz if is_left else 0, pixels_vert if is_top else 0),
)
msk = Image.new("RGB", (process_res_w, process_res_h), "white")
draw = ImageDraw.Draw(msk)
draw.rectangle(
(
expand_pixels + mask_blur if is_left else 0,
expand_pixels + mask_blur if is_top else 0,
msk.width - expand_pixels - mask_blur
if is_right
else res_w,
msk.height - expand_pixels - mask_blur
if is_bottom
else res_h,
),
fill="black",
)
np_image = (np.asarray(img) / 255.0).astype(np.float64)
np_mask = (np.asarray(msk) / 255.0).astype(np.float64)
noised = self.get_matched_noise(
np_image, np_mask, noise_q, color_variation
)
output_image = Image.fromarray(
np.clip(noised * 255.0, 0.0, 255.0).astype(np.uint8),
mode="RGB",
)
target_width = (
min(width, init_img.width + pixels_horiz)
if is_horiz
else img.width
)
target_height = (
min(height, init_img.height + pixels_vert)
if is_vert
else img.height
)
crop_region = (
0 if is_left else output_image.width - target_width,
0 if is_top else output_image.height - target_height,
target_width if is_left else output_image.width,
target_height if is_top else output_image.height,
)
mask_to_process = msk.crop(crop_region)
image_to_process = output_image.crop(crop_region)
# Preprocess mask and image
mask, masked_image = self.prepare_mask_and_masked_image(
image_to_process, mask_to_process, mask_blur, width, height
)
# Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask=mask,
masked_image=masked_image,
batch_size=batch_size,
height=height,
width=width,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
mask=mask,
masked_image_latents=masked_image_latents,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
res_img = all_imgs[0].resize(
(image_to_process.width, image_to_process.height)
)
output_image.paste(
res_img,
(
0 if is_left else output_image.width - res_img.width,
0 if is_top else output_image.height - res_img.height,
),
)
output_image = output_image.crop((0, 0, res_w, res_h))
return output_image
img = image.resize((width, height))
if left > 0:
img = expand(img, left, is_left=True)
if right > 0:
img = expand(img, right, is_right=True)
if up > 0:
img = expand(img, up, is_top=True)
if down > 0:
img = expand(img, down, is_bottom=True)
return [img]

View File

@@ -0,0 +1,274 @@
import torch
import time
import numpy as np
from tqdm.auto import tqdm
from random import randint
from PIL import Image
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import controlnet_hint_conversion
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class StencilPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
def load_controlnet(self):
if self.controlnet is not None:
return
self.controlnet = self.sd_model.controlnet()
def unload_controlnet(self):
del self.controlnet
self.controlnet = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_stencil_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
controlnet_conditioning_scale: float = 1.0,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
self.load_controlnet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.ondemand:
self.unload_unet()
self.unload_controlnet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
strength,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
use_stencil,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
controlnet_hint = controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
)
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare initial latent.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
final_timesteps = self.scheduler.timesteps
# Get Image latents
latents = self.produce_stencil_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,144 @@
import torch
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,319 @@
import inspect
import torch
import time
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from PIL import Image
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, Image.Image):
image = [image]
if isinstance(image[0], Image.Image):
w, h = image[0].size
w, h = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class UpscalerPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
low_res_scheduler: Union[
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.low_res_scheduler = low_res_scheduler
def prepare_extra_step_kwargs(self, generator, eta):
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
latents = 1 / 0.08333 * (latents.float())
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height,
width,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_img_latents(
self,
latents,
image,
text_embeddings,
guidance_scale,
noise_level,
total_timesteps,
dtype,
cpu_scheduling,
extra_step_kwargs,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = torch.cat([latent_model_input, image], dim=1)
timestep = torch.tensor([t]).to(dtype).detach().numpy()
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if cpu_scheduling:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
else:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
noise_level,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# 4. Preprocess image
image = preprocess(image).to(dtype)
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long)
noise = torch.randn(
image.shape,
generator=generator,
).to(dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
image = torch.cat([image] * 2)
noise_level = torch.cat([noise_level] * image.shape[0])
height, width = image.shape[2:]
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
eta = 0.0
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# guidance scale as a float32 tensor.
# guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
image=image,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
noise_level=noise_level,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
extra_step_kwargs=extra_step_kwargs,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,839 @@
import torch
import numpy as np
from transformers import CLIPTokenizer
from PIL import Image
from tqdm.auto import tqdm
import time
from typing import Union
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae,
get_clip,
get_unet,
get_tokenizer,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
import sys
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
class StableDiffusionPipeline:
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
self.vae = None
self.text_encoder = None
self.unet = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
self.sd_model = sd_model
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
try:
self.tokenizer = get_tokenizer()
except:
self.load_unet()
self.unload_unet()
self.tokenizer = get_tokenizer()
def load_clip(self):
if self.text_encoder is not None:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
)
self.text_encoder = self.sd_model.clip()
else:
try:
self.text_encoder = get_clip()
except:
print("download pipeline failed, falling back to import_mlir")
self.text_encoder = self.sd_model.clip()
def unload_clip(self):
del self.text_encoder
self.text_encoder = None
def load_unet(self):
if self.unet is not None:
return
if self.import_mlir or self.use_lora:
self.unet = self.sd_model.unet()
else:
try:
self.unet = get_unet()
except:
print("download pipeline failed, falling back to import_mlir")
self.unet = self.sd_model.unet()
def unload_unet(self):
del self.unet
self.unet = None
def load_vae(self):
if self.vae is not None:
return
if self.import_mlir or self.use_lora:
self.vae = self.sd_model.vae()
else:
try:
self.vae = get_vae()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae = self.sd_model.vae()
def unload_vae(self):
del self.vae
self.vae = None
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
text_input = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
# Get unconditional embeddings as well
uncond_input = self.tokenizer(
neg_prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
self.load_clip()
clip_inf_start = time.time()
text_embeddings = self.text_encoder("forward", (text_input,))
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
if use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
if use_base_vae:
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def produce_img_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
self.status = SD_STATE_IDLE
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
@classmethod
def from_pretrained(
cls,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
import_mlir: bool,
model_id: str,
ckpt_loc: str,
custom_vae: str,
precision: str,
max_length: int,
batch_size: int,
height: int,
width: int,
use_base_vae: bool,
use_tuned: bool,
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
if (
not import_mlir
and not use_lora
and cls.__name__ == "StencilPipeline"
):
sys.exit("StencilPipeline not supported with SharkTank currently.")
is_inpaint = cls.__name__ in [
"InpaintPipeline",
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
sd_model = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
use_quantize=use_quantize,
)
if cls.__name__ in ["UpscalerPipeline"]:
return cls(
scheduler,
ddpm_scheduler,
sd_model,
import_mlir,
use_lora,
ondemand,
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
# #####################################################
# Implements text embeddings with weights from prompts
# https://huggingface.co/AlanB/lpw_stable_diffusion_mod
# #####################################################
def encode_prompts_weight(
self,
prompt,
negative_prompt,
model_max_length,
do_classifier_free_guidance=True,
max_embeddings_multiples=1,
num_images_per_prompt=1,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
model_max_length (int):
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not,
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
num_images_per_prompt (`int`):
number of images that should be generated per prompt
SHARK: num_images_per_prompt is not used (defaulted to 1)
"""
# SHARK: Save model_max_length, load the clip and init inference time
self.model_max_length = model_max_length
self.load_clip()
clip_inf_start = time.time()
batch_size = len(prompt) if isinstance(prompt, list) else 1
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt
if do_classifier_free_guidance
else None,
max_embeddings_multiples=max_embeddings_multiples,
)
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = text_embeddings.shape
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = uncond_embeddings.shape
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()
from typing import List, Optional, Union
import re
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(
pipe: StableDiffusionPipeline, prompt: List[str], max_length: int
):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens,
weights,
max_length,
bos,
eos,
no_boseos_middle=True,
chunk_length=77,
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length
if no_boseos_middle
else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
tokens[i] = (
[bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
)
if no_boseos_middle:
weights[i] = (
[1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
)
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j
* (chunk_length - 2) : min(
len(weights[i]), (j + 1) * (chunk_length - 2)
)
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
# text_embedding = pipe.text_encoder(text_input_chunk)[0]
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
formatted_text_input_chunk = torch.cat(
[text_input_chunk, text_input_chunk]
)
text_embedding = pipe.text_encoder(
"forward", (formatted_text_input_chunk,)
)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
else:
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
# Convert the result to tensor
# text_embeddings = pipe.text_encoder(text_input)[0]
formatted_text_input = torch.cat([text_input, text_input])
text_embeddings = pipe.text_encoder(
"forward", (formatted_text_input,)
)[0]
text_embeddings = torch.from_numpy(text_embeddings)[None, :]
return text_embeddings
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
pipe (`StableDiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(
pipe, prompt, max_length - 2
)
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(
prompt, max_length=max_length, truncation=True
).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(
max_length, max([len(token) for token in uncond_tokens])
)
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(
uncond_tokens, dtype=torch.long, device="cpu"
)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
prompt_weights = torch.tensor(
prompt_weights, dtype=torch.float, device="cpu"
)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
uncond_weights = torch.tensor(
uncond_weights, dtype=torch.float, device="cpu"
)
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@@ -0,0 +1,4 @@
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)

View File

@@ -0,0 +1,66 @@
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
def get_schedulers(model_id):
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"SharkEulerDiscrete"
] = SharkEulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers

View File

@@ -0,0 +1,156 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_shark_model,
args,
)
import torch
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
)
def compile(self):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
model_input = {
"euler": {
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
}
example_latent = model_input["euler"]["latent"]
example_output = model_input["euler"]["output"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_sigma = model_input["euler"]["sigma"]
example_dt = model_input["euler"]["dt"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
def _import(self):
scaling_model = ScalingModel()
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model, _ = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
if args.import_mlir:
_import(self)
else:
try:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_step_" + args.precision,
iree_flags,
)
except:
print(
"failed to download model, falling back and using import_mlir"
)
args.import_mlir = True
_import(self)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
return self.scaling_model(
"forward",
(
sample,
sigma,
),
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
return self.step_model(
"forward",
(
noise_pred,
sigma,
latent,
dt,
),
send_to_host=False,
)

View File

@@ -0,0 +1,37 @@
from apps.stable_diffusion.src.utils.profiler import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.utils.resources import (
prompt_examples,
models_db,
base_models,
opt_flags,
resource_path,
)
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
controlnet_hint_conversion,
get_stencil_model_id,
)
from apps.stable_diffusion.src.utils.utils import (
get_shark_model,
compile_through_fx,
set_iree_runtime_flags,
map_device_to_name_path,
set_init_device_flags,
get_available_devices,
get_opt_flags,
preprocessCKPT,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
get_path_stem,
get_extended_name,
clear_all,
save_output_img,
get_generation_text_info,
update_lora_weight,
resize_stencil,
)

View File

@@ -0,0 +1,18 @@
from apps.stable_diffusion.src.utils.stable_args import args
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device:
import iree
print(f"Profiling and saving to {file_path}.")
vulkan_device = iree.runtime.get_device(args.device)
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
return vulkan_device
return None
def end_profiling(device):
if device:
return device.end_profiling()

View File

@@ -0,0 +1,37 @@
import os
import json
import sys
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
def get_json_file(path):
json_var = []
loc_json = resource_path(path)
if os.path.exists(loc_json):
with open(loc_json, encoding="utf-8") as fopen:
json_var = json.load(fopen)
if not json_var:
print(f"Unable to fetch {path}")
return json_var
# TODO: This shouldn't be called from here, every time the file imports
# it will run all the global vars.
prompt_examples = get_json_file("resources/prompts.json")
models_db = get_json_file("resources/model_db.json")
# The base_model contains the input configuration for the different
# models and also helps in providing information for the variants.
base_models = get_json_file("resources/base_model.json")
# Contains optimization flags for different models.
opt_flags = get_json_file("resources/opt_flags.json")

View File

@@ -0,0 +1,296 @@
{
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"vae_upscaler": {
"latents" : {
"shape" : [
"1*batch_size",4,"8*height","8*width"
],
"dtype":"f32"
}
}
},
"unet": {
"stabilityai/stable-diffusion-2-1": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-2-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"runwayml/stable-diffusion-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-x4-upscaler": {
"latents": {
"shape": [
"2*batch_size",
7,
"8*height",
"8*width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"noise_level": {
"shape": [2],
"dtype": "i64"
}
}
},
"stencil_adaptor": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
}
},
"stencil_unet": {
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
},
"control1": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control2": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control3": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control4": {
"shape": [2, 320, "height/2", "width/2"],
"dtype": "f32"
},
"control5": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"control6": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"control7": {
"shape": [2, 640, "height/4", "width/4"],
"dtype": "f32"
},
"control8": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"control9": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"control10": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control11": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control12": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
}
}
}

View File

@@ -0,0 +1,23 @@
[
{
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
"stablediffusion/inpaint_v1":"runwayml/stable-diffusion-inpainting",
"stablediffusion/inpaint_v2":"stabilityai/stable-diffusion-2-inpainting",
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
"openjourney/v1_4":"prompthero/openjourney",
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
},
{
"stablediffusion/fp16":"fp16",
"stablediffusion/fp32":"main",
"anythingv3/fp16":"diffusers",
"anythingv3/fp32":"diffusers",
"analogdiffusion/fp16":"main",
"analogdiffusion/fp32":"main",
"openjourney/fp16":"main",
"openjourney/fp32":"main"
}
]

View File

@@ -0,0 +1,19 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/nightly"
},
{
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan"
}
]

View File

@@ -0,0 +1,84 @@
{
"unet": {
"tuned": {
"fp16": {
"default_compilation_flags": []
},
"fp32": {
"default_compilation_flags": []
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
},
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
},
"fp32": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
},
"clip": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
}
}

View File

@@ -0,0 +1,8 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]

View File

@@ -0,0 +1,255 @@
import os
import io
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import iree_target_map, run_cmd
from shark.shark_downloader import (
download_model,
download_public_file,
WORKDIR,
)
from shark.parser import shark_args
from apps.stable_diffusion.src.utils.stable_args import args
def get_device():
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
return device
def get_device_args():
device = get_device()
device_spec_args = []
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
gpu_flags = get_iree_gpu_args()
for flag in gpu_flags:
device_spec_args.append(flag)
elif device == "vulkan":
device_spec_args.append(
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
return device, device_spec_args
# Download the model (Unet or VAE fp16) from shark_tank
def load_model_from_tank():
from apps.stable_diffusion.src.models import (
get_params,
get_variant_version,
)
variant, version = get_variant_version(args.hf_model_id)
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{variant}/untuned"
if args.annotation_model == "unet":
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
is_base = "/base" if args.use_base_vae else ""
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)
return mlir_model, model_name
# Download the tuned config files from shark_tank
def load_winograd_configs():
device = get_device()
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
winograd_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
def load_lower_configs(base_model_id=None):
from apps.stable_diffusion.src.models import get_variant_version
from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
)
if not base_model_id:
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
variant, version = get_variant_version(base_model_id)
if version == "inpaint_v1":
version = "v1_4"
elif version == "inpaint_v2":
version = "v2_1base"
config_bucket = "gs://shark_tank/sd_tuned_configs/"
device, device_spec_args = get_device_args()
spec = ""
if device_spec_args:
spec = device_spec_args[-1].split("=")[-1].strip()
if device == "vulkan":
spec = spec.split("-")[0]
if args.annotation_model == "vae":
if not spec or spec in ["rdna3", "sm_80"]:
config_name = (
f"{args.annotation_model}_{args.precision}_{device}.json"
)
else:
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
and args.width == 768
):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading lowering config file from ", lowering_config_dir)
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
# Annotate the model with Winograd attribute on selected conv ops
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
with create_context() as ctx:
winograd_model = model_annotation(
ctx,
input_contents=input_mlir,
config_path=winograd_config_dir,
search_op="conv",
winograd=True,
)
bytecode_stream = io.BytesIO()
winograd_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
if args.save_annotation:
if model_name.split("_")[-1] != "tuned":
out_file_path = os.path.join(
args.annotation_output, model_name + "_tuned_torch.mlir"
)
else:
out_file_path = os.path.join(
args.annotation_output, model_name + "_torch.mlir"
)
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return bytecode
def dump_after_mlir(input_mlir, use_winograd):
import iree.compiler as ireec
device, device_spec_args = get_device_args()
if use_winograd:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
else:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
dump_module = ireec.compile_str(
input_mlir,
target_backends=[iree_target_map(device)],
extra_args=device_spec_args
+ [
preprocess_flag,
"--compile-to=preprocessing",
],
)
return dump_module
# For Unet annotate the model with tuned lowering configs
def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
# Dump IR after padding/img2col/winograd passes
dump_module = dump_after_mlir(input_mlir, use_winograd)
print("Applying tuned configs on", model_name)
# Annotate the model with lowering configs in the config file
with create_context() as ctx:
tuned_model = model_annotation(
ctx,
input_contents=dump_module,
config_path=lowering_config_dir,
search_op="all",
)
bytecode_stream = io.BytesIO()
tuned_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
if args.save_annotation:
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return bytecode
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
winograd_model, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
tuned_model = mlir_model
else:
use_winograd = False
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)
return tuned_model
if __name__ == "__main__":
mlir_model, model_name = load_model_from_tank()
sd_model_annotation(mlir_model, model_name)

View File

@@ -0,0 +1,572 @@
import argparse
import os
from pathlib import Path
def path_expand(s):
return Path(s).expanduser().resolve()
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
)
p.add_argument(
"-p",
"--prompts",
nargs="+",
default=["cyberpunk forest by Salvador Dali"],
help="text of which images to be generated.",
)
p.add_argument(
"--negative_prompts",
nargs="+",
default=["trees, green"],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--img_path",
type=str,
help="Path to the image input for img2img/inpainting",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=int,
default=-1,
help="the seed to use. -1 for a random one.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `batch_count`.",
)
p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 769, 8),
help="the height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 769, 8),
help="the width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="the value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="the value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="the strength of change applied on the given input image for img2img",
)
##############################################################################
### Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The no. of steps to train",
)
##############################################################################
### Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture",
)
p.add_argument(
"--inpaint_full_res_padding",
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding",
)
p.add_argument(
"--pixels",
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting",
)
p.add_argument(
"--mask_blur",
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend left for outpainting",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend right for outpainting",
)
p.add_argument(
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend top for outpainting",
)
p.add_argument(
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend bottom for outpainting",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0)",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0)",
)
##############################################################################
### Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
)
p.add_argument(
"--batch_count",
type=int,
default=1,
help="number of batch to be generated with random seeds in single execution",
)
p.add_argument(
"--ckpt_loc",
type=str,
default="",
help="Path to SD's .ckpt file.",
)
p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
)
p.add_argument(
"--hf_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer)",
)
p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble"],
help="Enable the stencil feature.",
)
p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
)
p.add_argument(
"--ondemand",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
)
##############################################################################
### Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save a generation information json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="if import_mlir is True, saves mlir via the debug option in shark importer. Does nothing if import_mlir is false (the default)",
)
##############################################################################
### Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the progress bar animation during image generation",
)
p.add_argument(
"--ckpt_dir",
type=str,
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for enabling rest API",
)
##############################################################################
### SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file",
)
p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)
p.add_argument(
"--save_annotation",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file",
)
##############################################################################
### SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all",
)
args, unknown = p.parse_known_args()
if args.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
os.getcwd(), args.hf_model_id.replace("/", "_")
)

View File

@@ -0,0 +1,2 @@
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector

View File

@@ -0,0 +1,6 @@
import cv2
class CannyDetector:
def __call__(self, img, low_threshold, high_threshold):
return cv2.Canny(img, low_threshold, high_threshold)

View File

@@ -0,0 +1,62 @@
import requests
from pathlib import Path
import torch
import numpy as np
# from annotator.util import annotator_ckpts_path
from apps.stable_diffusion.src.utils.stencils.openpose.body import Body
from apps.stable_diffusion.src.utils.stencils.openpose.hand import Hand
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
draw_bodypose,
draw_handpose,
handDetect,
)
body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth"
class OpenposeDetector:
def __init__(self):
cwd = Path.cwd()
ckpt_path = Path(cwd, "stencil_annotator")
ckpt_path.mkdir(parents=True, exist_ok=True)
body_modelpath = ckpt_path / "body_pose_model.pth"
hand_modelpath = ckpt_path / "hand_pose_model.pth"
if not body_modelpath.is_file():
r = requests.get(body_model_path, allow_redirects=True)
open(body_modelpath, "wb").write(r.content)
if not hand_modelpath.is_file():
r = requests.get(hand_model_path, allow_redirects=True)
open(hand_modelpath, "wb").write(r.content)
self.body_estimation = Body(body_modelpath)
self.hand_estimation = Hand(hand_modelpath)
def __call__(self, oriImg, hand=False):
oriImg = oriImg[:, :, ::-1].copy()
with torch.no_grad():
candidate, subset = self.body_estimation(oriImg)
canvas = np.zeros_like(oriImg)
canvas = draw_bodypose(canvas, candidate, subset)
if hand:
hands_list = handDetect(candidate, subset, oriImg)
all_hand_peaks = []
for x, y, w, is_left in hands_list:
peaks = self.hand_estimation(
oriImg[y : y + w, x : x + w, :]
)
peaks[:, 0] = np.where(
peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x
)
peaks[:, 1] = np.where(
peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y
)
all_hand_peaks.append(peaks)
canvas = draw_handpose(canvas, all_hand_peaks)
return canvas, dict(
candidate=candidate.tolist(), subset=subset.tolist()
)

View File

@@ -0,0 +1,499 @@
import cv2
import numpy as np
import math
from scipy.ndimage.filters import gaussian_filter
import torch
import torch.nn as nn
from collections import OrderedDict
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
make_layers,
transfer,
padRightDownCorner,
)
class BodyPoseModel(nn.Module):
def __init__(self):
super(BodyPoseModel, self).__init__()
# these layers have no relu layer
no_relu_layers = [
"conv5_5_CPM_L1",
"conv5_5_CPM_L2",
"Mconv7_stage2_L1",
"Mconv7_stage2_L2",
"Mconv7_stage3_L1",
"Mconv7_stage3_L2",
"Mconv7_stage4_L1",
"Mconv7_stage4_L2",
"Mconv7_stage5_L1",
"Mconv7_stage5_L2",
"Mconv7_stage6_L1",
"Mconv7_stage6_L1",
]
blocks = {}
block0 = OrderedDict(
[
("conv1_1", [3, 64, 3, 1, 1]),
("conv1_2", [64, 64, 3, 1, 1]),
("pool1_stage1", [2, 2, 0]),
("conv2_1", [64, 128, 3, 1, 1]),
("conv2_2", [128, 128, 3, 1, 1]),
("pool2_stage1", [2, 2, 0]),
("conv3_1", [128, 256, 3, 1, 1]),
("conv3_2", [256, 256, 3, 1, 1]),
("conv3_3", [256, 256, 3, 1, 1]),
("conv3_4", [256, 256, 3, 1, 1]),
("pool3_stage1", [2, 2, 0]),
("conv4_1", [256, 512, 3, 1, 1]),
("conv4_2", [512, 512, 3, 1, 1]),
("conv4_3_CPM", [512, 256, 3, 1, 1]),
("conv4_4_CPM", [256, 128, 3, 1, 1]),
]
)
# Stage 1
block1_1 = OrderedDict(
[
("conv5_1_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_2_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_3_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_4_CPM_L1", [128, 512, 1, 1, 0]),
("conv5_5_CPM_L1", [512, 38, 1, 1, 0]),
]
)
block1_2 = OrderedDict(
[
("conv5_1_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_2_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_3_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_4_CPM_L2", [128, 512, 1, 1, 0]),
("conv5_5_CPM_L2", [512, 19, 1, 1, 0]),
]
)
blocks["block1_1"] = block1_1
blocks["block1_2"] = block1_2
self.model0 = make_layers(block0, no_relu_layers)
# Stages 2 - 6
for i in range(2, 7):
blocks["block%d_1" % i] = OrderedDict(
[
("Mconv1_stage%d_L1" % i, [185, 128, 7, 1, 3]),
("Mconv2_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d_L1" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d_L1" % i, [128, 38, 1, 1, 0]),
]
)
blocks["block%d_2" % i] = OrderedDict(
[
("Mconv1_stage%d_L2" % i, [185, 128, 7, 1, 3]),
("Mconv2_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d_L2" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d_L2" % i, [128, 19, 1, 1, 0]),
]
)
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_1 = blocks["block1_1"]
self.model2_1 = blocks["block2_1"]
self.model3_1 = blocks["block3_1"]
self.model4_1 = blocks["block4_1"]
self.model5_1 = blocks["block5_1"]
self.model6_1 = blocks["block6_1"]
self.model1_2 = blocks["block1_2"]
self.model2_2 = blocks["block2_2"]
self.model3_2 = blocks["block3_2"]
self.model4_2 = blocks["block4_2"]
self.model5_2 = blocks["block5_2"]
self.model6_2 = blocks["block6_2"]
def forward(self, x):
out1 = self.model0(x)
out1_1 = self.model1_1(out1)
out1_2 = self.model1_2(out1)
out2 = torch.cat([out1_1, out1_2, out1], 1)
out2_1 = self.model2_1(out2)
out2_2 = self.model2_2(out2)
out3 = torch.cat([out2_1, out2_2, out1], 1)
out3_1 = self.model3_1(out3)
out3_2 = self.model3_2(out3)
out4 = torch.cat([out3_1, out3_2, out1], 1)
out4_1 = self.model4_1(out4)
out4_2 = self.model4_2(out4)
out5 = torch.cat([out4_1, out4_2, out1], 1)
out5_1 = self.model5_1(out5)
out5_2 = self.model5_2(out5)
out6 = torch.cat([out5_1, out5_2, out1], 1)
out6_1 = self.model6_1(out6)
out6_2 = self.model6_2(out6)
return out6_1, out6_2
class Body(object):
def __init__(self, model_path):
self.model = BodyPoseModel()
if torch.cuda.is_available():
self.model = self.model.cuda()
model_dict = transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def __call__(self, oriImg):
scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre1 = 0.1
thre2 = 0.05
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = cv2.resize(
oriImg,
(0, 0),
fx=scale,
fy=scale,
interpolation=cv2.INTER_CUBIC,
)
imageToTest_padded, pad = padRightDownCorner(
imageToTest, stride, padValue
)
im = (
np.transpose(
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
(3, 2, 0, 1),
)
/ 256
- 0.5
)
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
if torch.cuda.is_available():
data = data.cuda()
with torch.no_grad():
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
# extract outputs, resize, and remove padding
heatmap = np.transpose(
np.squeeze(Mconv7_stage6_L2), (1, 2, 0)
) # output 1 is heatmaps
heatmap = cv2.resize(
heatmap,
(0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC,
)
heatmap = heatmap[
: imageToTest_padded.shape[0] - pad[2],
: imageToTest_padded.shape[1] - pad[3],
:,
]
heatmap = cv2.resize(
heatmap,
(oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC,
)
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
paf = np.transpose(
np.squeeze(Mconv7_stage6_L1), (1, 2, 0)
) # output 0 is PAFs
paf = cv2.resize(
paf,
(0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC,
)
paf = paf[
: imageToTest_padded.shape[0] - pad[2],
: imageToTest_padded.shape[1] - pad[3],
:,
]
paf = cv2.resize(
paf,
(oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC,
)
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
paf_avg += +paf / len(multiplier)
all_peaks = []
peak_counter = 0
for part in range(18):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
map_left = np.zeros(one_heatmap.shape)
map_left[1:, :] = one_heatmap[:-1, :]
map_right = np.zeros(one_heatmap.shape)
map_right[:-1, :] = one_heatmap[1:, :]
map_up = np.zeros(one_heatmap.shape)
map_up[:, 1:] = one_heatmap[:, :-1]
map_down = np.zeros(one_heatmap.shape)
map_down[:, :-1] = one_heatmap[:, 1:]
peaks_binary = np.logical_and.reduce(
(
one_heatmap >= map_left,
one_heatmap >= map_right,
one_heatmap >= map_up,
one_heatmap >= map_down,
one_heatmap > thre1,
)
)
peaks = list(
zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
) # note reverse
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
peak_id = range(peak_counter, peak_counter + len(peaks))
peaks_with_score_and_id = [
peaks_with_score[i] + (peak_id[i],)
for i in range(len(peak_id))
]
all_peaks.append(peaks_with_score_and_id)
peak_counter += len(peaks)
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
[3, 17],
[6, 18],
]
# the middle joints heatmap correpondence
mapIdx = [
[31, 32],
[39, 40],
[33, 34],
[35, 36],
[41, 42],
[43, 44],
[19, 20],
[21, 22],
[23, 24],
[25, 26],
[27, 28],
[29, 30],
[47, 48],
[49, 50],
[53, 54],
[51, 52],
[55, 56],
[37, 38],
[45, 46],
]
connection_all = []
special_k = []
mid_num = 10
for k in range(len(mapIdx)):
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
candA = all_peaks[limbSeq[k][0] - 1]
candB = all_peaks[limbSeq[k][1] - 1]
nA = len(candA)
nB = len(candB)
indexA, indexB = limbSeq[k]
if nA != 0 and nB != 0:
connection_candidate = []
for i in range(nA):
for j in range(nB):
vec = np.subtract(candB[j][:2], candA[i][:2])
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
norm = max(0.001, norm)
vec = np.divide(vec, norm)
startend = list(
zip(
np.linspace(
candA[i][0], candB[j][0], num=mid_num
),
np.linspace(
candA[i][1], candB[j][1], num=mid_num
),
)
)
vec_x = np.array(
[
score_mid[
int(round(startend[I][1])),
int(round(startend[I][0])),
0,
]
for I in range(len(startend))
]
)
vec_y = np.array(
[
score_mid[
int(round(startend[I][1])),
int(round(startend[I][0])),
1,
]
for I in range(len(startend))
]
)
score_midpts = np.multiply(
vec_x, vec[0]
) + np.multiply(vec_y, vec[1])
score_with_dist_prior = sum(score_midpts) / len(
score_midpts
) + min(0.5 * oriImg.shape[0] / norm - 1, 0)
criterion1 = len(
np.nonzero(score_midpts > thre2)[0]
) > 0.8 * len(score_midpts)
criterion2 = score_with_dist_prior > 0
if criterion1 and criterion2:
connection_candidate.append(
[
i,
j,
score_with_dist_prior,
score_with_dist_prior
+ candA[i][2]
+ candB[j][2],
]
)
connection_candidate = sorted(
connection_candidate, key=lambda x: x[2], reverse=True
)
connection = np.zeros((0, 5))
for c in range(len(connection_candidate)):
i, j, s = connection_candidate[c][0:3]
if i not in connection[:, 3] and j not in connection[:, 4]:
connection = np.vstack(
[connection, [candA[i][3], candB[j][3], s, i, j]]
)
if len(connection) >= min(nA, nB):
break
connection_all.append(connection)
else:
special_k.append(k)
connection_all.append([])
# last number in each row is the total parts number of that person
# the second last number in each row is the score of the overall configuration
subset = -1 * np.ones((0, 20))
candidate = np.array(
[item for sublist in all_peaks for item in sublist]
)
for k in range(len(mapIdx)):
if k not in special_k:
partAs = connection_all[k][:, 0]
partBs = connection_all[k][:, 1]
indexA, indexB = np.array(limbSeq[k]) - 1
for i in range(len(connection_all[k])): # = 1:size(temp,1)
found = 0
subset_idx = [-1, -1]
for j in range(len(subset)): # 1:size(subset,1):
if (
subset[j][indexA] == partAs[i]
or subset[j][indexB] == partBs[i]
):
subset_idx[found] = j
found += 1
if found == 1:
j = subset_idx[0]
if subset[j][indexB] != partBs[i]:
subset[j][indexB] = partBs[i]
subset[j][-1] += 1
subset[j][-2] += (
candidate[partBs[i].astype(int), 2]
+ connection_all[k][i][2]
)
elif found == 2: # if found 2 and disjoint, merge them
j1, j2 = subset_idx
membership = (
(subset[j1] >= 0).astype(int)
+ (subset[j2] >= 0).astype(int)
)[:-2]
if len(np.nonzero(membership == 2)[0]) == 0: # merge
subset[j1][:-2] += subset[j2][:-2] + 1
subset[j1][-2:] += subset[j2][-2:]
subset[j1][-2] += connection_all[k][i][2]
subset = np.delete(subset, j2, 0)
else: # as like found == 1
subset[j1][indexB] = partBs[i]
subset[j1][-1] += 1
subset[j1][-2] += (
candidate[partBs[i].astype(int), 2]
+ connection_all[k][i][2]
)
# if find no partA in the subset, create a new subset
elif not found and k < 17:
row = -1 * np.ones(20)
row[indexA] = partAs[i]
row[indexB] = partBs[i]
row[-1] = 2
row[-2] = (
sum(
candidate[
connection_all[k][i, :2].astype(int), 2
]
)
+ connection_all[k][i][2]
)
subset = np.vstack([subset, row])
# delete some rows of subset which has few parts occur
deleteIdx = []
for i in range(len(subset)):
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
deleteIdx.append(i)
subset = np.delete(subset, deleteIdx, axis=0)
# candidate: x, y, score, id
return candidate, subset

View File

@@ -0,0 +1,205 @@
import cv2
import numpy as np
from scipy.ndimage.filters import gaussian_filter
import torch
import torch.nn as nn
from skimage.measure import label
from collections import OrderedDict
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
make_layers,
transfer,
padRightDownCorner,
npmax,
)
class HandPoseModel(nn.Module):
def __init__(self):
super(HandPoseModel, self).__init__()
# these layers have no relu layer
no_relu_layers = [
"conv6_2_CPM",
"Mconv7_stage2",
"Mconv7_stage3",
"Mconv7_stage4",
"Mconv7_stage5",
"Mconv7_stage6",
]
# stage 1
block1_0 = OrderedDict(
[
("conv1_1", [3, 64, 3, 1, 1]),
("conv1_2", [64, 64, 3, 1, 1]),
("pool1_stage1", [2, 2, 0]),
("conv2_1", [64, 128, 3, 1, 1]),
("conv2_2", [128, 128, 3, 1, 1]),
("pool2_stage1", [2, 2, 0]),
("conv3_1", [128, 256, 3, 1, 1]),
("conv3_2", [256, 256, 3, 1, 1]),
("conv3_3", [256, 256, 3, 1, 1]),
("conv3_4", [256, 256, 3, 1, 1]),
("pool3_stage1", [2, 2, 0]),
("conv4_1", [256, 512, 3, 1, 1]),
("conv4_2", [512, 512, 3, 1, 1]),
("conv4_3", [512, 512, 3, 1, 1]),
("conv4_4", [512, 512, 3, 1, 1]),
("conv5_1", [512, 512, 3, 1, 1]),
("conv5_2", [512, 512, 3, 1, 1]),
("conv5_3_CPM", [512, 128, 3, 1, 1]),
]
)
block1_1 = OrderedDict(
[
("conv6_1_CPM", [128, 512, 1, 1, 0]),
("conv6_2_CPM", [512, 22, 1, 1, 0]),
]
)
blocks = {}
blocks["block1_0"] = block1_0
blocks["block1_1"] = block1_1
# stage 2-6
for i in range(2, 7):
blocks["block%d" % i] = OrderedDict(
[
("Mconv1_stage%d" % i, [150, 128, 7, 1, 3]),
("Mconv2_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d" % i, [128, 22, 1, 1, 0]),
]
)
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_0 = blocks["block1_0"]
self.model1_1 = blocks["block1_1"]
self.model2 = blocks["block2"]
self.model3 = blocks["block3"]
self.model4 = blocks["block4"]
self.model5 = blocks["block5"]
self.model6 = blocks["block6"]
def forward(self, x):
out1_0 = self.model1_0(x)
out1_1 = self.model1_1(out1_0)
concat_stage2 = torch.cat([out1_1, out1_0], 1)
out_stage2 = self.model2(concat_stage2)
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
out_stage3 = self.model3(concat_stage3)
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
out_stage4 = self.model4(concat_stage4)
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
out_stage5 = self.model5(concat_stage5)
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
out_stage6 = self.model6(concat_stage6)
return out_stage6
class Hand(object):
def __init__(self, model_path):
self.model = HandPoseModel()
if torch.cuda.is_available():
self.model = self.model.cuda()
model_dict = transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def __call__(self, oriImg):
scale_search = [0.5, 1.0, 1.5, 2.0]
# scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre = 0.05
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
# paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = cv2.resize(
oriImg,
(0, 0),
fx=scale,
fy=scale,
interpolation=cv2.INTER_CUBIC,
)
imageToTest_padded, pad = padRightDownCorner(
imageToTest, stride, padValue
)
im = (
np.transpose(
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
(3, 2, 0, 1),
)
/ 256
- 0.5
)
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
if torch.cuda.is_available():
data = data.cuda()
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
with torch.no_grad():
output = self.model(data).cpu().numpy()
# output = self.model(data).numpy()q
# extract outputs, resize, and remove padding
heatmap = np.transpose(
np.squeeze(output), (1, 2, 0)
) # output 1 is heatmaps
heatmap = cv2.resize(
heatmap,
(0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC,
)
heatmap = heatmap[
: imageToTest_padded.shape[0] - pad[2],
: imageToTest_padded.shape[1] - pad[3],
:,
]
heatmap = cv2.resize(
heatmap,
(oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC,
)
heatmap_avg += heatmap / len(multiplier)
all_peaks = []
for part in range(21):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
# 全部小于阈值
if np.sum(binary) == 0:
all_peaks.append([0, 0])
continue
label_img, label_numbers = label(
binary, return_num=True, connectivity=binary.ndim
)
max_index = (
np.argmax(
[
np.sum(map_ori[label_img == i])
for i in range(1, label_numbers + 1)
]
)
+ 1
)
label_img[label_img != max_index] = 0
map_ori[label_img == 0] = 0
y, x = npmax(map_ori)
all_peaks.append([x, y])
return np.array(all_peaks)

View File

@@ -0,0 +1,272 @@
import math
import numpy as np
import matplotlib
import cv2
from collections import OrderedDict
import torch.nn as nn
def make_layers(block, no_relu_layers):
layers = []
for layer_name, v in block.items():
if "pool" in layer_name:
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
layers.append((layer_name, layer))
else:
conv2d = nn.Conv2d(
in_channels=v[0],
out_channels=v[1],
kernel_size=v[2],
stride=v[3],
padding=v[4],
)
layers.append((layer_name, conv2d))
if layer_name not in no_relu_layers:
layers.append(("relu_" + layer_name, nn.ReLU(inplace=True)))
return nn.Sequential(OrderedDict(layers))
def padRightDownCorner(img, stride, padValue):
h = img.shape[0]
w = img.shape[1]
pad = 4 * [None]
pad[0] = 0 # up
pad[1] = 0 # left
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
img_padded = img
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
img_padded = np.concatenate((pad_up, img_padded), axis=0)
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
img_padded = np.concatenate((pad_left, img_padded), axis=1)
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
img_padded = np.concatenate((img_padded, pad_down), axis=0)
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
img_padded = np.concatenate((img_padded, pad_right), axis=1)
return img_padded, pad
# transfer caffe model to pytorch which will match the layer name
def transfer(model, model_weights):
transfered_model_weights = {}
for weights_name in model.state_dict().keys():
transfered_model_weights[weights_name] = model_weights[
".".join(weights_name.split(".")[1:])
]
return transfered_model_weights
# draw the body keypoint and lims
def draw_bodypose(canvas, candidate, subset):
stickwidth = 4
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
[3, 17],
[6, 18],
]
colors = [
[255, 0, 0],
[255, 85, 0],
[255, 170, 0],
[255, 255, 0],
[170, 255, 0],
[85, 255, 0],
[0, 255, 0],
[0, 255, 85],
[0, 255, 170],
[0, 255, 255],
[0, 170, 255],
[0, 85, 255],
[0, 0, 255],
[85, 0, 255],
[170, 0, 255],
[255, 0, 255],
[255, 0, 170],
[255, 0, 85],
]
for i in range(18):
for n in range(len(subset)):
index = int(subset[n][i])
if index == -1:
continue
x, y = candidate[index][0:2]
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
for i in range(17):
for n in range(len(subset)):
index = subset[n][np.array(limbSeq[i]) - 1]
if -1 in index:
continue
cur_canvas = canvas.copy()
Y = candidate[index.astype(int), 0]
X = candidate[index.astype(int), 1]
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly(
(int(mY), int(mX)),
(int(length / 2), stickwidth),
int(angle),
0,
360,
1,
)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
return canvas
# image drawed by opencv is not good.
def draw_handpose(canvas, all_hand_peaks, show_number=False):
edges = [
[0, 1],
[1, 2],
[2, 3],
[3, 4],
[0, 5],
[5, 6],
[6, 7],
[7, 8],
[0, 9],
[9, 10],
[10, 11],
[11, 12],
[0, 13],
[13, 14],
[14, 15],
[15, 16],
[0, 17],
[17, 18],
[18, 19],
[19, 20],
]
for peaks in all_hand_peaks:
for ie, e in enumerate(edges):
if np.sum(np.all(peaks[e], axis=1) == 0) == 0:
x1, y1 = peaks[e[0]]
x2, y2 = peaks[e[1]]
cv2.line(
canvas,
(x1, y1),
(x2, y2),
matplotlib.colors.hsv_to_rgb(
[ie / float(len(edges)), 1.0, 1.0]
)
* 255,
thickness=2,
)
for i, keyponit in enumerate(peaks):
x, y = keyponit
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
if show_number:
cv2.putText(
canvas,
str(i),
(x, y),
cv2.FONT_HERSHEY_SIMPLEX,
0.3,
(0, 0, 0),
lineType=cv2.LINE_AA,
)
return canvas
# detect hand according to body pose keypoints
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
def handDetect(candidate, subset, oriImg):
# right hand: wrist 4, elbow 3, shoulder 2
# left hand: wrist 7, elbow 6, shoulder 5
ratioWristElbow = 0.33
detect_result = []
image_height, image_width = oriImg.shape[0:2]
for person in subset.astype(int):
# if any of three not detected
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
if not (has_left or has_right):
continue
hands = []
# left hand
if has_left:
left_shoulder_index, left_elbow_index, left_wrist_index = person[
[5, 6, 7]
]
x1, y1 = candidate[left_shoulder_index][:2]
x2, y2 = candidate[left_elbow_index][:2]
x3, y3 = candidate[left_wrist_index][:2]
hands.append([x1, y1, x2, y2, x3, y3, True])
# right hand
if has_right:
(
right_shoulder_index,
right_elbow_index,
right_wrist_index,
) = person[[2, 3, 4]]
x1, y1 = candidate[right_shoulder_index][:2]
x2, y2 = candidate[right_elbow_index][:2]
x3, y3 = candidate[right_wrist_index][:2]
hands.append([x1, y1, x2, y2, x3, y3, False])
for x1, y1, x2, y2, x3, y3, is_left in hands:
x = x3 + ratioWristElbow * (x3 - x2)
y = y3 + ratioWristElbow * (y3 - y2)
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
# x-y refers to the center --> offset to topLeft point
x -= width / 2
y -= width / 2 # width = height
# overflow the image
if x < 0:
x = 0
if y < 0:
y = 0
width1 = width
width2 = width
if x + width > image_width:
width1 = image_width - x
if y + width > image_height:
width2 = image_height - y
width = min(width1, width2)
# the max hand box value is 20 pixels
if width >= 20:
detect_result.append([int(x), int(y), int(width), is_left])
"""
return value: [[x, y, w, True if left hand else False]].
width=height since the network require squared input.
x, y is the coordinate of top left
"""
return detect_result
# get max index of 2d array
def npmax(array):
arrayindex = array.argmax(1)
arrayvalue = array.max(1)
i = arrayvalue.argmax()
j = arrayindex[i]
return (i,)

View File

@@ -0,0 +1,186 @@
import numpy as np
from PIL import Image
import torch
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
)
stencil = {}
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def controlnet_hint_shaping(
controlnet_hint, height, width, dtype, num_images_per_prompt=1
):
channels = 3
if isinstance(controlnet_hint, torch.Tensor):
# torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt)
shape_chw = (channels, height, width)
shape_bchw = (1, channels, height, width)
shape_nchw = (num_images_per_prompt, channels, height, width)
if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]:
controlnet_hint = controlnet_hint.to(
dtype=dtype, device=torch.device("cpu")
)
if controlnet_hint.shape != shape_nchw:
controlnet_hint = controlnet_hint.repeat(
num_images_per_prompt, 1, 1, 1
)
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width}),"
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, np.ndarray):
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
# hwc is opencv compatible image format. Color channel must be BGR Format.
if controlnet_hint.shape == (height, width):
controlnet_hint = np.repeat(
controlnet_hint[:, :, np.newaxis], channels, axis=2
) # hw -> hwc(c==3)
shape_hwc = (height, width, channels)
shape_bhwc = (1, height, width, channels)
shape_nhwc = (num_images_per_prompt, height, width, channels)
if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]:
controlnet_hint = torch.from_numpy(controlnet_hint.copy())
controlnet_hint = controlnet_hint.to(
dtype=dtype, device=torch.device("cpu")
)
controlnet_hint /= 255.0
if controlnet_hint.shape != shape_nhwc:
controlnet_hint = controlnet_hint.repeat(
num_images_per_prompt, 1, 1, 1
)
controlnet_hint = controlnet_hint.permute(
0, 3, 1, 2
) # b h w c -> b c h w
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({width}, {channels}), "
+ f"({height}, {width}, {channels}), "
+ f"(1, {height}, {width}, {channels}) or "
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, Image.Image):
if controlnet_hint.size == (width, height):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
controlnet_hint = np.array(controlnet_hint) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return controlnet_hint_shaping(
controlnet_hint, height, width, num_images_per_prompt
)
else:
raise ValueError(
f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}"
)
else:
raise ValueError(
f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
)
def controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
):
controlnet_hint = None
match use_stencil:
case "canny":
print("Detecting edge with canny")
controlnet_hint = hint_canny(image)
case "openpose":
print("Detecting human pose")
controlnet_hint = hint_openpose(image)
case "scribble":
print("Working with scribble")
controlnet_hint = hint_scribble(image)
case _:
return None
controlnet_hint = controlnet_hint_shaping(
controlnet_hint, height, width, dtype, num_images_per_prompt
)
return controlnet_hint
stencil_to_model_id_map = {
"canny": "lllyasviel/sd-controlnet-canny",
"depth": "lllyasviel/sd-controlnet-depth",
"hed": "lllyasviel/sd-controlnet-hed",
"mlsd": "lllyasviel/sd-controlnet-mlsd",
"normal": "lllyasviel/sd-controlnet-normal",
"openpose": "lllyasviel/sd-controlnet-openpose",
"scribble": "lllyasviel/sd-controlnet-scribble",
"seg": "lllyasviel/sd-controlnet-seg",
}
def get_stencil_model_id(use_stencil):
if use_stencil in stencil_to_model_id_map:
return stencil_to_model_id_map[use_stencil]
return None
# Stencil 1. Canny
def hint_canny(
image: Image.Image,
low_threshold=100,
high_threshold=200,
):
with torch.no_grad():
input_image = np.array(image)
if not "canny" in stencil:
stencil["canny"] = CannyDetector()
detected_map = stencil["canny"](
input_image, low_threshold, high_threshold
)
detected_map = HWC3(detected_map)
return detected_map
# Stencil 2. OpenPose.
def hint_openpose(
image: Image.Image,
):
with torch.no_grad():
input_image = np.array(image)
if not "openpose" in stencil:
stencil["openpose"] = OpenposeDetector()
detected_map, _ = stencil["openpose"](input_image)
detected_map = HWC3(detected_map)
return detected_map
# Stencil 3. Scribble.
def hint_scribble(image: Image.Image):
with torch.no_grad():
input_image = np.array(image)
detected_map = np.zeros_like(input_image, dtype=np.uint8)
detected_map[np.min(input_image, axis=2) < 127] = 255
return detected_map

View File

@@ -0,0 +1,795 @@
import os
import gc
import json
import re
from PIL import PngImagePlugin
from PIL import Image
from datetime import datetime as dt
from csv import DictWriter
from pathlib import Path
import numpy as np
from random import randint
import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
def get_extended_name(model_name):
device = args.device.split("://", 1)[0]
extended_name = "{}_{}".format(model_name, device)
return extended_name
def get_vmfb_path_name(model_name):
vmfb_path = os.path.join(os.getcwd(), model_name + ".vmfb")
return vmfb_path
def _load_vmfb(shark_module, vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
vmfb_path = get_vmfb_path_name(model_name)
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), model_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.parser import shark_args
# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache
from shark.shark_downloader import download_model
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=tank_url,
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="tm_tensor"
)
return _compile_module(shark_module, model_name, extra_args)
# Converts the torch-module into a shark_module.
def compile_through_fx(
model,
inputs,
extended_model_name,
is_f16=False,
f16_input_mask=None,
use_tuned=False,
save_dir=tempfile.gettempdir(),
debug=False,
generate_vmfb=True,
extra_args=[],
base_model_id=None,
model_name=None,
precision=None,
return_mlir=False,
):
if not return_mlir and model_name is not None:
vmfb_path = get_vmfb_path_name(extended_model_name)
if os.path.isfile(vmfb_path):
shark_module = SharkInference(mlir_module=None, device=args.device)
return (
_load_vmfb(shark_module, vmfb_path, model_name, precision),
None,
)
from shark.parser import shark_args
if "cuda" in args.device:
shark_args.enable_tf32 = True
(
mlir_module,
func_name,
) = import_with_fx(
model=model,
inputs=inputs,
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=extended_model_name,
save_dir=save_dir,
)
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
if "unet" in model_name.split("_")[0]:
args.annotation_model = "unet"
mlir_module = sd_model_annotation(
mlir_module, extended_model_name, base_model_id
)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="tm_tensor",
)
if generate_vmfb:
return (
_compile_module(shark_module, extended_model_name, extra_args),
mlir_module,
)
del mlir_module
gc.collect()
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def set_init_device_flags():
if "vulkan" in args.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "cpu" in args.device:
args.device = "cpu"
# set max_length based on availability.
if args.hf_model_id in [
"Linaqruf/anything-v3.0",
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]:
args.max_length = 77
elif args.hf_model_id == "prompthero/openjourney":
args.max_length = 64
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
if (
args.precision != "fp16"
or args.height not in [512, 768]
or (args.height == 512 and args.width != 512)
or (args.height == 768 and args.width != 768)
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
args.use_tuned = False
elif base_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
"runwayml/stable-diffusion-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
]:
args.use_tuned = False
elif "vulkan" in args.device and not any(
x in args.iree_vulkan_target_triple for x in ["rdna2", "rdna3"]
):
args.use_tuned = False
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]:
args.use_tuned = False
elif args.use_base_vae and args.hf_model_id not in [
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.use_tuned = False
elif (
args.height == 768
and args.width == 768
and (
base_model_id
not in [
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
]
or "rdna3" not in args.iree_vulkan_target_triple
)
):
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
else:
print("Tuned models are currently not supported for this setting.")
# set import_mlir to True for unuploaded models.
if args.ckpt_loc != "":
args.import_mlir = True
elif args.hf_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.import_mlir = True
elif args.height != 512 or args.width != 512 or args.batch_size != 1:
args.import_mlir = True
elif args.use_tuned and args.hf_model_id in [
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"stabilityai/stable-diffusion-2-1",
]:
args.import_mlir = True
elif (
args.use_tuned
and "vulkan" in args.device
and "rdna2" in args.iree_vulkan_target_triple
):
args.import_mlir = True
elif (
args.use_tuned
and "cuda" in args.device
and get_cuda_sm_cc() == "sm_89"
):
args.import_mlir = True
# Utility to get list of devices available.
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
for i, device in enumerate(device_list_dict):
device_list.append(f"{device['name']} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("device => cpu")
return available_devices
def disk_space_check(path, lim=20):
from shutil import disk_usage
du = disk_usage(path)
free = du.free / (1024 * 1024 * 1024)
if free <= lim:
print(f"[WARNING] Only {free:.2f}GB space available in {path}.")
def get_opt_flags(model, precision="fp16"):
iree_flags = []
is_tuned = "tuned" if args.use_tuned else "untuned"
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
"default_compilation_flags"
]
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
if (
device
not in opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
]
):
device = "default_device"
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
return iree_flags
def get_path_stem(path):
path = Path(path)
return path.stem
def get_path_to_diffusers_checkpoint(custom_weights):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = path.stem
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
return path_to_diffusers
def preprocessCKPT(custom_weights, is_inpaint=False):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
return
else:
print(
"Diffusers' checkpoint will be identified here : ",
path_to_diffusers,
)
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
extract_ema = False
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
)
pipe.save_pretrained(path_to_diffusers)
print("Loading complete")
def processLoRA(model, use_lora, splitting_prefix):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
alpha = 0.75
visited = []
# directly update weight in model
process_unet = "te" not in splitting_prefix
for key in state_dict:
if ".alpha" in key or key in visited:
continue
curr_layer = model
if ("text" not in key and process_unet) or (
"text" in key and not process_unet
):
layer_infos = (
key.split(".")[0].split(splitting_prefix)[-1].split("_")
)
else:
continue
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = (
state_dict[pair_keys[0]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
weight_down = (
state_dict[pair_keys[1]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
return model
def update_lora_weight_for_unet(unet, use_lora):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
return unet
main_file_name = get_path_stem(use_lora)
if ".bin" in use_lora:
main_file_name += ".bin"
elif ".safetensors" in use_lora:
main_file_name += ".safetensors"
elif ".pt" in use_lora:
main_file_name += ".pt"
else:
sys.exit("Only .bin and .safetensors format for LoRA is supported")
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_")
def update_lora_weight(model, use_lora, model_name):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora)
try:
return processLoRA(model, use_lora, "lora_te_")
except:
return None
# `fetch_and_update_base_model_id` is a resource utility function which
# helps maintaining mapping of the model to run with its base model.
# If `base_model` is "", then this function tries to fetch the base model
# info for the `model_to_run`.
def fetch_and_update_base_model_id(model_to_run, base_model=""):
variants_path = os.path.join(os.getcwd(), "variants.json")
data = {model_to_run: base_model}
json_data = {}
if os.path.exists(variants_path):
with open(variants_path, "r", encoding="utf-8") as jsonFile:
json_data = json.load(jsonFile)
# Return with base_model's info if base_model is "".
if base_model == "":
if model_to_run in json_data:
base_model = json_data[model_to_run]
return base_model
elif base_model == "":
return base_model
# Update JSON data to contain an entry mapping model_to_run with base_model.
json_data.update(data)
with open(variants_path, "w", encoding="utf-8") as jsonFile:
json.dump(json_data, jsonFile)
# Generate and return a new seed if the provided one is not in the supported range (including -1)
def sanitize_seed(seed):
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed
# clear all the cached objects to recompile cleanly.
def clear_all():
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
from glob import glob
import shutil
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
if os.path.exists(yaml):
os.remove(yaml)
home = os.path.expanduser("~")
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(
os.path.join(home, ".local/shark_tank"), ignore_errors=True
)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info={}):
output_path = args.output_dir if args.output_dir else Path.cwd()
generated_imgs_path = Path(
output_path, "generated_imgs", dt.now().strftime("%Y%m%d")
)
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
img_model = args.hf_model_id
if args.ckpt_loc:
img_model = Path(os.path.basename(args.ckpt_loc)).stem
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
)
new_entry = {
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
"PROMPT": args.prompts[0],
"NEG_PROMPT": args.negative_prompts[0],
"SEED": img_seed,
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
}
new_entry.update(extra_info)
with open(csv_path, "a", encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
dictwriter_obj.writerow(new_entry)
csv_obj.close()
if args.save_metadata_to_json:
del new_entry["OUTPUT"]
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
return text_output
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image):
width, height = image.size
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 128:
n_size = 128
if width == min_size:
width = n_size
height = n_size / aspect_ratio
else:
height = n_size
width = n_size * aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
min_size = min(width, height)
if min_size > 768:
n_size = 768
if width == min_size:
height = n_size
width = n_size * aspect_ratio
else:
width = n_size
height = n_size / aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
new_image = image.resize((n_width, n_height))
return new_image, n_width, n_height

View File

@@ -0,0 +1,15 @@
You need to pre-create your bot (https://core.telegram.org/bots#how-do-i-create-a-bot)
Then create in the directory web file .env
In it the record:
TG_TOKEN="your_token"
specifying your bot's token from previous step.
Then run telegram_bot.py with the same parameters that you use when running index.py, for example:
python telegram_bot.py --max_length=77 --vulkan_large_heap_block_size=0 --use_base_vae --local_tank_cache h:\shark\TEMP
Bot commands:
/select_model
/select_scheduler
/set_steps "integer number of steps"
/set_guidance_scale "integer number"
/set_negative_prompt "negative text"
Any other text triggers the creation of an image based on it.

View File

@@ -0,0 +1,219 @@
import os
import sys
import transformers
from apps.stable_diffusion.src import args, clear_all
import apps.stable_diffusion.web.utils.global_obj as global_obj
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
if args.clear_all:
clear_all()
if __name__ == "__main__":
if args.api:
from apps.stable_diffusion.web.ui import txt2img_inf, img2img_api
from fastapi import FastAPI, APIRouter
import uvicorn
# init global sd pipeline and config
global_obj._init()
app = FastAPI()
app.add_api_route("/sdapi/v1/txt2img", txt2img_inf, methods=["post"])
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
app.include_router(APIRouter())
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
sys.exit(0)
import gradio as gr
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
)
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create the custom model folder if it doesn't already exist
dir = ["models", "vae", "lora"]
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_gallery,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
img2img_web,
img2img_gallery,
img2img_init_image,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
img2img_sendto_upscaler,
inpaint_web,
inpaint_gallery,
inpaint_init_image,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
inpaint_sendto_upscaler,
outpaint_web,
outpaint_gallery,
outpaint_init_image,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
outpaint_sendto_upscaler,
upscaler_web,
upscaler_gallery,
upscaler_init_image,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
lora_train_web,
)
# init global sd pipeline and config
global_obj._init()
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
img2img_web.render()
with gr.TabItem(label="Inpainting", id=2):
inpaint_web.render()
with gr.TabItem(label="Outpainting", id=3):
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.Tabs(visible=False) as experimental_tabs:
with gr.TabItem(label="LoRA Training", id=5):
lora_train_web.render()
register_button_click(
txt2img_sendto_img2img,
1,
[txt2img_gallery],
[img2img_init_image, tabs],
)
register_button_click(
txt2img_sendto_inpaint,
2,
[txt2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_outpaint,
3,
[txt2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_upscaler,
4,
[txt2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
img2img_sendto_inpaint,
2,
[img2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_outpaint,
3,
[img2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_upscaler,
4,
[img2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
inpaint_sendto_img2img,
1,
[inpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
inpaint_sendto_outpaint,
3,
[inpaint_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
inpaint_sendto_upscaler,
4,
[inpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
outpaint_sendto_img2img,
1,
[outpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
outpaint_sendto_inpaint,
2,
[outpaint_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
outpaint_sendto_upscaler,
4,
[outpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
upscaler_sendto_img2img,
1,
[upscaler_gallery],
[img2img_init_image, tabs],
)
register_button_click(
upscaler_sendto_inpaint,
2,
[upscaler_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
upscaler_sendto_outpaint,
3,
[upscaler_gallery],
[outpaint_init_image, tabs],
)
sd_web.queue()
sd_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

Some files were not shown because too many files have changed in this diff Show More