Compare commits

..

3 Commits

184 changed files with 14871 additions and 6679 deletions

View File

@@ -2,4 +2,4 @@
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py
exclude = lit.cfg.py

View File

@@ -50,13 +50,27 @@ jobs:
shell: powershell
run: |
./setup_venv.ps1
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
python process_skipfiles.py
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
# GHA windows VM OOMs so disable for now
#- name: Build and validate the SHARK Runtime package
# shell: powershell
# run: |
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
#- uses: actions/upload-artifact@v2
# with:
# path: dist/*
- name: Upload Release Assets
id: upload-release-assets
uses: dwenegar/upload-release-assets@v1
@@ -64,7 +78,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
assets_path: ./dist/nodai*
assets_path: ./dist/*
#asset_content_type: application/vnd.microsoft.portable-executable
- name: Publish Release
@@ -104,7 +118,7 @@ jobs:
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
@@ -144,7 +158,7 @@ jobs:
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models

161
.github/workflows/test-models.yml vendored Normal file
View File

@@ -0,0 +1,161 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Validate Models on Shark Runtime
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
suite: [cpu,cuda,vulkan]
python-version: ["3.11"]
include:
- os: ubuntu-latest
suite: lint
exclude:
- os: ubuntu-latest
suite: vulkan
- os: ubuntu-latest
suite: cuda
- os: ubuntu-latest
suite: cpu
- os: MacStudio
suite: cuda
- os: MacStudio
suite: cpu
- os: icelake
suite: vulkan
- os: icelake
suite: cuda
- os: a100
suite: cpu
- os: 7950x
suite: cpu
- os: 7950x
suite: cuda
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
if: matrix.os != '7950x'
- name: Set Environment Variables
if: matrix.os != '7950x'
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
run: |
# See https://github.com/actions/setup-python/issues/433
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'
#cache: 'pip'
#cache-dependency-path: |
# **/requirements-importer.txt
# **/requirements.txt
- uses: actions/checkout@v2
if: matrix.os == '7950x'
- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --check .
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
- name: Validate Models on NVIDIA GPU
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan --update_tank
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
pytest -k vulkan -s
- name: Validate Stable Diffusion Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
python build_tools/stable_diffusion_testing.py --device=vulkan

View File

@@ -1,86 +0,0 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Validate Shark Studio
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [nodai-ubuntu-builder-large]
suite: [cpu] #,cuda,vulkan]
python-version: ["3.11"]
include:
- os: nodai-ubuntu-builder-large
suite: lint
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set Environment Variables
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
run: |
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'
- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --check apps/shark_studio
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
python${{ matrix.python-version }} -m venv shark.venv
source shark.venv/bin/activate
pip install -r requirements.txt --no-cache-dir
pip install -e .
pip uninstall -y torch
pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python apps/shark_studio/tests/api_test.py

19
.gitignore vendored
View File

@@ -2,8 +2,6 @@
__pycache__/
*.py[cod]
*$py.class
*.mlir
*.vmfb
# C extensions
*.so
@@ -159,7 +157,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
#.idea/
# vscode related
.vscode
@@ -182,23 +180,10 @@ generated_imgs/
# Custom model related artefacts
variants.json
/models/
models/
# models folder
apps/stable_diffusion/web/models/
# Stencil annotators.
stencil_annotator/
# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData
# Embeded browser cache and other
apps/stable_diffusion/web/EBWebView/
# Llama2 tokenizer configs
llama2_tokenizer_configs/
# Webview2 runtime artefacts
EBWebView/

2
.gitmodules vendored
View File

@@ -1,4 +1,4 @@
[submodule "inference/thirdparty/shark-runtime"]
path = inference/thirdparty/shark-runtime
url =https://github.com/nod-ai/SRT.git
url =https://github.com/nod-ai/SHARK-Runtime.git
branch = shark-06032022

View File

@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
@@ -114,12 +114,12 @@ source shark.venv/bin/activate
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
```shell
python3.11 apps/stable_diffusion/scripts/main.py --app=txt2img --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
python3.11 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
@@ -170,7 +170,7 @@ python -m pip install --upgrade pip
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
```shell
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```
### Run shark tank model tests.
@@ -254,6 +254,7 @@ if you want to instead incorporate this into a python script, you can pass the `
```
shark_module = SharkInference(
mlir_model,
func_name,
device=args.device,
mlir_dialect="tm_tensor",
dispatch_benchmarks="all",
@@ -296,7 +297,7 @@ torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
from shark.shark_inference import SharkInference
shark_module = SharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input))
@@ -319,17 +320,12 @@ mhlo_ir = r"""builtin.module {
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
shark_module.compile()
result = shark_module.forward((arg0, arg1))
```
</details>
## Examples Using the REST API
* [Setting up SHARK for use with Blender](./docs/shark_sd_blender.md)
* [Setting up SHARK for use with Koboldcpp](./docs/shark_sd_koboldcpp.md)
## Supported and Validated Models
SHARK is maintained to support the latest innovations in ML Models:

View File

@@ -1,179 +0,0 @@
from turbine_models.custom_models import stateless_llama
import time
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
)
from apps.shark_studio.api.utils import get_resource_path
import iree.runtime as ireert
from itertools import chain
import gc
import os
import torch
from transformers import AutoTokenizer
llm_model_map = {
"llama2_7b": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
}
class LanguageModel:
def __init__(
self,
model_name,
hf_auth_token=None,
device=None,
precision="fp32",
external_weights=None,
use_system_prompt=True,
):
print(llm_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.tempfile_name = get_resource_path("llm.torch.tempfile")
self.vmfb_name = get_resource_path("llm.vmfb.tempfile")
self.device = device
self.precision = precision
self.safe_name = self.hf_model_name.strip("/").replace("/", "_")
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.external_weight_file = None
if external_weights is not None:
self.external_weight_file = get_resource_path(
self.safe_name + "." + external_weights
)
self.use_system_prompt = use_system_prompt
self.global_iter = 0
if os.path.exists(self.vmfb_name) and (
external_weights is None or os.path.exists(str(self.external_weight_file))
):
self.iree_module_dict = dict()
(
self.iree_module_dict["vmfb"],
self.iree_module_dict["config"],
self.iree_module_dict["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
self.vmfb_name,
device,
device_idx=0,
rt_flags=[],
external_weight_file=self.external_weight_file,
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
external_weights=external_weights,
external_weight_file=self.external_weight_file,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.compile()
else:
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
self.compile()
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(
self.tempfile_name,
device=self.device,
mmap=True,
frontend="torch",
external_weight_file=self.external_weight_file,
write_to=self.vmfb_name,
extra_args=["--iree-global-opt-enable-quantized-matmul-reassociation"],
)
# TODO: delete the temp file
def sanitize_prompt(self, prompt):
print(prompt)
if isinstance(prompt, list):
prompt = list(chain.from_iterable(prompt))
prompt = " ".join([x for x in prompt if isinstance(x, str)])
prompt = prompt.replace("\n", " ")
prompt = prompt.replace("\t", " ")
prompt = prompt.replace("\r", " ")
if self.use_system_prompt and self.global_iter == 0:
prompt = llm_model_map["llama2_7b"]["system_prompt"] + prompt
prompt += " [/INST]"
print(prompt)
return prompt
def chat(self, prompt):
prompt = self.sanitize_prompt(prompt)
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
def format_out(results):
return torch.tensor(results.to_host()[0][0])
history = []
for iter in range(self.max_tokens):
st_time = time.time()
if iter == 0:
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device, input_tensor
)
]
token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs)
else:
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device,
token,
)
]
token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs)
total_time = time.time() - st_time
history.append(format_out(token))
yield self.tokenizer.decode(history), total_time
if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
self.global_iter += 1
return result_output, total_time
if __name__ == "__main__":
lm = LanguageModel(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
)
print("model loaded")
for i in lm.chat("hi, what are you?"):
print(i)

View File

@@ -1,12 +0,0 @@
import os
import sys
def get_available_devices():
return ["cpu-task"]
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)

View File

@@ -1,34 +0,0 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import logging
import unittest
from apps.shark_studio.api.llm import LanguageModel
class LLMAPITest(unittest.TestCase):
def testLLMSimple(self):
lm = LanguageModel(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
)
count = 0
for msg, _ in lm.chat("hi, what are you?"):
# skip first token output
if count == 0:
count += 1
continue
assert (
msg.strip(" ") == "Hello"
), f"LLM API failed to return correct response, expected 'Hello', received {msg}"
break
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

View File

@@ -1,426 +0,0 @@
from multiprocessing import Process, freeze_support
import os
import sys
import logging
from ui.chat import chat_element
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
# import before IREE to avoid MLIR library issues
import torch_mlir
# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
# from apps.stable_diffusion.src import args, clear_all
# import apps.stable_diffusion.web.utils.global_obj as global_obj
def launch_app(address):
from tkinter import Tk
import webview
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())
if __name__ == "__main__":
# if args.debug:
logging.basicConfig(level=logging.DEBUG)
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
# if args.api or "api" in args.ui.split(","):
# from apps.stable_diffusion.web.ui import (
# txt2img_api,
# img2img_api,
# upscaler_api,
# inpaint_api,
# outpaint_api,
# llm_chat_api,
# )
#
# from fastapi import FastAPI, APIRouter
# import uvicorn
#
# # init global sd pipeline and config
# global_obj._init()
#
# app = FastAPI()
# app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
#
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
# app.add_api_route(
# "/v1/chat/completions", llm_chat_api, methods=["post"]
# )
# app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/completions", llm_chat_api, methods=["post"])
# app.add_api_route(
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
# )
# app.include_router(APIRouter())
# uvicorn.run(app, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
#
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
# from apps.stable_diffusion.web.utils.gradio_configs import (
# config_gradio_tmp_imgs_folder,
# )
# config_gradio_tmp_imgs_folder()
import gradio as gr
# Create custom models folders if they don't exist
# from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# create_custom_models_folders()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
# from apps.stable_diffusion.web.ui import (
# txt2img_web,
# txt2img_custom_model,
# txt2img_gallery,
# txt2img_png_info_img,
# txt2img_status,
# txt2img_sendto_img2img,
# txt2img_sendto_inpaint,
# txt2img_sendto_outpaint,
# txt2img_sendto_upscaler,
## h2ogpt_upload,
## h2ogpt_web,
# img2img_web,
# img2img_custom_model,
# img2img_gallery,
# img2img_init_image,
# img2img_status,
# img2img_sendto_inpaint,
# img2img_sendto_outpaint,
# img2img_sendto_upscaler,
# inpaint_web,
# inpaint_custom_model,
# inpaint_gallery,
# inpaint_init_image,
# inpaint_status,
# inpaint_sendto_img2img,
# inpaint_sendto_outpaint,
# inpaint_sendto_upscaler,
# outpaint_web,
# outpaint_custom_model,
# outpaint_gallery,
# outpaint_init_image,
# outpaint_status,
# outpaint_sendto_img2img,
# outpaint_sendto_inpaint,
# outpaint_sendto_upscaler,
# upscaler_web,
# upscaler_custom_model,
# upscaler_gallery,
# upscaler_init_image,
# upscaler_status,
# upscaler_sendto_img2img,
# upscaler_sendto_inpaint,
# upscaler_sendto_outpaint,
## lora_train_web,
## model_web,
## model_config_web,
# hf_models,
# modelmanager_sendto_txt2img,
# modelmanager_sendto_img2img,
# modelmanager_sendto_inpaint,
# modelmanager_sendto_outpaint,
# modelmanager_sendto_upscaler,
# stablelm_chat,
# minigpt4_web,
# outputgallery_web,
# outputgallery_tab_select,
# outputgallery_watch,
# outputgallery_filename,
# outputgallery_sendto_txt2img,
# outputgallery_sendto_img2img,
# outputgallery_sendto_inpaint,
# outputgallery_sendto_outpaint,
# outputgallery_sendto_upscaler,
# )
# init global sd pipeline and config
# global_obj._init()
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_modelmanager_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Shark Studio 2.0 Beta"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
# and that the order in the code here is the order they should
# appear in the ui, as the id value doesn't determine the order.
# Where possible, avoid changing the id of any tab that is the
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
# with gr.TabItem(label="Text-to-Image", id=0):
# txt2img_web.render()
# with gr.TabItem(label="Image-to-Image", id=1):
# img2img_web.render()
# with gr.TabItem(label="Inpainting", id=2):
# inpaint_web.render()
# with gr.TabItem(label="Outpainting", id=3):
# outpaint_web.render()
# with gr.TabItem(label="Upscaler", id=4):
# upscaler_web.render()
# if args.output_gallery:
# with gr.TabItem(label="Output Gallery", id=5) as og_tab:
# outputgallery_web.render()
# # extra output gallery configuration
# outputgallery_tab_select(og_tab.select)
# outputgallery_watch(
# [
# txt2img_status,
# img2img_status,
# inpaint_status,
# outpaint_status,
# upscaler_status,
# ]
# )
## with gr.TabItem(label="Model Manager", id=6):
## model_web.render()
## with gr.TabItem(label="LoRA Training (Experimental)", id=7):
## lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=0):
chat_element.render()
## with gr.TabItem(
## label="Generate Sharding Config (Experimental)", id=9
## ):
## model_config_web.render()
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
# minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
# send to buttons
# register_button_click(
# txt2img_sendto_img2img,
# 1,
# [txt2img_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_inpaint,
# 2,
# [txt2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_outpaint,
# 3,
# [txt2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_upscaler,
# 4,
# [txt2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_inpaint,
# 2,
# [img2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_outpaint,
# 3,
# [img2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_upscaler,
# 4,
# [img2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_img2img,
# 1,
# [inpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_outpaint,
# 3,
# [inpaint_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_upscaler,
# 4,
# [inpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_img2img,
# 1,
# [outpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_inpaint,
# 2,
# [outpaint_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_upscaler,
# 4,
# [outpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_img2img,
# 1,
# [upscaler_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_inpaint,
# 2,
# [upscaler_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_outpaint,
# 3,
# [upscaler_gallery],
# [outpaint_init_image, tabs],
# )
# if args.output_gallery:
# register_outputgallery_button(
# outputgallery_sendto_txt2img,
# 0,
# [outputgallery_filename],
# [txt2img_png_info_img, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_img2img,
# 1,
# [outputgallery_filename],
# [img2img_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_inpaint,
# 2,
# [outputgallery_filename],
# [inpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_outpaint,
# 3,
# [outputgallery_filename],
# [outpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_upscaler,
# 4,
# [outputgallery_filename],
# [upscaler_init_image, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_txt2img,
# 0,
# [hf_models],
# [txt2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_img2img,
# 1,
# [hf_models],
# [img2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_inpaint,
# 2,
# [hf_models],
# [inpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_outpaint,
# 3,
# [hf_models],
# [outpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_upscaler,
# 4,
# [hf_models],
# [upscaler_custom_model, tabs],
# )
sd_web.queue()
# if args.ui == "app":
# t = Process(
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
# )
# t.start()
sd_web.launch(
share=True,
inbrowser=True,
server_name="0.0.0.0",
server_port=11911, # args.server_port,
)

View File

@@ -1,298 +0,0 @@
import gradio as gr
import time
import os
from pathlib import Path
from datetime import datetime as dt
import json
import sys
from apps.shark_studio.api.utils import (
get_available_devices,
)
from apps.shark_studio.api.llm import (
llm_model_map,
LanguageModel,
)
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
language_model = None
def create_prompt(model_name, history, prompt_prefix):
return ""
def get_default_config():
return False
# model_vmfb_key = ""
def chat_fn(
prompt_prefix,
history,
model,
device,
precision,
download_vmfb,
config_file,
cli=False,
):
global language_model
if language_model is None:
history[-1][-1] = "Getting the model ready..."
yield history, ""
language_model = LanguageModel(
model,
device=device,
precision=precision,
external_weights="safetensors",
external_weight_file="llama2_7b.safetensors",
use_system_prompt=prompt_prefix,
)
history[-1][-1] = "Getting the model ready... Done"
yield history, ""
history[-1][-1] = ""
token_count = 0
total_time = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, exec_time in language_model.chat(history):
history[-1][-1] = text
if is_first:
prefill_time = exec_time
is_first = False
yield history, f"Prefill: {prefill_time:.2f}"
else:
total_time += exec_time
token_count += 1
tokens_per_sec = token_count / total_time
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
def llm_chat_api(InputData: dict):
return None
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
# if is_chat_completion_api:
# print(f"message -> role : {InputData['messages'][0]['role']}")
# print(f"message -> content : {InputData['messages'][0]['content']}")
# else:
# print(f"prompt : {InputData['prompt']}")
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
global vicuna_model
model_name = InputData["model"] if "model" in InputData.keys() else "codegen"
model_path = llm_model_map[model_name]
device = "cpu-task"
precision = "fp16"
max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"]
if max_toks is None:
max_toks = 128 if model_name == "codegen" else 512
# make it working for codegen first
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
device_id=device_id,
)
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")])
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
res = vicuna_model.generate(prompt)
res_op = None
for op in res:
res_op = op
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion" if is_chat_completion_api else "text_completion",
"created": int(end_time),
"choices": choices,
}
def view_json_file(file_obj):
content = ""
with open(file_obj.name, "r") as fopen:
content = fopen.read()
return content
with gr.Blocks(title="Chat") as chat_element:
with gr.Row():
model_choices = list(llm_model_map.keys())
model = gr.Dropdown(
label="Select Model",
value=model_choices[0],
choices=model_choices,
allow_custom_value=True,
)
supported_devices = get_available_devices()
enabled = True
if len(supported_devices) == 0:
supported_devices = ["cpu-task"]
supported_devices = [x for x in supported_devices if "sync" not in x]
device = gr.Dropdown(
label="Device",
value=supported_devices[0],
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
)
precision = gr.Radio(
label="Precision",
value="int4",
choices=[
# "int4",
# "int8",
# "fp16",
"fp32",
],
visible=False,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
interactive=True,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=False,
interactive=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(label="Upload sharding configuration", visible=False)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
submit_click_event = submit.click(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)

View File

@@ -0,0 +1,87 @@
Compile / Run Instructions:
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
Compile Commands FP32/FP16:
```shell
Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=mhlo for tf models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
CPU:
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
```
Run / Benchmark Command (FP32 - NCHW):
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
```shell
## Vulkan AMD:
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CUDA:
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=cuda --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CPU:
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=local-task --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
Run via vulkan_gui for RGP Profiling:
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
```shell
./build/vulkan_gui/iree-vulkan-gui --module=/path/to/unet.vmfb --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
</details>
<details>
<summary>Debug Commands</summary>
## Debug commands and other advanced usage follows.
```shell
python txt2img.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
```
## dump all dispatch .spv and isa using amdllpc
```shell
python txt2img.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
```
## Compile and save the .vmfb (using vulkan fp16 as an example):
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
```
## Capture an RGP trace
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
```
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
```shell
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf16
```
## Run the unet module with iree-benchmark-module (same config as above):
```shell
##if you want to use .npz inputs:
unzip ~/.local/shark_tank/<your unet>/inputs.npz
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --input=@arr_0.npy --input=1xf16 --input=@arr_2.npy --input=@arr_3.npy --input=@arr_4.npy
```
</details>

View File

@@ -0,0 +1,6 @@
from apps.stable_diffusion.scripts.txt2img import txt2img_inf
from apps.stable_diffusion.scripts.img2img import img2img_inf
from apps.stable_diffusion.scripts.inpaint import inpaint_inf
from apps.stable_diffusion.scripts.outpaint import outpaint_inf
from apps.stable_diffusion.scripts.upscaler import upscaler_inf
from apps.stable_diffusion.scripts.train_lora_word import lora_train

View File

@@ -0,0 +1,382 @@
import sys
import torch
import time
from PIL import Image
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
StencilPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# Both width and height should be > 384 and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image):
width, height = image.size
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 384:
n_size = 384
if width == min_size:
width = n_size
height = n_size / aspect_ratio
else:
height = n_size
width = n_size * aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
new_image = image.resize((n_width, n_height))
return new_image, n_width, n_height
# Exposed to UI.
def img2img_inf(
prompt: str,
negative_prompt: str,
init_image,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.strength = strength
args.scheduler = scheduler
args.img_path = "not none"
if init_image is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB")
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_stencil = None if use_stencil == "None" else use_stencil
args.use_stencil = use_stencil
if use_stencil is not None:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler"
)
args.scheduler = "PNDM"
else:
sys.exit(
"Img2Img works best with PNDM scheduler. Other schedulers are not supported yet."
)
cpu_scheduling = not args.scheduler.startswith("Shark")
args.precision = precision
dtype = torch.float32 if precision == "fp32" else torch.half
new_config_obj = Config(
"img2img",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=use_lora,
use_stencil=use_stencil,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
if use_stencil is not None:
args.use_tuned = False
global_obj.set_sd_obj(
StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
)
else:
global_obj.set_sd_obj(
Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"STRENGTH": strength}
text_output = ""
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
batch_size,
height,
width,
steps,
strength,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
image = Image.open(args.img_path).convert("RGB")
# When the models get uploaded, it should be default to False.
args.import_mlir = True
use_stencil = args.use_stencil
if use_stencil:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, args.width, args.height = resize_stencil(image)
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler"
)
args.scheduler = "PNDM"
else:
sys.exit(
"Img2Img works best with PNDM scheduler. Other schedulers are not supported yet."
)
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model
if use_stencil:
img2img_obj = StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
)
start_time = time.time()
generated_imgs = img2img_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.batch_size,
args.height,
args.width,
args.steps,
args.strength,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, strength={args.strength}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += img2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
extra_info = {"STRENGTH": args.strength}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)

View File

@@ -0,0 +1,287 @@
import torch
import time
from PIL import Image
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def inpaint_inf(
prompt: str,
negative_prompt: str,
image_dict,
height: int,
width: int,
inpaint_full_res: bool,
inpaint_full_res_padding: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
args.mask_path = "not none"
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"inpaint",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=use_lora,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_sd_obj(
InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
mask_image,
batch_size,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
if args.mask_path is None:
print("Flag --mask_path is required.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
model_id = (
args.hf_model_id
if "inpaint" in args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
image = Image.open(args.img_path)
mask_image = Image.open(args.mask_path)
inpaint_obj = InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = inpaint_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
mask_image,
args.batch_size,
args.height,
args.width,
args.inpaint_full_res,
args.inpaint_full_res_padding,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += inpaint_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)

View File

@@ -0,0 +1,312 @@
import torch
import time
from PIL import Image
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def outpaint_inf(
prompt: str,
negative_prompt: str,
init_image,
pixels: int,
mask_blur: int,
directions: list,
noise_q: float,
color_variation: float,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"outpaint",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=use_lora,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_sd_obj(
OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
left = True if "left" in directions else False
right = True if "right" in directions else False
top = True if "up" in directions else False
bottom = True if "down" in directions else False
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
init_image,
pixels,
mask_blur,
left,
right,
top,
bottom,
noise_q,
color_variation,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
model_id = (
args.hf_model_id
if "inpaint" in args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
image = Image.open(args.img_path)
outpaint_obj = OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = outpaint_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.pixels,
args.mask_blur,
args.left,
args.right,
args.top,
args.bottom,
args.noise_q,
args.color_variation,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += outpaint_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
# save this information as metadata of output generated image.
directions = []
if args.left:
directions.append("left")
if args.right:
directions.append("right")
if args.top:
directions.append("up")
if args.bottom:
directions.append("down")
extra_info = {
"PIXELS": args.pixels,
"MASK_BLUR": args.mask_blur,
"DIRECTIONS": directions,
"NOISE_Q": args.noise_q,
"COLOR_VARIATION": args.color_variation,
}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)

View File

@@ -0,0 +1,240 @@
import logging
import os
from models.stable_diffusion.main import stable_diff_inf
from models.stable_diffusion.utils import get_available_devices
from dotenv import load_dotenv
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram import BotCommand
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
from io import BytesIO
import random
log = logging.getLogger("TG.Bot")
logging.basicConfig()
log.warning("Start")
load_dotenv()
os.environ["AMD_ENABLE_LLPC"] = "0"
TG_TOKEN = os.getenv("TG_TOKEN")
SELECTED_MODEL = "stablediffusion"
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
STEPS = 30
NEGATIVE_PROMPT = (
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
" legs, fused fingers, too many fingers"
)
GUIDANCE_SCALE = 6
available_devices = get_available_devices()
models_list = [
"stablediffusion",
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]
sheds_list = [
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
]
def image_to_bytes(image):
bio = BytesIO()
bio.name = "image.jpeg"
image.save(bio, "JPEG")
bio.seek(0)
return bio
def get_try_again_markup():
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
reply_markup = InlineKeyboardMarkup(keyboard)
return reply_markup
def generate_image(prompt):
seed = random.randint(1, 10000)
log.warning(SELECTED_MODEL)
log.warning(STEPS)
image, text = stable_diff_inf(
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
steps=STEPS,
guidance_scale=GUIDANCE_SCALE,
seed=seed,
scheduler_key=SELECTED_SCHEDULER,
variant=SELECTED_MODEL,
device_key=available_devices[0],
)
return image, seed
async def generate_and_send_photo(
update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
progress_msg = await update.message.reply_text(
"Generating image...", reply_to_message_id=update.message.message_id
)
im, seed = generate_image(prompt=update.message.text)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{update.message.text}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=update.message.message_id,
)
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
query = update.callback_query
if query.data in models_list:
global SELECTED_MODEL
SELECTED_MODEL = query.data
await query.answer()
await query.edit_message_text(text=f"Selected model: {query.data}")
return
if query.data in sheds_list:
global SELECTED_SCHEDULER
SELECTED_SCHEDULER = query.data
await query.answer()
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
return
replied_message = query.message.reply_to_message
await query.answer()
progress_msg = await query.message.reply_text(
"Generating image...", reply_to_message_id=replied_message.message_id
)
if query.data == "TRYAGAIN":
prompt = replied_message.text
im, seed = generate_image(prompt)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{prompt}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=replied_message.message_id,
)
async def select_model_handler(update, context):
text = "Select model"
keyboard = []
for model in models_list:
keyboard.append(
[
InlineKeyboardButton(text=model, callback_data=model),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def select_scheduler_handler(update, context):
text = "Select schedule"
keyboard = []
for shed in sheds_list:
keyboard.append(
[
InlineKeyboardButton(text=shed, callback_data=shed),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def set_steps_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_steps ")[1]
global STEPS
STEPS = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_steps 30"
)
await update.message.reply_text(input_args)
async def set_negative_prompt_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_negative_prompt ")[1]
global NEGATIVE_PROMPT
NEGATIVE_PROMPT = input_args
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_negative_prompt ugly, bad art, mutated"
)
await update.message.reply_text(input_args)
async def set_guidance_scale_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_guidance_scale ")[1]
global GUIDANCE_SCALE
GUIDANCE_SCALE = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_guidance_scale 7"
)
await update.message.reply_text(input_args)
async def setup_bot_commands(application: Application) -> None:
await application.bot.set_my_commands(
[
BotCommand("select_model", "to select model"),
BotCommand("select_scheduler", "to select scheduler"),
BotCommand("set_steps", "to set steps"),
BotCommand("set_guidance_scale", "to set guidance scale"),
BotCommand("set_negative_prompt", "to set negative prompt"),
]
)
app = (
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
)
app.add_handler(CommandHandler("select_model", select_model_handler))
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
app.add_handler(CommandHandler("set_steps", set_steps_handler))
app.add_handler(
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
)
app.add_handler(
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
)
app.add_handler(
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
)
app.add_handler(CallbackQueryHandler(button))
log.warning("Start bot")
app.run_polling()

View File

@@ -0,0 +1,674 @@
# Install the required libs
# pip install -U git+https://github.com/huggingface/diffusers.git
# pip install accelerate transformers ftfy
# HuggingFace Token
# YOUR_TOKEN = "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"
# Import required libraries
import itertools
import math
import os
from typing import List
import random
import torch_mlir
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import PIL
import logging
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
import torch._dynamo as dynamo
from torch.fx.experimental.proxy_tensor import make_fx
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
from shark.shark_inference import SharkInference
torch._dynamo.config.verbose = True
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker,
)
from PIL import Image
from tqdm.auto import tqdm
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer,
)
from io import BytesIO
from dataclasses import dataclass
from apps.stable_diffusion.src import (
args,
get_schedulers,
set_init_device_flags,
clear_all,
)
# Setup the dataset
class LoraDataset(Dataset):
def __init__(
self,
data_root,
tokenizer,
size=512,
repeats=100,
interpolation="bicubic",
set="train",
prompt="myloraprompt",
center_crop=False,
):
self.data_root = data_root
self.tokenizer = tokenizer
self.size = size
self.center_crop = center_crop
self.prompt = prompt
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
self.num_images = len(self.image_paths)
self._length = self.num_images
if set == "train":
self._length = self.num_images * repeats
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
example["input_ids"] = self.tokenizer(
self.prompt,
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids[0]
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = np.array(image).astype(np.uint8)
image = (image / 127.5 - 1.0).astype(np.float32)
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
return example
schedulers = None
########## Setting up the model ##########
def lora_train(
prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
training_images_dir: str,
lora_save_dir: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
print(
"Note LoRA training is not compatible with the latest torch-mlir branch"
)
print(
"To run LoRA training you'll need this to follow this guide for the torch-mlir branch: https://github.com/nod-ai/SHARK/tree/main/shark/examples/shark_training/stable_diffusion"
)
torch.manual_seed(seed)
args.prompts = [prompt]
args.steps = steps
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = custom_model
else:
args.hf_model_id = custom_model
args.training_images_dir = training_images_dir
args.lora_save_dir = lora_save_dir
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device
# Load the Stable Diffusion model
text_encoder = CLIPTextModel.from_pretrained(
args.hf_model_id, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(args.hf_model_id, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
args.hf_model_id, subfolder="unet"
)
def freeze_params(params):
for param in params:
param.requires_grad = False
# Freeze everything but LoRA
freeze_params(vae.parameters())
freeze_params(unet.parameters())
freeze_params(text_encoder.parameters())
# Move vae and unet to device
vae.to(args.device)
unet.to(args.device)
text_encoder.to(args.device)
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = vae
def forward(self, input):
x = self.vae.encode(input, return_dict=False)[0]
return x
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = unet
def forward(self, x, y, z):
return self.unet.forward(x, y, z, return_dict=False)[0]
shark_vae = VaeModel()
shark_unet = UnetModel()
####### Creating our training data ########
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id,
subfolder="tokenizer",
)
# Let's create the Dataset and Dataloader
train_dataset = LoraDataset(
data_root=args.training_images_dir,
tokenizer=tokenizer,
size=vae.sample_size,
prompt=args.prompts[0],
repeats=100,
center_crop=False,
set="train",
)
def create_dataloader(train_batch_size=1):
return torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True
)
# Create noise_scheduler for training
noise_scheduler = DDPMScheduler.from_config(
args.hf_model_id, subfolder="scheduler"
)
######## Training ###########
# Define hyperparameters for our training. If you are not happy with your results,
# you can tune the `learning_rate` and the `max_train_steps`
# Setting up all training args
hyperparameters = {
"learning_rate": 5e-04,
"scale_lr": True,
"max_train_steps": steps,
"train_batch_size": batch_size,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True,
"mixed_precision": "fp16",
"seed": 42,
"output_dir": "sd-concept-output",
}
# creating output directory
cwd = os.getcwd()
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
while not os.path.exists(str(out_dir)):
try:
os.mkdir(out_dir)
except OSError as error:
print("Output directory not created")
###### Torch-MLIR Compilation ######
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
@make_simple_dynamo_backend
def refbackend_torchdynamo_backend(
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
):
# handling usage of empty tensor without initializing
transform_fx(fx_graph)
fx_graph.recompile()
if _returns_nothing(fx_graph):
return fx_graph
removed_none_indexes = _remove_nones(fx_graph)
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
mlir_module = torch_mlir.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors"
)
bytecode_stream = BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
def compiled_callable(*inputs):
inputs = [x.numpy() for x in inputs]
result = shark_module("forward", inputs)
if was_unwrapped:
result = [
result,
]
if not isinstance(result, list):
result = torch.from_numpy(result)
else:
result = tuple(torch.from_numpy(x) for x in result)
result = list(result)
for removed_index in removed_none_indexes:
result.insert(removed_index, None)
result = tuple(result)
return result
return compiled_callable
def predictions(torch_func, jit_func, batchA, batchB):
res = jit_func(batchA.numpy(), batchB.numpy())
if res is not None:
# prediction = torch.from_numpy(res)
prediction = res
else:
prediction = None
return prediction
logger = logging.getLogger(__name__)
train_batch_size = hyperparameters["train_batch_size"]
gradient_accumulation_steps = hyperparameters[
"gradient_accumulation_steps"
]
learning_rate = hyperparameters["learning_rate"]
if hyperparameters["scale_lr"]:
learning_rate = (
learning_rate
* gradient_accumulation_steps
* train_batch_size
# * accelerator.num_processes
)
# Initialize the optimizer
optimizer = torch.optim.AdamW(
lora_layers.parameters(), # only optimize the embeddings
lr=learning_rate,
)
# Training function
def train_func(batch_pixel_values, batch_input_ids):
# Convert images to latent space
latents = shark_vae(batch_pixel_values).sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.num_train_timesteps,
(bsz,),
device=latents.device,
).long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch_input_ids)[0]
# Predict the noise residual
noise_pred = shark_unet(
noisy_latents,
timesteps,
encoder_hidden_states,
)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
)
loss = (
F.mse_loss(noise_pred, target, reduction="none")
.mean([1, 2, 3])
.mean()
)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
def training_function():
max_train_steps = hyperparameters["max_train_steps"]
output_dir = hyperparameters["output_dir"]
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
train_dataloader = create_dataloader(train_batch_size)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
num_train_epochs = math.ceil(
max_train_steps / num_update_steps_per_epoch
)
# Train!
total_batch_size = (
train_batch_size
* gradient_accumulation_steps
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(
f" Instantaneous batch size per device = {train_batch_size}"
)
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(
# range(max_train_steps), disable=not accelerator.is_local_main_process
range(max_train_steps)
)
progress_bar.set_description("Steps")
global_step = 0
params__ = [
i for i in text_encoder.get_input_embeddings().parameters()
]
for epoch in range(num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
dynamo_callable = dynamo.optimize(
refbackend_torchdynamo_backend
)(train_func)
lam_func = lambda x, y: dynamo_callable(
torch.from_numpy(x), torch.from_numpy(y)
)
loss = predictions(
train_func,
lam_func,
batch["pixel_values"],
batch["input_ids"],
)
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item()}
progress_bar.set_postfix(**logs)
if global_step >= max_train_steps:
break
training_function()
# Save the lora weights
unet.save_attn_procs(args.lora_save_dir)
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
if param.grad is not None:
del param.grad # free some memory
torch.cuda.empty_cache()
if __name__ == "__main__":
if args.clear_all:
clear_all()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
if len(args.prompts) != 1:
print("Need exactly one prompt for the LoRA word")
lora_train(
args.prompts[0],
args.height,
args.width,
args.training_steps,
args.guidance_scale,
args.seed,
args.batch_count,
args.batch_size,
args.scheduler,
"None",
args.hf_model_id,
args.precision,
args.device,
args.max_length,
args.training_images_dir,
args.lora_save_dir,
)

View File

@@ -0,0 +1,258 @@
import torch
import time
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"txt2img",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=use_lora,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_sd_obj(
Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":
if args.clear_all:
clear_all()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
use_lora = args.use_lora
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)

View File

@@ -0,0 +1,269 @@
import torch
import time
from PIL import Image
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def upscaler_inf(
prompt: str,
negative_prompt: str,
init_image,
height: int,
width: int,
steps: int,
noise_level: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.scheduler = scheduler
if init_image is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB").resize((height, width))
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
args.height = 128
args.width = 128
new_config_obj = Config(
"upscaler",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
args.height,
args.width,
device,
use_lora=None,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_size = batch_size
args.max_length = max_length
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_sd_obj(
UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.get_sd_obj().low_res_scheduler = schedulers["DDPM"]
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"NOISE LEVEL": noise_level}
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
low_res_img = image
high_res_img = Image.new("RGB", (height * 4, width * 4))
for i in range(0, width, 128):
for j in range(0, height, 128):
box = (j, i, j + 128, i + 128)
upscaled_image = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
low_res_img.crop(box),
batch_size,
args.height,
args.width,
steps,
noise_level,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
save_output_img(high_res_img, img_seed, extra_info)
generated_imgs.append(high_res_img)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
# When the models get uploaded, it should be default to False.
args.import_mlir = True
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
image = (
Image.open(args.img_path)
.convert("RGB")
.resize((args.height, args.width))
)
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model
upscaler_obj = UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
ddpm_scheduler=schedulers["DDPM"],
)
start_time = time.time()
generated_imgs = upscaler_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.batch_size,
args.height,
args.width,
args.steps,
args.noise_level,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, noise_level={args.noise_level}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += upscaler_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
extra_info = {"NOISE LEVEL": args.noise_level}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)

View File

@@ -0,0 +1,84 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv-python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
( 'web/ui/css/*', 'ui/css' ),
( 'web/ui/logos/*', 'logos' )
]
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
a = Analysis(
['web/index.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,82 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_submodules
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
]
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
a = Analysis(
['scripts/txt2img.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,17 @@
from apps.stable_diffusion.src.utils import (
args,
set_init_device_flags,
prompt_examples,
get_available_devices,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,
StencilPipeline,
UpscalerPipeline,
)
from apps.stable_diffusion.src.schedulers import get_schedulers

View File

@@ -0,0 +1,12 @@
from apps.stable_diffusion.src.models.model_wrappers import (
SharkifyStableDiffusionModel,
)
from apps.stable_diffusion.src.models.opt_params import (
get_vae_encode,
get_vae,
get_unet,
get_clip,
get_tokenizer,
get_params,
get_variant_version,
)

View File

@@ -0,0 +1,665 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from collections import defaultdict
import torch
import safetensors.torch
import traceback
import sys
import os
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
base_models,
args,
fetch_or_delete_vmfbs,
preprocessCKPT,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
get_stencil_model_id,
)
# These shapes are parameter dependent.
def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape = []
for i in range(len(shape)):
if shape[i] == "max_len":
new_shape.append(max_len)
elif shape[i] == "height":
new_shape.append(height)
elif shape[i] == "width":
new_shape.append(width)
elif isinstance(shape[i], str):
if "*" in shape[i]:
mul_val = int(shape[i].split("*")[0])
if "batch_size" in shape[i]:
new_shape.append(batch_size * mul_val)
elif "height" in shape[i]:
new_shape.append(height * mul_val)
elif "width" in shape[i]:
new_shape.append(width * mul_val)
elif "/" in shape[i]:
import math
div_val = int(shape[i].split("/")[1])
if "batch_size" in shape[i]:
new_shape.append(math.ceil(batch_size / div_val))
elif "height" in shape[i]:
new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val))
else:
new_shape.append(shape[i])
return new_shape
# Get the input info for various models i.e. "unet", "clip", "vae", "vae_encode".
def get_input_info(model_info, max_len, width, height, batch_size):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = defaultdict(list)
for k in model_info:
for inp in model_info[k]:
shape = model_info[k][inp]["shape"]
dtype = dtype_config[model_info[k][inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, max_len, width, height, batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map[k].append(tensor)
return input_map
class SharkifyStableDiffusionModel:
def __init__(
self,
model_id: str,
custom_weights: str,
custom_vae: str,
precision: str,
max_len: int = 64,
width: int = 512,
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
low_cpu_mem_usage: bool = False,
debug: bool = False,
sharktank_dir: str = "",
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
use_lora: str = ""
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
self.model_id = "stabilityai/stable-diffusion-2-1-base"
self.custom_vae = custom_vae
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
"_"
+ str(batch_size)
+ "_"
+ str(max_len)
+ "_"
+ str(height)
+ "_"
+ str(width)
+ "_"
+ precision
)
print(f'use_tuned? sharkify: {use_tuned}')
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.is_upscaler = is_upscaler
self.use_stencil = get_stencil_model_id(use_stencil)
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora
print(self.model_name)
self.debug = debug
self.sharktank_dir = sharktank_dir
self.generate_vmfb = generate_vmfb
def get_extended_name_for_all_model(self, mask_to_fetch):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
index = 0
for model in sub_model_list:
if mask_to_fetch[index] == False:
index += 1
continue
sub_model = model
model_config = self.model_name
if "vae" == model:
if self.custom_vae != "":
model_config = model_config + get_path_stem(self.custom_vae)
if self.base_vae:
sub_model = "base_vae"
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
sys.exit("please specify max_len in the range [32, 77].")
if not (width % 8 == 0 and width >= 128):
sys.exit("width should be greater than 128 and multiple of 8")
if not (height % 8 == 0 and height >= 128):
sys.exit("height should be greater than 128 and multiple of 8")
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
def forward(self, input):
latents = self.vae.encode(input).latent_dist.sample()
return 0.18215 * latents
vae_encode = VaeEncodeModel()
inputs = tuple(self.inputs["vae_encode"])
is_f16 = True if self.precision == "fp16" else False
shark_vae_encode = compile_through_fx(
vae_encode,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
model_name=self.model_name["vae_encode"],
extra_args=get_opt_flags("vae", precision=self.precision),
)
return shark_vae_encode
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
super().__init__()
self.vae = None
if custom_vae == "":
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae
def forward(self, input):
if not self.base_vae:
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
if self.base_vae:
return x
x = x * 255.0
return x.round()
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
)
return shark_vae
def get_vae_upscaler(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
def forward(self, input):
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
return x
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
shark_vae = compile_through_fx(
vae,
inputs,
use_tuned=self.use_tuned,
model_name=self.model_name["vae"],
extra_args=get_opt_flags("vae", precision="fp32"),
)
return shark_vae
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward( self, latent, timestep, text_embedding, guidance_scale, control1,
control2, control3, control4, control5, control6, control7,
control8, control9, control10, control11, control12, control13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple([ control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12,])
mb_res_samples = control13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
down_block_additional_residuals=db_res_samples,
mid_block_additional_residual=mb_res_samples,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = ControlledUnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
shark_controlled_unet = compile_through_fx(
unet,
inputs,
model_name=self.model_name["stencil_unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_controlled_unet
def get_control_net(self):
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
):
super().__init__()
self.cnet = ControlNetModel.from_pretrained(
model_id,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.cnet.in_channels
self.train(False)
def forward(
self,
latent,
timestep,
text_embedding,
stencil_image_input,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
# TODO: guidance NOT NEEDED change in `get_input_info` later
latents = torch.cat(
[latent] * 2
) # needs to be same as controlledUNET latents
stencil_image = torch.cat(
[stencil_image_input] * 2
) # needs to be same as controlledUNET latents
down_block_res_samples, mid_block_res_sample = self.cnet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
controlnet_cond=stencil_image,
return_dict=False,
)
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
scnet = StencilControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
input_mask = [True, True, True, True]
shark_cnet = compile_through_fx(
scnet,
inputs,
model_name=self.model_name["stencil_adaptor"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_cnet
def get_unet(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
self.unet.load_attn_procs(use_lora)
self.in_channels = self.unet.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
if(args.attention_slicing.isdigit()):
self.unet.set_attention_slice(int(args.attention_slicing))
else:
self.unet.set_attention_slice(args.attention_slicing)
# TODO: Instead of flattening the `control` try to use the list.
def forward(
self, latent, timestep, text_embedding, guidance_scale,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet = compile_through_fx(
unet,
inputs,
model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_unet
def get_unet_upscaler(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, latent, timestep, text_embedding, noise_level):
unet_out = self.unet.forward(
latent,
timestep,
text_embedding,
noise_level,
return_dict=False,
)[0]
return unet_out
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
shark_unet = compile_through_fx(
unet,
inputs,
model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_unet
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
model_name=self.model_name["clip"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
)
return shark_clip
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
if not custom_vae.endswith((".ckpt", ".safetensors")):
return self.custom_vae
try:
preprocessCKPT(self.custom_vae)
return get_path_to_diffusers_checkpoint(self.custom_vae)
except:
print("Processing standalone Vae checkpoint")
vae_checkpoint = None
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
if custom_vae.endswith(".ckpt"):
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
else:
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
if "state_dict" in vae_checkpoint:
vae_checkpoint = vae_checkpoint["state_dict"]
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
# configiration.
def compile_all(self, base_model_id, need_vae_encode, need_stencil):
self.inputs = get_input_info(
base_models[base_model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
if self.is_upscaler:
return self.get_clip(), self.get_unet_upscaler(), self.get_vae_upscaler()
compiled_controlnet = None
compiled_controlled_unet = None
compiled_unet = None
if need_stencil:
compiled_controlnet = self.get_control_net()
compiled_controlled_unet = self.get_controlled_unet()
else:
compiled_unet = self.get_unet()
if self.custom_vae != "":
print("Plugging in custom Vae")
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
if need_stencil:
return compiled_clip, compiled_controlled_unet, compiled_vae, compiled_controlnet
if need_vae_encode:
compiled_vae_encode = self.get_vae_encode()
return compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode
return compiled_clip, compiled_unet, compiled_vae
def __call__(self):
# Step 1:
# -- Fetch all vmfbs for the model, if present, else delete the lot.
need_vae_encode, need_stencil = False, False
if not self.is_upscaler and args.img_path is not None:
if self.use_stencil is not None:
need_stencil = True
else:
need_vae_encode = True
# `mask_to_fetch` prepares a mask to pick a combination out of :-
# ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
mask_to_fetch = [True, True, False, True, False, False]
if need_vae_encode:
mask_to_fetch = [True, True, False, True, True, False]
elif need_stencil:
mask_to_fetch = [True, False, True, True, False, True]
self.model_name = self.get_extended_name_for_all_model(mask_to_fetch)
vmfbs = fetch_or_delete_vmfbs(self.model_name, self.precision)
if vmfbs[0]:
# -- If all vmfbs are indeed present, we also try and fetch the base
# model configuration for running SD with custom checkpoints.
if self.custom_weights != "":
args.hf_model_id = fetch_and_update_base_model_id(self.custom_weights)
if args.hf_model_id == "":
sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.")
print("Loaded vmfbs from cache and successfully fetched base model configuration.")
return vmfbs
# Step 2:
# -- If vmfbs weren't found, we try to see if the base model configuration
# for the required SD run is known to us and bypass the retry mechanism.
model_to_run = ""
if self.custom_weights != "":
model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights, self.is_inpaint)
else:
model_to_run = args.hf_model_id
# For custom Vae user can provide either the repo-id or a checkpoint file,
# and for a checkpoint file we'd need to process it via Diffusers' script.
self.custom_vae = self.process_custom_vae()
base_model_fetched = fetch_and_update_base_model_id(model_to_run)
if base_model_fetched != "":
print("Compiling all the models with the fetched base model configuration.")
if args.ckpt_loc != "":
args.hf_model_id = base_model_fetched
return self.compile_all(base_model_fetched, need_vae_encode, need_stencil)
# Step 3:
# -- This is the retry mechanism where the base model's configuration is not
# known to us and figure that out by trial and error.
print("Inferring base model configuration.")
for model_id in base_models:
try:
if need_vae_encode:
compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode = self.compile_all(model_id, need_vae_encode, need_stencil)
elif need_stencil:
compiled_clip, compiled_unet, compiled_vae, compiled_controlnet = self.compile_all(model_id, need_vae_encode, need_stencil)
else:
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id, need_vae_encode, need_stencil)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
if need_vae_encode:
return (
compiled_clip,
compiled_unet,
compiled_vae,
compiled_vae_encode,
)
if need_stencil:
return (
compiled_clip,
compiled_unet,
compiled_vae,
compiled_controlnet,
)
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
)

View File

@@ -0,0 +1,108 @@
import sys
from transformers import CLIPTokenizer
from apps.stable_diffusion.src.utils import (
models_db,
args,
get_shark_model,
get_opt_flags,
)
hf_model_variant_map = {
"Linaqruf/anything-v3.0": ["anythingv3", "v1_4"],
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v1_4"],
"prompthero/openjourney": ["openjourney", "v1_4"],
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v1_4"],
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
}
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
def get_params(bucket_key, model_key, model, is_tuned, precision):
try:
bucket = models_db[0][bucket_key]
model_name = models_db[1][model_key]
except KeyError:
raise Exception(
f"{bucket_key}/{model_key} is not present in the models database"
)
iree_flags = get_opt_flags(model, precision="fp16")
return bucket, model_name, iree_flags
def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "unet", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae_encode():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
is_base = "/base" if args.use_base_vae else ""
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
variant, version = get_variant_version(args.hf_model_id)
bucket_key = f"{variant}/untuned"
model_key = (
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
)
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "clip", "untuned", "fp32"
)
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
)
return tokenizer

View File

@@ -0,0 +1,18 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
InpaintPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpaint import (
OutpaintPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_stencil import (
StencilPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_upscaler import (
UpscalerPipeline,
)

View File

@@ -0,0 +1,172 @@
import torch
import time
import numpy as np
from tqdm.auto import tqdm
from random import randint
from PIL import Image
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
class Image2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
def prepare_image_latents(
self,
image,
batch_size,
height,
width,
generator,
num_inference_steps,
strength,
dtype,
):
# Pre process image -> get image encoded -> process latents
# TODO: process with variable HxW combos
# Pre process image
image = image.resize((width, height))
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
image_arr = 2 * (image_arr - 0.5)
# set scheduler steps
self.scheduler.set_timesteps(num_inference_steps)
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps
)
t_start = max(num_inference_steps - init_timestep, 0)
# timesteps reduced as per strength
timesteps = self.scheduler.timesteps[t_start:]
# new number of steps to be used as per strength will be
# num_inference_steps = num_inference_steps - t_start
# image encode
latents = self.encode_image((image_arr,))
latents = torch.from_numpy(latents).to(dtype)
# add noise to data
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
latents = self.scheduler.add_noise(
latents, noise, timesteps[0].repeat(1)
)
return latents, timesteps
def encode_image(self, input_image):
vae_encode_start = time.time()
latents = self.vae_encode("forward", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
return latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
strength,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
use_stencil,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare input image latent
image_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=image_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
return all_imgs

View File

@@ -0,0 +1,445 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from PIL import Image, ImageOps
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
class InpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
latents = latents * self.scheduler.init_noise_sigma
return latents
def get_crop_region(self, mask, pad=0):
h, w = mask.shape
crop_left = 0
for i in range(w):
if not (mask[:, i] == 0).all():
break
crop_left += 1
crop_right = 0
for i in reversed(range(w)):
if not (mask[:, i] == 0).all():
break
crop_right += 1
crop_top = 0
for i in range(h):
if not (mask[i] == 0).all():
break
crop_top += 1
crop_bottom = 0
for i in reversed(range(h)):
if not (mask[i] == 0).all():
break
crop_bottom += 1
return (
int(max(crop_left - pad, 0)),
int(max(crop_top - pad, 0)),
int(min(w - crop_right + pad, w)),
int(min(h - crop_bottom + pad, h)),
)
def expand_crop_region(
self,
crop_region,
processing_width,
processing_height,
image_width,
image_height,
):
x1, y1, x2, y2 = crop_region
ratio_crop_region = (x2 - x1) / (y2 - y1)
ratio_processing = processing_width / processing_height
if ratio_crop_region > ratio_processing:
desired_height = (x2 - x1) / ratio_processing
desired_height_diff = int(desired_height - (y2 - y1))
y1 -= desired_height_diff // 2
y2 += desired_height_diff - desired_height_diff // 2
if y2 >= image_height:
diff = y2 - image_height
y2 -= diff
y1 -= diff
if y1 < 0:
y2 -= y1
y1 -= y1
if y2 >= image_height:
y2 = image_height
else:
desired_width = (y2 - y1) * ratio_processing
desired_width_diff = int(desired_width - (x2 - x1))
x1 -= desired_width_diff // 2
x2 += desired_width_diff - desired_width_diff // 2
if x2 >= image_width:
diff = x2 - image_width
x2 -= diff
x1 -= diff
if x1 < 0:
x2 -= x1
x1 -= x1
if x2 >= image_width:
x2 = image_width
return x1, y1, x2, y2
def resize_image(self, resize_mode, im, width, height):
"""
resize_mode:
0: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
1: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
"""
if resize_mode == 0:
ratio = width / height
src_ratio = im.width / im.height
src_w = (
width if ratio > src_ratio else im.width * height // im.height
)
src_h = (
height if ratio <= src_ratio else im.height * width // im.width
)
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(
resized,
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
)
else:
ratio = width / height
src_ratio = im.width / im.height
src_w = (
width if ratio < src_ratio else im.width * height // im.height
)
src_h = (
height if ratio >= src_ratio else im.height * width // im.width
)
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(
resized,
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
)
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
res.paste(
resized.resize((width, fill_height), box=(0, 0, width, 0)),
box=(0, 0),
)
res.paste(
resized.resize(
(width, fill_height),
box=(0, resized.height, width, resized.height),
),
box=(0, fill_height + src_h),
)
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
res.paste(
resized.resize(
(fill_width, height), box=(0, 0, 0, height)
),
box=(0, 0),
)
res.paste(
resized.resize(
(fill_width, height),
box=(resized.width, 0, resized.width, height),
),
box=(fill_width + src_w, 0),
)
return res
def prepare_mask_and_masked_image(
self,
image,
mask,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
):
# preprocess image
image = image.resize((width, height))
mask = mask.resize((width, height))
paste_to = ()
overlay_image = None
if inpaint_full_res:
# prepare overlay image
overlay_image = Image.new("RGB", (image.width, image.height))
overlay_image.paste(
image.convert("RGB"),
mask=ImageOps.invert(mask.convert("L")),
)
# prepare mask
mask = mask.convert("L")
crop_region = self.get_crop_region(
np.array(mask), inpaint_full_res_padding
)
crop_region = self.expand_crop_region(
crop_region, width, height, mask.width, mask.height
)
x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region)
mask = self.resize_image(1, mask, width, height)
paste_to = (x1, y1, x2 - x1, y2 - y1)
# prepare image
image = image.crop(crop_region)
image = self.resize_image(1, image, width, height)
if isinstance(image, (Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image, paste_to, overlay_image
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
):
mask = torch.nn.functional.interpolate(
mask, size=(height // 8, width // 8)
)
mask = mask.to(dtype)
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
return mask, masked_image_latents
def apply_overlay(self, image, paste_loc, overlay):
x, y, w, h = paste_loc
image = self.resize_image(0, image, w, h)
overlay.paste(image, (x, y))
return overlay
def generate_images(
self,
prompts,
neg_prompts,
image,
mask_image,
batch_size,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Preprocess mask and image
(
mask,
masked_image,
paste_to,
overlay_image,
) = self.prepare_mask_and_masked_image(
image,
mask_image,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
)
# Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask=mask,
masked_image=masked_image,
batch_size=batch_size,
height=height,
width=width,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
mask=mask,
masked_image_latents=masked_image_latents,
)
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if inpaint_full_res:
output_image = self.apply_overlay(
all_imgs[0], paste_to, overlay_image
)
return [output_image]
return all_imgs

View File

@@ -0,0 +1,541 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from PIL import Image, ImageDraw, ImageFilter
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
import math
class OutpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_mask_and_masked_image(
self, image, mask, mask_blur, width, height
):
if mask_blur > 0:
mask = mask.filter(ImageFilter.GaussianBlur(mask_blur))
image = image.resize((width, height))
mask = mask.resize((width, height))
# preprocess image
if isinstance(image, (Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
):
mask = torch.nn.functional.interpolate(
mask, size=(height // 8, width // 8)
)
mask = mask.to(dtype)
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
return mask, masked_image_latents
def get_matched_noise(
self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05
):
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(data):
if data.ndim > 2: # has channels
out_fft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]),
dtype=np.complex128,
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_fft[:, :, c] = np.fft.fft2(
np.fft.fftshift(c_data), norm="ortho"
)
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
else: # one channel
out_fft = np.zeros(
(data.shape[0], data.shape[1]), dtype=np.complex128
)
out_fft[:, :] = np.fft.fft2(
np.fft.fftshift(data), norm="ortho"
)
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
return out_fft
def _ifft2(data):
if data.ndim > 2: # has channels
out_ifft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]),
dtype=np.complex128,
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_ifft[:, :, c] = np.fft.ifft2(
np.fft.fftshift(c_data), norm="ortho"
)
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
else: # one channel
out_ifft = np.zeros(
(data.shape[0], data.shape[1]), dtype=np.complex128
)
out_ifft[:, :] = np.fft.ifft2(
np.fft.fftshift(data), norm="ortho"
)
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
return out_ifft
def _get_gaussian_window(width, height, std=3.14, mode=0):
window_scale_x = float(width / min(width, height))
window_scale_y = float(height / min(width, height))
window = np.zeros((width, height))
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
for y in range(height):
fy = (y / height * 2.0 - 1.0) * window_scale_y
if mode == 0:
window[:, y] = np.exp(-(x**2 + fy**2) * std)
else:
window[:, y] = (
1 / ((x**2 + 1.0) * (fy**2 + 1.0))
) ** (std / 3.14)
return window
def _get_masked_window_rgb(np_mask_grey, hardness=1.0):
np_mask_rgb = np.zeros(
(np_mask_grey.shape[0], np_mask_grey.shape[1], 3)
)
if hardness != 1.0:
hardened = np_mask_grey[:] ** hardness
else:
hardened = np_mask_grey[:]
for c in range(3):
np_mask_rgb[:, :, c] = hardened[:]
return np_mask_rgb
def _match_cumulative_cdf(source, template):
src_values, src_unique_indices, src_counts = np.unique(
source.ravel(), return_inverse=True, return_counts=True
)
tmpl_values, tmpl_counts = np.unique(
template.ravel(), return_counts=True
)
# calculate normalized quantiles for each array
src_quantiles = np.cumsum(src_counts) / source.size
tmpl_quantiles = np.cumsum(tmpl_counts) / template.size
interp_a_values = np.interp(
src_quantiles, tmpl_quantiles, tmpl_values
)
return interp_a_values[src_unique_indices].reshape(source.shape)
def _match_histograms(image, reference):
if image.ndim != reference.ndim:
raise ValueError(
"Image and reference must have the same number of channels."
)
if image.shape[-1] != reference.shape[-1]:
raise ValueError(
"Number of channels in the input image and reference image must match!"
)
matched = np.empty(image.shape, dtype=image.dtype)
for channel in range(image.shape[-1]):
matched_channel = _match_cumulative_cdf(
image[..., channel], reference[..., channel]
)
matched[..., channel] = matched_channel
matched = matched.astype(np.float64, copy=False)
return matched
width = _np_src_image.shape[0]
height = _np_src_image.shape[1]
num_channels = _np_src_image.shape[2]
np_src_image = _np_src_image[:] * (1.0 - np_mask_rgb)
np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
# rather than leave the masked area black, we get better results from fft by filling the average unmasked color
windowed_image = _np_src_image * (
1.0 - _get_masked_window_rgb(np_mask_grey)
)
windowed_image /= np.max(windowed_image)
windowed_image += np.average(_np_src_image) * np_mask_rgb
src_fft = _fft2(
windowed_image
) # get feature statistics from masked src img
src_dist = np.absolute(src_fft)
src_phase = src_fft / src_dist
# create a generator with a static seed to make outpainting deterministic / only follow global seed
rng = np.random.default_rng(0)
noise_window = _get_gaussian_window(
width, height, mode=1
) # start with simple gaussian noise
noise_rgb = rng.random((width, height, num_channels))
noise_grey = np.sum(noise_rgb, axis=2) / 3.0
# the colorfulness of the starting noise is blended to greyscale with a parameter
noise_rgb *= color_variation
for c in range(num_channels):
noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
noise_fft = _fft2(noise_rgb)
for c in range(num_channels):
noise_fft[:, :, c] *= noise_window
noise_rgb = np.real(_ifft2(noise_fft))
shaped_noise_fft = _fft2(noise_rgb)
shaped_noise_fft[:, :, :] = (
np.absolute(shaped_noise_fft[:, :, :]) ** 2
* (src_dist**noise_q)
* src_phase
) # perform the actual shaping
# color_variation
brightness_variation = 0.0
contrast_adjusted_np_src = (
_np_src_image[:] * (brightness_variation + 1.0)
- brightness_variation * 2.0
)
shaped_noise = np.real(_ifft2(shaped_noise_fft))
shaped_noise -= np.min(shaped_noise)
shaped_noise /= np.max(shaped_noise)
shaped_noise[img_mask, :] = _match_histograms(
shaped_noise[img_mask, :] ** 1.0,
contrast_adjusted_np_src[ref_mask, :],
)
shaped_noise = (
_np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
)
matched_noise = shaped_noise[:]
return np.clip(matched_noise, 0.0, 1.0)
def generate_images(
self,
prompts,
neg_prompts,
image,
pixels,
mask_blur,
is_left,
is_right,
is_top,
is_bottom,
noise_q,
color_variation,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
process_width = width
process_height = height
left = pixels if is_left else 0
right = pixels if is_right else 0
up = pixels if is_top else 0
down = pixels if is_bottom else 0
target_w = math.ceil((image.width + left + right) / 64) * 64
target_h = math.ceil((image.height + up + down) / 64) * 64
if left > 0:
left = left * (target_w - image.width) // (left + right)
if right > 0:
right = target_w - image.width - left
if up > 0:
up = up * (target_h - image.height) // (up + down)
if down > 0:
down = target_h - image.height - up
def expand(
init_img,
expand_pixels,
is_left=False,
is_right=False,
is_top=False,
is_bottom=False,
):
is_horiz = is_left or is_right
is_vert = is_top or is_bottom
pixels_horiz = expand_pixels if is_horiz else 0
pixels_vert = expand_pixels if is_vert else 0
res_w = init_img.width + pixels_horiz
res_h = init_img.height + pixels_vert
process_res_w = math.ceil(res_w / 64) * 64
process_res_h = math.ceil(res_h / 64) * 64
img = Image.new("RGB", (process_res_w, process_res_h))
img.paste(
init_img,
(pixels_horiz if is_left else 0, pixels_vert if is_top else 0),
)
msk = Image.new("RGB", (process_res_w, process_res_h), "white")
draw = ImageDraw.Draw(msk)
draw.rectangle(
(
expand_pixels + mask_blur if is_left else 0,
expand_pixels + mask_blur if is_top else 0,
msk.width - expand_pixels - mask_blur
if is_right
else res_w,
msk.height - expand_pixels - mask_blur
if is_bottom
else res_h,
),
fill="black",
)
np_image = (np.asarray(img) / 255.0).astype(np.float64)
np_mask = (np.asarray(msk) / 255.0).astype(np.float64)
noised = self.get_matched_noise(
np_image, np_mask, noise_q, color_variation
)
output_image = Image.fromarray(
np.clip(noised * 255.0, 0.0, 255.0).astype(np.uint8),
mode="RGB",
)
target_width = (
min(width, init_img.width + pixels_horiz)
if is_horiz
else img.width
)
target_height = (
min(height, init_img.height + pixels_vert)
if is_vert
else img.height
)
crop_region = (
0 if is_left else output_image.width - target_width,
0 if is_top else output_image.height - target_height,
target_width if is_left else output_image.width,
target_height if is_top else output_image.height,
)
mask_to_process = msk.crop(crop_region)
image_to_process = output_image.crop(crop_region)
# Preprocess mask and image
mask, masked_image = self.prepare_mask_and_masked_image(
image_to_process, mask_to_process, mask_blur, width, height
)
# Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask=mask,
masked_image=masked_image,
batch_size=batch_size,
height=height,
width=width,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
mask=mask,
masked_image_latents=masked_image_latents,
)
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
res_img = all_imgs[0].resize(
(image_to_process.width, image_to_process.height)
)
output_image.paste(
res_img,
(
0 if is_left else output_image.width - res_img.width,
0 if is_top else output_image.height - res_img.height,
),
)
output_image = output_image.crop((0, 0, res_w, res_h))
return output_image
img = image.resize((width, height))
if left > 0:
img = expand(img, left, is_left=True)
if right > 0:
img = expand(img, right, is_right=True)
if up > 0:
img = expand(img, up, is_top=True)
if down > 0:
img = expand(img, down, is_bottom=True)
return [img]

View File

@@ -0,0 +1,150 @@
import torch
import time
import numpy as np
from tqdm.auto import tqdm
from random import randint
from PIL import Image
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import controlnet_hint_conversion
class StencilPipeline(StableDiffusionPipeline):
def __init__(
self,
controlnet: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.controlnet = controlnet
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
strength,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
use_stencil,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
controlnet_hint = controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
)
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare initial latent.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
final_timesteps = self.scheduler.timesteps
# Get Image latents
latents = self.produce_stencil_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
controlnet=self.controlnet,
)
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
return all_imgs

View File

@@ -0,0 +1,139 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
return all_imgs

View File

@@ -0,0 +1,310 @@
import inspect
import torch
import time
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from PIL import Image
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, Image.Image):
image = [image]
if isinstance(image[0], Image.Image):
w, h = image[0].size
w, h = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class UpscalerPipeline(StableDiffusionPipeline):
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
low_res_scheduler: Union[
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.low_res_scheduler = low_res_scheduler
def prepare_extra_step_kwargs(self, generator, eta):
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
latents = 1 / 0.08333 * (latents.float())
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height,
width,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_img_latents(
self,
latents,
image,
text_embeddings,
guidance_scale,
noise_level,
total_timesteps,
dtype,
cpu_scheduling,
extra_step_kwargs,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = torch.cat([latent_model_input, image], dim=1)
timestep = torch.tensor([t]).to(dtype).detach().numpy()
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if cpu_scheduling:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
else:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
noise_level,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# 4. Preprocess image
image = preprocess(image).to(dtype)
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long)
noise = torch.randn(
image.shape,
generator=generator,
).to(dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
image = torch.cat([image] * 2)
noise_level = torch.cat([noise_level] * image.shape[0])
height, width = image.shape[2:]
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
eta = 0.0
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# guidance scale as a float32 tensor.
# guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
image=image,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
noise_level=noise_level,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
extra_step_kwargs=extra_step_kwargs,
)
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
return all_imgs

View File

@@ -0,0 +1,430 @@
import torch
import numpy as np
from transformers import CLIPTokenizer
from PIL import Image
from tqdm.auto import tqdm
import time
from typing import Union
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
get_vae,
get_clip,
get_unet,
get_tokenizer,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
class StableDiffusionPipeline:
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
):
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.unet = unet
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
text_input = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
# Get unconditional embeddings as well
uncond_input = self.tokenizer(
neg_prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
clip_inf_start = time.time()
text_embeddings = self.text_encoder("forward", (text_input,))
clip_inf_time = (time.time() - clip_inf_start) * 1000
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
if use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
if use_base_vae:
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def produce_stencil_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
controlnet=None,
controlnet_conditioning_scale: float = 1.0,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def produce_img_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
self.status = SD_STATE_IDLE
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
@classmethod
def from_pretrained(
cls,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
import_mlir: bool,
model_id: str,
ckpt_loc: str,
custom_vae: str,
precision: str,
max_length: int,
batch_size: int,
height: int,
width: int,
use_base_vae: bool,
use_tuned: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
):
is_inpaint = cls.__name__ in [
"InpaintPipeline",
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
if import_mlir or use_lora:
if not import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
)
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
)
if cls.__name__ in [
"Image2ImagePipeline",
"InpaintPipeline",
"OutpaintPipeline",
]:
clip, unet, vae, vae_encode = mlir_import()
return cls(
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
)
if cls.__name__ in ["StencilPipeline"]:
clip, unet, vae, controlnet = mlir_import()
return cls(
controlnet, vae, clip, get_tokenizer(), unet, scheduler
)
if cls.__name__ in ["UpscalerPipeline"]:
clip, unet, vae = mlir_import()
return cls(
vae, clip, get_tokenizer(), unet, scheduler, ddpm_scheduler
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
try:
if cls.__name__ in [
"Image2ImagePipeline",
"InpaintPipeline",
"OutpaintPipeline",
]:
return cls(
get_vae_encode(),
get_vae(),
get_clip(),
get_tokenizer(),
get_unet(),
scheduler,
)
if cls.__name__ == "StencilPipeline":
import sys
sys.exit(
"StencilPipeline not supported with SharkTank currently."
)
return cls(
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
)
except:
print("download pipeline failed, falling back to import_mlir")
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
)
if cls.__name__ in [
"Image2ImagePipeline",
"InpaintPipeline",
"OutpaintPipeline",
]:
clip, unet, vae, vae_encode = mlir_import()
return cls(
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
)
if cls.__name__ == "StencilPipeline":
clip, unet, vae, controlnet = mlir_import()
return cls(
controlnet, vae, clip, get_tokenizer(), unet, scheduler
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)

View File

@@ -0,0 +1,4 @@
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)

View File

@@ -0,0 +1,66 @@
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
def get_schedulers(model_id):
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"SharkEulerDiscrete"
] = SharkEulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers

View File

@@ -0,0 +1,156 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_shark_model,
args,
)
import torch
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
)
def compile(self):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
model_input = {
"euler": {
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
}
example_latent = model_input["euler"]["latent"]
example_output = model_input["euler"]["output"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_sigma = model_input["euler"]["sigma"]
example_dt = model_input["euler"]["dt"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
def _import(self):
scaling_model = ScalingModel()
self.scaling_model = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
if args.import_mlir:
_import(self)
else:
try:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_step_" + args.precision,
iree_flags,
)
except:
print(
"failed to download model, falling back and using import_mlir"
)
args.import_mlir = True
_import(self)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
return self.scaling_model(
"forward",
(
sample,
sigma,
),
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
return self.step_model(
"forward",
(
noise_pred,
sigma,
latent,
dt,
),
send_to_host=False,
)

View File

@@ -0,0 +1,36 @@
from apps.stable_diffusion.src.utils.profiler import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.utils.resources import (
prompt_examples,
models_db,
base_models,
opt_flags,
resource_path,
)
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
controlnet_hint_conversion,
get_stencil_model_id,
)
from apps.stable_diffusion.src.utils.utils import (
get_shark_model,
compile_through_fx,
set_iree_runtime_flags,
map_device_to_name_path,
set_init_device_flags,
get_available_devices,
get_opt_flags,
preprocessCKPT,
fetch_or_delete_vmfbs,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
get_path_stem,
get_extended_name,
clear_all,
save_output_img,
get_generation_text_info,
)

View File

@@ -0,0 +1,18 @@
from apps.stable_diffusion.src.utils.stable_args import args
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device:
import iree
print(f"Profiling and saving to {file_path}.")
vulkan_device = iree.runtime.get_device(args.device)
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
return vulkan_device
return None
def end_profiling(device):
if device:
return device.end_profiling()

View File

@@ -0,0 +1,37 @@
import os
import json
import sys
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
def get_json_file(path):
json_var = []
loc_json = resource_path(path)
if os.path.exists(loc_json):
with open(loc_json, encoding="utf-8") as fopen:
json_var = json.load(fopen)
if not json_var:
print(f"Unable to fetch {path}")
return json_var
# TODO: This shouldn't be called from here, every time the file imports
# it will run all the global vars.
prompt_examples = get_json_file("resources/prompts.json")
models_db = get_json_file("resources/model_db.json")
# The base_model contains the input configuration for the different
# models and also helps in providing information for the variants.
base_models = get_json_file("resources/base_model.json")
# Contains optimization flags for different models.
opt_flags = get_json_file("resources/opt_flags.json")

View File

@@ -0,0 +1,384 @@
{
"stabilityai/stable-diffusion-x4-upscaler": {
"unet": {
"latents": {
"shape": [
"2*batch_size",
7,
"8*height",
"8*width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"noise_level": {
"shape": [2],
"dtype": "i64"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"8*height","8*width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"stabilityai/stable-diffusion-2-1": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"CompVis/stable-diffusion-v1-4": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stencil_adaptor": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
}
},
"stencil_unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
},
"control1": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control2": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control3": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control4": {
"shape": [2, 320, "height/2", "width/2"],
"dtype": "f32"
},
"control5": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"control6": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"control7": {
"shape": [2, 640, "height/4", "width/4"],
"dtype": "f32"
},
"control8": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"control9": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"control10": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control11": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control12": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"stabilityai/stable-diffusion-2-inpainting": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"runwayml/stable-diffusion-inpainting": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
}
}

View File

@@ -0,0 +1,23 @@
[
{
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
"stablediffusion/inpaint_v1":"runwayml/stable-diffusion-inpainting",
"stablediffusion/inpaint_v2":"stabilityai/stable-diffusion-2-inpainting",
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
"openjourney/v1_4":"prompthero/openjourney",
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
},
{
"stablediffusion/fp16":"fp16",
"stablediffusion/fp32":"main",
"anythingv3/fp16":"diffusers",
"anythingv3/fp32":"diffusers",
"analogdiffusion/fp16":"main",
"analogdiffusion/fp32":"main",
"openjourney/fp16":"main",
"openjourney/fp32":"main"
}
]

View File

@@ -0,0 +1,85 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/sd_untuned",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/unet/fp32/length_64/untuned":"unet_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/vae/fp32/length_64/untuned":"vae_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"anythingv3/v1_4/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v1_4/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
"anythingv3/v1_4/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
"anythingv3/v1_4/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v1_4/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v1_4/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
"anythingv3/v1_4/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
"anythingv3/v1_4/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
"anythingv3/v1_4/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v1_4/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
"anythingv3/v1_4/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v1_4/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v1_4/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
"analogdiffusion/v1_4/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
"analogdiffusion/v1_4/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v1_4/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v1_4/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
"analogdiffusion/v1_4/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
"analogdiffusion/v1_4/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
"analogdiffusion/v1_4/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v1_4/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
"analogdiffusion/v1_4/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
"openjourney/v1_4/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
"openjourney/v1_4/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
"openjourney/v1_4/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
"openjourney/v1_4/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
"openjourney/v1_4/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
"openjourney/v1_4/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
"openjourney/v1_4/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
"dreamlike/v1_4/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
"dreamlike/v1_4/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
"dreamlike/v1_4/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
"dreamlike/v1_4/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
"dreamlike/v1_4/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v1_4/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v1_4/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
}
]

View File

@@ -0,0 +1,84 @@
{
"unet": {
"tuned": {
"fp16": {
"default_compilation_flags": []
},
"fp32": {
"default_compilation_flags": []
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
},
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
},
"fp32": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
},
"clip": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
}
}

View File

@@ -0,0 +1,8 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]

View File

@@ -0,0 +1,244 @@
import os
import io
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import iree_target_map, run_cmd
from shark.shark_downloader import (
download_model,
download_public_file,
WORKDIR,
)
from shark.parser import shark_args
from apps.stable_diffusion.src.utils.stable_args import args
def get_device():
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
return device
def get_device_args():
device = get_device()
device_spec_args = []
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
gpu_flags = get_iree_gpu_args()
for flag in gpu_flags:
device_spec_args.append(flag)
elif device == "vulkan":
device_spec_args.append(
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
return device, device_spec_args
# Download the model (Unet or VAE fp16) from shark_tank
def load_model_from_tank():
from apps.stable_diffusion.src.models import (
get_params,
get_variant_version,
)
variant, version = get_variant_version(args.hf_model_id)
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{variant}/untuned"
if args.annotation_model == "unet":
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
is_base = "/base" if args.use_base_vae else ""
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)
return mlir_model, model_name
# Download the tuned config files from shark_tank
def load_winograd_configs():
device = get_device()
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
winograd_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
def load_lower_configs():
from apps.stable_diffusion.src.models import get_variant_version
from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
)
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
variant, version = get_variant_version(base_model_id)
if version == "inpaint_v1":
version = "v1_4"
elif version == "inpaint_v2":
version = "v2_1base"
config_bucket = "gs://shark_tank/sd_tuned_configs/"
device, device_spec_args = get_device_args()
spec = ""
if device_spec_args:
spec = device_spec_args[-1].split("=")[-1].strip()
if device == "vulkan":
spec = spec.split("-")[0]
if args.annotation_model == "vae":
if not spec or spec in ["rdna3", "sm_80"]:
config_name = (
f"{args.annotation_model}_{args.precision}_{device}.json"
)
else:
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading lowering config file from ", lowering_config_dir)
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
# Annotate the model with Winograd attribute on selected conv ops
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
with create_context() as ctx:
winograd_model = model_annotation(
ctx,
input_contents=input_mlir,
config_path=winograd_config_dir,
search_op="conv",
winograd=True,
)
bytecode_stream = io.BytesIO()
winograd_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
if args.save_annotation:
if model_name.split("_")[-1] != "tuned":
out_file_path = os.path.join(
args.annotation_output, model_name + "_tuned_torch.mlir"
)
else:
out_file_path = os.path.join(
args.annotation_output, model_name + "_torch.mlir"
)
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return bytecode
def dump_after_mlir(input_mlir, use_winograd):
import iree.compiler as ireec
device, device_spec_args = get_device_args()
if use_winograd:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
else:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
dump_module = ireec.compile_str(
input_mlir,
target_backends=[iree_target_map(device)],
extra_args=device_spec_args
+ [
preprocess_flag,
"--compile-to=preprocessing",
],
)
return dump_module
# For Unet annotate the model with tuned lowering configs
def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
# Dump IR after padding/img2col/winograd passes
dump_module = dump_after_mlir(input_mlir, use_winograd)
print("Applying tuned configs on", model_name)
# Annotate the model with lowering configs in the config file
with create_context() as ctx:
tuned_model = model_annotation(
ctx,
input_contents=dump_module,
config_path=lowering_config_dir,
search_op="all",
)
bytecode_stream = io.BytesIO()
tuned_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
if args.save_annotation:
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return bytecode
def sd_model_annotation(mlir_model, model_name):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs()
tuned_model = annotate_with_lower_configs(
winograd_model, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
use_winograd = False
lowering_config_dir = load_lower_configs()
tuned_model = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)
return tuned_model
if __name__ == "__main__":
mlir_model, model_name = load_model_from_tank()
sd_model_annotation(mlir_model, model_name)

View File

@@ -0,0 +1,520 @@
import argparse
import os
from pathlib import Path
def path_expand(s):
return Path(s).expanduser().resolve()
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Stable Diffusion Params
##############################################################################
p.add_argument(
"-p",
"--prompts",
nargs="+",
default=["cyberpunk forest by Salvador Dali"],
help="text of which images to be generated.",
)
p.add_argument(
"--negative_prompts",
nargs="+",
default=["trees, green"],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--img_path",
type=str,
help="Path to the image input for img2img/inpainting",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=int,
default=-1,
help="the seed to use. -1 for a random one.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `batch_count`.",
)
p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 769, 8),
help="the height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 769, 8),
help="the width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="the value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="the value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="the strength of change applied on the given input image for img2img",
)
##############################################################################
### Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The no. of steps to train",
)
##############################################################################
### Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture",
)
p.add_argument(
"--inpaint_full_res_padding",
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding",
)
p.add_argument(
"--pixels",
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting",
)
p.add_argument(
"--mask_blur",
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend left for outpainting",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend right for outpainting",
)
p.add_argument(
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend top for outpainting",
)
p.add_argument(
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend bottom for outpainting",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0)",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0)",
)
##############################################################################
### Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
)
p.add_argument(
"--batch_count",
type=int,
default=1,
help="number of batch to be generated with random seeds in single execution",
)
p.add_argument(
"--ckpt_loc",
type=str,
default="",
help="Path to SD's .ckpt file.",
)
p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
)
p.add_argument(
"--hf_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer)",
)
p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble"],
help="Enable the stencil feature.",
)
p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="4147483648",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
)
##############################################################################
### Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save a generation information json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="if import_mlir is True, saves mlir via the debug option in shark importer. Does nothing if import_mlir is false (the default)",
)
##############################################################################
### Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the progress bar animation during image generation",
)
p.add_argument(
"--ckpt_dir",
type=str,
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
##############################################################################
### SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file",
)
p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)
p.add_argument(
"--save_annotation",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file",
)
args, unknown = p.parse_known_args()
if args.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
os.getcwd(), args.hf_model_id.replace("/", "_")
)

View File

@@ -0,0 +1,2 @@
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector

View File

@@ -0,0 +1,6 @@
import cv2
class CannyDetector:
def __call__(self, img, low_threshold, high_threshold):
return cv2.Canny(img, low_threshold, high_threshold)

View File

@@ -0,0 +1,62 @@
import requests
from pathlib import Path
import torch
import numpy as np
# from annotator.util import annotator_ckpts_path
from apps.stable_diffusion.src.utils.stencils.openpose.body import Body
from apps.stable_diffusion.src.utils.stencils.openpose.hand import Hand
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
draw_bodypose,
draw_handpose,
handDetect,
)
body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth"
class OpenposeDetector:
def __init__(self):
cwd = Path.cwd()
ckpt_path = Path(cwd, "stencil_annotator")
ckpt_path.mkdir(parents=True, exist_ok=True)
body_modelpath = ckpt_path / "body_pose_model.pth"
hand_modelpath = ckpt_path / "hand_pose_model.pth"
if not body_modelpath.is_file():
r = requests.get(body_model_path, allow_redirects=True)
open(body_modelpath, "wb").write(r.content)
if not hand_modelpath.is_file():
r = requests.get(hand_model_path, allow_redirects=True)
open(hand_modelpath, "wb").write(r.content)
self.body_estimation = Body(body_modelpath)
self.hand_estimation = Hand(hand_modelpath)
def __call__(self, oriImg, hand=False):
oriImg = oriImg[:, :, ::-1].copy()
with torch.no_grad():
candidate, subset = self.body_estimation(oriImg)
canvas = np.zeros_like(oriImg)
canvas = draw_bodypose(canvas, candidate, subset)
if hand:
hands_list = handDetect(candidate, subset, oriImg)
all_hand_peaks = []
for x, y, w, is_left in hands_list:
peaks = self.hand_estimation(
oriImg[y : y + w, x : x + w, :]
)
peaks[:, 0] = np.where(
peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x
)
peaks[:, 1] = np.where(
peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y
)
all_hand_peaks.append(peaks)
canvas = draw_handpose(canvas, all_hand_peaks)
return canvas, dict(
candidate=candidate.tolist(), subset=subset.tolist()
)

View File

@@ -0,0 +1,499 @@
import cv2
import numpy as np
import math
from scipy.ndimage.filters import gaussian_filter
import torch
import torch.nn as nn
from collections import OrderedDict
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
make_layers,
transfer,
padRightDownCorner,
)
class BodyPoseModel(nn.Module):
def __init__(self):
super(BodyPoseModel, self).__init__()
# these layers have no relu layer
no_relu_layers = [
"conv5_5_CPM_L1",
"conv5_5_CPM_L2",
"Mconv7_stage2_L1",
"Mconv7_stage2_L2",
"Mconv7_stage3_L1",
"Mconv7_stage3_L2",
"Mconv7_stage4_L1",
"Mconv7_stage4_L2",
"Mconv7_stage5_L1",
"Mconv7_stage5_L2",
"Mconv7_stage6_L1",
"Mconv7_stage6_L1",
]
blocks = {}
block0 = OrderedDict(
[
("conv1_1", [3, 64, 3, 1, 1]),
("conv1_2", [64, 64, 3, 1, 1]),
("pool1_stage1", [2, 2, 0]),
("conv2_1", [64, 128, 3, 1, 1]),
("conv2_2", [128, 128, 3, 1, 1]),
("pool2_stage1", [2, 2, 0]),
("conv3_1", [128, 256, 3, 1, 1]),
("conv3_2", [256, 256, 3, 1, 1]),
("conv3_3", [256, 256, 3, 1, 1]),
("conv3_4", [256, 256, 3, 1, 1]),
("pool3_stage1", [2, 2, 0]),
("conv4_1", [256, 512, 3, 1, 1]),
("conv4_2", [512, 512, 3, 1, 1]),
("conv4_3_CPM", [512, 256, 3, 1, 1]),
("conv4_4_CPM", [256, 128, 3, 1, 1]),
]
)
# Stage 1
block1_1 = OrderedDict(
[
("conv5_1_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_2_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_3_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_4_CPM_L1", [128, 512, 1, 1, 0]),
("conv5_5_CPM_L1", [512, 38, 1, 1, 0]),
]
)
block1_2 = OrderedDict(
[
("conv5_1_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_2_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_3_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_4_CPM_L2", [128, 512, 1, 1, 0]),
("conv5_5_CPM_L2", [512, 19, 1, 1, 0]),
]
)
blocks["block1_1"] = block1_1
blocks["block1_2"] = block1_2
self.model0 = make_layers(block0, no_relu_layers)
# Stages 2 - 6
for i in range(2, 7):
blocks["block%d_1" % i] = OrderedDict(
[
("Mconv1_stage%d_L1" % i, [185, 128, 7, 1, 3]),
("Mconv2_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d_L1" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d_L1" % i, [128, 38, 1, 1, 0]),
]
)
blocks["block%d_2" % i] = OrderedDict(
[
("Mconv1_stage%d_L2" % i, [185, 128, 7, 1, 3]),
("Mconv2_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d_L2" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d_L2" % i, [128, 19, 1, 1, 0]),
]
)
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_1 = blocks["block1_1"]
self.model2_1 = blocks["block2_1"]
self.model3_1 = blocks["block3_1"]
self.model4_1 = blocks["block4_1"]
self.model5_1 = blocks["block5_1"]
self.model6_1 = blocks["block6_1"]
self.model1_2 = blocks["block1_2"]
self.model2_2 = blocks["block2_2"]
self.model3_2 = blocks["block3_2"]
self.model4_2 = blocks["block4_2"]
self.model5_2 = blocks["block5_2"]
self.model6_2 = blocks["block6_2"]
def forward(self, x):
out1 = self.model0(x)
out1_1 = self.model1_1(out1)
out1_2 = self.model1_2(out1)
out2 = torch.cat([out1_1, out1_2, out1], 1)
out2_1 = self.model2_1(out2)
out2_2 = self.model2_2(out2)
out3 = torch.cat([out2_1, out2_2, out1], 1)
out3_1 = self.model3_1(out3)
out3_2 = self.model3_2(out3)
out4 = torch.cat([out3_1, out3_2, out1], 1)
out4_1 = self.model4_1(out4)
out4_2 = self.model4_2(out4)
out5 = torch.cat([out4_1, out4_2, out1], 1)
out5_1 = self.model5_1(out5)
out5_2 = self.model5_2(out5)
out6 = torch.cat([out5_1, out5_2, out1], 1)
out6_1 = self.model6_1(out6)
out6_2 = self.model6_2(out6)
return out6_1, out6_2
class Body(object):
def __init__(self, model_path):
self.model = BodyPoseModel()
if torch.cuda.is_available():
self.model = self.model.cuda()
model_dict = transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def __call__(self, oriImg):
scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre1 = 0.1
thre2 = 0.05
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = cv2.resize(
oriImg,
(0, 0),
fx=scale,
fy=scale,
interpolation=cv2.INTER_CUBIC,
)
imageToTest_padded, pad = padRightDownCorner(
imageToTest, stride, padValue
)
im = (
np.transpose(
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
(3, 2, 0, 1),
)
/ 256
- 0.5
)
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
if torch.cuda.is_available():
data = data.cuda()
with torch.no_grad():
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
# extract outputs, resize, and remove padding
heatmap = np.transpose(
np.squeeze(Mconv7_stage6_L2), (1, 2, 0)
) # output 1 is heatmaps
heatmap = cv2.resize(
heatmap,
(0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC,
)
heatmap = heatmap[
: imageToTest_padded.shape[0] - pad[2],
: imageToTest_padded.shape[1] - pad[3],
:,
]
heatmap = cv2.resize(
heatmap,
(oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC,
)
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
paf = np.transpose(
np.squeeze(Mconv7_stage6_L1), (1, 2, 0)
) # output 0 is PAFs
paf = cv2.resize(
paf,
(0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC,
)
paf = paf[
: imageToTest_padded.shape[0] - pad[2],
: imageToTest_padded.shape[1] - pad[3],
:,
]
paf = cv2.resize(
paf,
(oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC,
)
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
paf_avg += +paf / len(multiplier)
all_peaks = []
peak_counter = 0
for part in range(18):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
map_left = np.zeros(one_heatmap.shape)
map_left[1:, :] = one_heatmap[:-1, :]
map_right = np.zeros(one_heatmap.shape)
map_right[:-1, :] = one_heatmap[1:, :]
map_up = np.zeros(one_heatmap.shape)
map_up[:, 1:] = one_heatmap[:, :-1]
map_down = np.zeros(one_heatmap.shape)
map_down[:, :-1] = one_heatmap[:, 1:]
peaks_binary = np.logical_and.reduce(
(
one_heatmap >= map_left,
one_heatmap >= map_right,
one_heatmap >= map_up,
one_heatmap >= map_down,
one_heatmap > thre1,
)
)
peaks = list(
zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
) # note reverse
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
peak_id = range(peak_counter, peak_counter + len(peaks))
peaks_with_score_and_id = [
peaks_with_score[i] + (peak_id[i],)
for i in range(len(peak_id))
]
all_peaks.append(peaks_with_score_and_id)
peak_counter += len(peaks)
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
[3, 17],
[6, 18],
]
# the middle joints heatmap correpondence
mapIdx = [
[31, 32],
[39, 40],
[33, 34],
[35, 36],
[41, 42],
[43, 44],
[19, 20],
[21, 22],
[23, 24],
[25, 26],
[27, 28],
[29, 30],
[47, 48],
[49, 50],
[53, 54],
[51, 52],
[55, 56],
[37, 38],
[45, 46],
]
connection_all = []
special_k = []
mid_num = 10
for k in range(len(mapIdx)):
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
candA = all_peaks[limbSeq[k][0] - 1]
candB = all_peaks[limbSeq[k][1] - 1]
nA = len(candA)
nB = len(candB)
indexA, indexB = limbSeq[k]
if nA != 0 and nB != 0:
connection_candidate = []
for i in range(nA):
for j in range(nB):
vec = np.subtract(candB[j][:2], candA[i][:2])
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
norm = max(0.001, norm)
vec = np.divide(vec, norm)
startend = list(
zip(
np.linspace(
candA[i][0], candB[j][0], num=mid_num
),
np.linspace(
candA[i][1], candB[j][1], num=mid_num
),
)
)
vec_x = np.array(
[
score_mid[
int(round(startend[I][1])),
int(round(startend[I][0])),
0,
]
for I in range(len(startend))
]
)
vec_y = np.array(
[
score_mid[
int(round(startend[I][1])),
int(round(startend[I][0])),
1,
]
for I in range(len(startend))
]
)
score_midpts = np.multiply(
vec_x, vec[0]
) + np.multiply(vec_y, vec[1])
score_with_dist_prior = sum(score_midpts) / len(
score_midpts
) + min(0.5 * oriImg.shape[0] / norm - 1, 0)
criterion1 = len(
np.nonzero(score_midpts > thre2)[0]
) > 0.8 * len(score_midpts)
criterion2 = score_with_dist_prior > 0
if criterion1 and criterion2:
connection_candidate.append(
[
i,
j,
score_with_dist_prior,
score_with_dist_prior
+ candA[i][2]
+ candB[j][2],
]
)
connection_candidate = sorted(
connection_candidate, key=lambda x: x[2], reverse=True
)
connection = np.zeros((0, 5))
for c in range(len(connection_candidate)):
i, j, s = connection_candidate[c][0:3]
if i not in connection[:, 3] and j not in connection[:, 4]:
connection = np.vstack(
[connection, [candA[i][3], candB[j][3], s, i, j]]
)
if len(connection) >= min(nA, nB):
break
connection_all.append(connection)
else:
special_k.append(k)
connection_all.append([])
# last number in each row is the total parts number of that person
# the second last number in each row is the score of the overall configuration
subset = -1 * np.ones((0, 20))
candidate = np.array(
[item for sublist in all_peaks for item in sublist]
)
for k in range(len(mapIdx)):
if k not in special_k:
partAs = connection_all[k][:, 0]
partBs = connection_all[k][:, 1]
indexA, indexB = np.array(limbSeq[k]) - 1
for i in range(len(connection_all[k])): # = 1:size(temp,1)
found = 0
subset_idx = [-1, -1]
for j in range(len(subset)): # 1:size(subset,1):
if (
subset[j][indexA] == partAs[i]
or subset[j][indexB] == partBs[i]
):
subset_idx[found] = j
found += 1
if found == 1:
j = subset_idx[0]
if subset[j][indexB] != partBs[i]:
subset[j][indexB] = partBs[i]
subset[j][-1] += 1
subset[j][-2] += (
candidate[partBs[i].astype(int), 2]
+ connection_all[k][i][2]
)
elif found == 2: # if found 2 and disjoint, merge them
j1, j2 = subset_idx
membership = (
(subset[j1] >= 0).astype(int)
+ (subset[j2] >= 0).astype(int)
)[:-2]
if len(np.nonzero(membership == 2)[0]) == 0: # merge
subset[j1][:-2] += subset[j2][:-2] + 1
subset[j1][-2:] += subset[j2][-2:]
subset[j1][-2] += connection_all[k][i][2]
subset = np.delete(subset, j2, 0)
else: # as like found == 1
subset[j1][indexB] = partBs[i]
subset[j1][-1] += 1
subset[j1][-2] += (
candidate[partBs[i].astype(int), 2]
+ connection_all[k][i][2]
)
# if find no partA in the subset, create a new subset
elif not found and k < 17:
row = -1 * np.ones(20)
row[indexA] = partAs[i]
row[indexB] = partBs[i]
row[-1] = 2
row[-2] = (
sum(
candidate[
connection_all[k][i, :2].astype(int), 2
]
)
+ connection_all[k][i][2]
)
subset = np.vstack([subset, row])
# delete some rows of subset which has few parts occur
deleteIdx = []
for i in range(len(subset)):
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
deleteIdx.append(i)
subset = np.delete(subset, deleteIdx, axis=0)
# candidate: x, y, score, id
return candidate, subset

View File

@@ -0,0 +1,205 @@
import cv2
import numpy as np
from scipy.ndimage.filters import gaussian_filter
import torch
import torch.nn as nn
from skimage.measure import label
from collections import OrderedDict
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
make_layers,
transfer,
padRightDownCorner,
npmax,
)
class HandPoseModel(nn.Module):
def __init__(self):
super(HandPoseModel, self).__init__()
# these layers have no relu layer
no_relu_layers = [
"conv6_2_CPM",
"Mconv7_stage2",
"Mconv7_stage3",
"Mconv7_stage4",
"Mconv7_stage5",
"Mconv7_stage6",
]
# stage 1
block1_0 = OrderedDict(
[
("conv1_1", [3, 64, 3, 1, 1]),
("conv1_2", [64, 64, 3, 1, 1]),
("pool1_stage1", [2, 2, 0]),
("conv2_1", [64, 128, 3, 1, 1]),
("conv2_2", [128, 128, 3, 1, 1]),
("pool2_stage1", [2, 2, 0]),
("conv3_1", [128, 256, 3, 1, 1]),
("conv3_2", [256, 256, 3, 1, 1]),
("conv3_3", [256, 256, 3, 1, 1]),
("conv3_4", [256, 256, 3, 1, 1]),
("pool3_stage1", [2, 2, 0]),
("conv4_1", [256, 512, 3, 1, 1]),
("conv4_2", [512, 512, 3, 1, 1]),
("conv4_3", [512, 512, 3, 1, 1]),
("conv4_4", [512, 512, 3, 1, 1]),
("conv5_1", [512, 512, 3, 1, 1]),
("conv5_2", [512, 512, 3, 1, 1]),
("conv5_3_CPM", [512, 128, 3, 1, 1]),
]
)
block1_1 = OrderedDict(
[
("conv6_1_CPM", [128, 512, 1, 1, 0]),
("conv6_2_CPM", [512, 22, 1, 1, 0]),
]
)
blocks = {}
blocks["block1_0"] = block1_0
blocks["block1_1"] = block1_1
# stage 2-6
for i in range(2, 7):
blocks["block%d" % i] = OrderedDict(
[
("Mconv1_stage%d" % i, [150, 128, 7, 1, 3]),
("Mconv2_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d" % i, [128, 22, 1, 1, 0]),
]
)
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_0 = blocks["block1_0"]
self.model1_1 = blocks["block1_1"]
self.model2 = blocks["block2"]
self.model3 = blocks["block3"]
self.model4 = blocks["block4"]
self.model5 = blocks["block5"]
self.model6 = blocks["block6"]
def forward(self, x):
out1_0 = self.model1_0(x)
out1_1 = self.model1_1(out1_0)
concat_stage2 = torch.cat([out1_1, out1_0], 1)
out_stage2 = self.model2(concat_stage2)
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
out_stage3 = self.model3(concat_stage3)
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
out_stage4 = self.model4(concat_stage4)
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
out_stage5 = self.model5(concat_stage5)
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
out_stage6 = self.model6(concat_stage6)
return out_stage6
class Hand(object):
def __init__(self, model_path):
self.model = HandPoseModel()
if torch.cuda.is_available():
self.model = self.model.cuda()
model_dict = transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def __call__(self, oriImg):
scale_search = [0.5, 1.0, 1.5, 2.0]
# scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre = 0.05
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
# paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = cv2.resize(
oriImg,
(0, 0),
fx=scale,
fy=scale,
interpolation=cv2.INTER_CUBIC,
)
imageToTest_padded, pad = padRightDownCorner(
imageToTest, stride, padValue
)
im = (
np.transpose(
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
(3, 2, 0, 1),
)
/ 256
- 0.5
)
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
if torch.cuda.is_available():
data = data.cuda()
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
with torch.no_grad():
output = self.model(data).cpu().numpy()
# output = self.model(data).numpy()q
# extract outputs, resize, and remove padding
heatmap = np.transpose(
np.squeeze(output), (1, 2, 0)
) # output 1 is heatmaps
heatmap = cv2.resize(
heatmap,
(0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC,
)
heatmap = heatmap[
: imageToTest_padded.shape[0] - pad[2],
: imageToTest_padded.shape[1] - pad[3],
:,
]
heatmap = cv2.resize(
heatmap,
(oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC,
)
heatmap_avg += heatmap / len(multiplier)
all_peaks = []
for part in range(21):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
# 全部小于阈值
if np.sum(binary) == 0:
all_peaks.append([0, 0])
continue
label_img, label_numbers = label(
binary, return_num=True, connectivity=binary.ndim
)
max_index = (
np.argmax(
[
np.sum(map_ori[label_img == i])
for i in range(1, label_numbers + 1)
]
)
+ 1
)
label_img[label_img != max_index] = 0
map_ori[label_img == 0] = 0
y, x = npmax(map_ori)
all_peaks.append([x, y])
return np.array(all_peaks)

View File

@@ -0,0 +1,272 @@
import math
import numpy as np
import matplotlib
import cv2
from collections import OrderedDict
import torch.nn as nn
def make_layers(block, no_relu_layers):
layers = []
for layer_name, v in block.items():
if "pool" in layer_name:
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
layers.append((layer_name, layer))
else:
conv2d = nn.Conv2d(
in_channels=v[0],
out_channels=v[1],
kernel_size=v[2],
stride=v[3],
padding=v[4],
)
layers.append((layer_name, conv2d))
if layer_name not in no_relu_layers:
layers.append(("relu_" + layer_name, nn.ReLU(inplace=True)))
return nn.Sequential(OrderedDict(layers))
def padRightDownCorner(img, stride, padValue):
h = img.shape[0]
w = img.shape[1]
pad = 4 * [None]
pad[0] = 0 # up
pad[1] = 0 # left
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
img_padded = img
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
img_padded = np.concatenate((pad_up, img_padded), axis=0)
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
img_padded = np.concatenate((pad_left, img_padded), axis=1)
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
img_padded = np.concatenate((img_padded, pad_down), axis=0)
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
img_padded = np.concatenate((img_padded, pad_right), axis=1)
return img_padded, pad
# transfer caffe model to pytorch which will match the layer name
def transfer(model, model_weights):
transfered_model_weights = {}
for weights_name in model.state_dict().keys():
transfered_model_weights[weights_name] = model_weights[
".".join(weights_name.split(".")[1:])
]
return transfered_model_weights
# draw the body keypoint and lims
def draw_bodypose(canvas, candidate, subset):
stickwidth = 4
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
[3, 17],
[6, 18],
]
colors = [
[255, 0, 0],
[255, 85, 0],
[255, 170, 0],
[255, 255, 0],
[170, 255, 0],
[85, 255, 0],
[0, 255, 0],
[0, 255, 85],
[0, 255, 170],
[0, 255, 255],
[0, 170, 255],
[0, 85, 255],
[0, 0, 255],
[85, 0, 255],
[170, 0, 255],
[255, 0, 255],
[255, 0, 170],
[255, 0, 85],
]
for i in range(18):
for n in range(len(subset)):
index = int(subset[n][i])
if index == -1:
continue
x, y = candidate[index][0:2]
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
for i in range(17):
for n in range(len(subset)):
index = subset[n][np.array(limbSeq[i]) - 1]
if -1 in index:
continue
cur_canvas = canvas.copy()
Y = candidate[index.astype(int), 0]
X = candidate[index.astype(int), 1]
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly(
(int(mY), int(mX)),
(int(length / 2), stickwidth),
int(angle),
0,
360,
1,
)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
return canvas
# image drawed by opencv is not good.
def draw_handpose(canvas, all_hand_peaks, show_number=False):
edges = [
[0, 1],
[1, 2],
[2, 3],
[3, 4],
[0, 5],
[5, 6],
[6, 7],
[7, 8],
[0, 9],
[9, 10],
[10, 11],
[11, 12],
[0, 13],
[13, 14],
[14, 15],
[15, 16],
[0, 17],
[17, 18],
[18, 19],
[19, 20],
]
for peaks in all_hand_peaks:
for ie, e in enumerate(edges):
if np.sum(np.all(peaks[e], axis=1) == 0) == 0:
x1, y1 = peaks[e[0]]
x2, y2 = peaks[e[1]]
cv2.line(
canvas,
(x1, y1),
(x2, y2),
matplotlib.colors.hsv_to_rgb(
[ie / float(len(edges)), 1.0, 1.0]
)
* 255,
thickness=2,
)
for i, keyponit in enumerate(peaks):
x, y = keyponit
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
if show_number:
cv2.putText(
canvas,
str(i),
(x, y),
cv2.FONT_HERSHEY_SIMPLEX,
0.3,
(0, 0, 0),
lineType=cv2.LINE_AA,
)
return canvas
# detect hand according to body pose keypoints
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
def handDetect(candidate, subset, oriImg):
# right hand: wrist 4, elbow 3, shoulder 2
# left hand: wrist 7, elbow 6, shoulder 5
ratioWristElbow = 0.33
detect_result = []
image_height, image_width = oriImg.shape[0:2]
for person in subset.astype(int):
# if any of three not detected
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
if not (has_left or has_right):
continue
hands = []
# left hand
if has_left:
left_shoulder_index, left_elbow_index, left_wrist_index = person[
[5, 6, 7]
]
x1, y1 = candidate[left_shoulder_index][:2]
x2, y2 = candidate[left_elbow_index][:2]
x3, y3 = candidate[left_wrist_index][:2]
hands.append([x1, y1, x2, y2, x3, y3, True])
# right hand
if has_right:
(
right_shoulder_index,
right_elbow_index,
right_wrist_index,
) = person[[2, 3, 4]]
x1, y1 = candidate[right_shoulder_index][:2]
x2, y2 = candidate[right_elbow_index][:2]
x3, y3 = candidate[right_wrist_index][:2]
hands.append([x1, y1, x2, y2, x3, y3, False])
for x1, y1, x2, y2, x3, y3, is_left in hands:
x = x3 + ratioWristElbow * (x3 - x2)
y = y3 + ratioWristElbow * (y3 - y2)
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
# x-y refers to the center --> offset to topLeft point
x -= width / 2
y -= width / 2 # width = height
# overflow the image
if x < 0:
x = 0
if y < 0:
y = 0
width1 = width
width2 = width
if x + width > image_width:
width1 = image_width - x
if y + width > image_height:
width2 = image_height - y
width = min(width1, width2)
# the max hand box value is 20 pixels
if width >= 20:
detect_result.append([int(x), int(y), int(width), is_left])
"""
return value: [[x, y, w, True if left hand else False]].
width=height since the network require squared input.
x, y is the coordinate of top left
"""
return detect_result
# get max index of 2d array
def npmax(array):
arrayindex = array.argmax(1)
arrayvalue = array.max(1)
i = arrayvalue.argmax()
j = arrayindex[i]
return (i,)

View File

@@ -0,0 +1,186 @@
import numpy as np
from PIL import Image
import torch
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
)
stencil = {}
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def controlnet_hint_shaping(
controlnet_hint, height, width, dtype, num_images_per_prompt=1
):
channels = 3
if isinstance(controlnet_hint, torch.Tensor):
# torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt)
shape_chw = (channels, height, width)
shape_bchw = (1, channels, height, width)
shape_nchw = (num_images_per_prompt, channels, height, width)
if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]:
controlnet_hint = controlnet_hint.to(
dtype=dtype, device=torch.device("cpu")
)
if controlnet_hint.shape != shape_nchw:
controlnet_hint = controlnet_hint.repeat(
num_images_per_prompt, 1, 1, 1
)
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width}),"
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, np.ndarray):
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
# hwc is opencv compatible image format. Color channel must be BGR Format.
if controlnet_hint.shape == (height, width):
controlnet_hint = np.repeat(
controlnet_hint[:, :, np.newaxis], channels, axis=2
) # hw -> hwc(c==3)
shape_hwc = (height, width, channels)
shape_bhwc = (1, height, width, channels)
shape_nhwc = (num_images_per_prompt, height, width, channels)
if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]:
controlnet_hint = torch.from_numpy(controlnet_hint.copy())
controlnet_hint = controlnet_hint.to(
dtype=dtype, device=torch.device("cpu")
)
controlnet_hint /= 255.0
if controlnet_hint.shape != shape_nhwc:
controlnet_hint = controlnet_hint.repeat(
num_images_per_prompt, 1, 1, 1
)
controlnet_hint = controlnet_hint.permute(
0, 3, 1, 2
) # b h w c -> b c h w
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({width}, {channels}), "
+ f"({height}, {width}, {channels}), "
+ f"(1, {height}, {width}, {channels}) or "
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, Image.Image):
if controlnet_hint.size == (width, height):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
controlnet_hint = np.array(controlnet_hint) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return controlnet_hint_shaping(
controlnet_hint, height, width, num_images_per_prompt
)
else:
raise ValueError(
f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}"
)
else:
raise ValueError(
f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
)
def controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
):
controlnet_hint = None
match use_stencil:
case "canny":
print("Detecting edge with canny")
controlnet_hint = hint_canny(image)
case "openpose":
print("Detecting human pose")
controlnet_hint = hint_openpose(image)
case "scribble":
print("Working with scribble")
controlnet_hint = hint_scribble(image)
case _:
return None
controlnet_hint = controlnet_hint_shaping(
controlnet_hint, height, width, dtype, num_images_per_prompt
)
return controlnet_hint
stencil_to_model_id_map = {
"canny": "lllyasviel/sd-controlnet-canny",
"depth": "lllyasviel/sd-controlnet-depth",
"hed": "lllyasviel/sd-controlnet-hed",
"mlsd": "lllyasviel/sd-controlnet-mlsd",
"normal": "lllyasviel/sd-controlnet-normal",
"openpose": "lllyasviel/sd-controlnet-openpose",
"scribble": "lllyasviel/sd-controlnet-scribble",
"seg": "lllyasviel/sd-controlnet-seg",
}
def get_stencil_model_id(use_stencil):
if use_stencil in stencil_to_model_id_map:
return stencil_to_model_id_map[use_stencil]
return None
# Stencil 1. Canny
def hint_canny(
image: Image.Image,
low_threshold=100,
high_threshold=200,
):
with torch.no_grad():
input_image = np.array(image)
if not "canny" in stencil:
stencil["canny"] = CannyDetector()
detected_map = stencil["canny"](
input_image, low_threshold, high_threshold
)
detected_map = HWC3(detected_map)
return detected_map
# Stencil 2. OpenPose.
def hint_openpose(
image: Image.Image,
):
with torch.no_grad():
input_image = np.array(image)
if not "openpose" in stencil:
stencil["openpose"] = OpenposeDetector()
detected_map, _ = stencil["openpose"](input_image)
detected_map = HWC3(detected_map)
return detected_map
# Stencil 3. Scribble.
def hint_scribble(image: Image.Image):
with torch.no_grad():
input_image = np.array(image)
detected_map = np.zeros_like(input_image, dtype=np.uint8)
detected_map[np.min(input_image, axis=2) < 127] = 255
return detected_map

View File

@@ -0,0 +1,642 @@
import os
import gc
import json
import re
from PIL import PngImagePlugin
from datetime import datetime as dt
from csv import DictWriter
from pathlib import Path
import numpy as np
from random import randint
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
def get_extended_name(model_name):
device = args.device.split("://", 1)[0]
extended_name = "{}_{}".format(model_name, device)
return extended_name
def get_vmfb_path_name(model_name):
vmfb_path = os.path.join(os.getcwd(), model_name + ".vmfb")
return vmfb_path
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
vmfb_path = get_vmfb_path_name(model_name)
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), model_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.parser import shark_args
# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache
from shark.shark_downloader import download_model
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=tank_url,
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
)
return _compile_module(shark_module, model_name, extra_args)
# Converts the torch-module into a shark_module.
def compile_through_fx(
model,
inputs,
model_name,
is_f16=False,
f16_input_mask=None,
use_tuned=False,
save_dir=tempfile.gettempdir(),
debug=False,
generate_vmfb=True,
extra_args=[],
):
from shark.parser import shark_args
if "cuda" in args.device:
shark_args.enable_tf32 = True
(
mlir_module,
func_name,
) = import_with_fx(
model=model,
inputs=inputs,
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=model_name,
save_dir=save_dir,
)
if use_tuned:
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
mlir_module = sd_model_annotation(mlir_module, model_name)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
if generate_vmfb:
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
del mlir_module
gc.collect()
return _compile_module(shark_module, model_name, extra_args)
del mlir_module
gc.collect()
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def set_init_device_flags():
if "vulkan" in args.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "cpu" in args.device:
args.device = "cpu"
# set max_length based on availability.
if args.hf_model_id in [
"Linaqruf/anything-v3.0",
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]:
args.max_length = 77
elif args.hf_model_id == "prompthero/openjourney":
args.max_length = 64
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
if (
args.precision != "fp16"
or args.height != 512
or args.width != 512
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
args.use_tuned = False
elif base_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
"runwayml/stable-diffusion-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
]:
args.use_tuned = False
elif "vulkan" in args.device and not any(
x in args.iree_vulkan_target_triple for x in ["rdna2", "rdna3"]
):
args.use_tuned = False
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]:
args.use_tuned = False
elif args.use_base_vae and args.hf_model_id not in [
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
else:
print("Tuned models are currently not supported for this setting.")
# set import_mlir to True for unuploaded models.
if args.ckpt_loc != "":
args.import_mlir = True
elif args.hf_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.import_mlir = True
elif args.height != 512 or args.width != 512 or args.batch_size != 1:
args.import_mlir = True
elif args.use_tuned and args.hf_model_id in [
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"stabilityai/stable-diffusion-2-1",
]:
args.import_mlir = True
elif (
args.use_tuned
and "vulkan" in args.device
and "rdna2" in args.iree_vulkan_target_triple
):
args.import_mlir = True
elif (
args.use_tuned
and "cuda" in args.device
and get_cuda_sm_cc() == "sm_89"
):
args.import_mlir = True
# Utility to get list of devices available.
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
for i, device in enumerate(device_list_dict):
device_list.append(f"{device['name']} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
return available_devices
def disk_space_check(path, lim=20):
from shutil import disk_usage
du = disk_usage(path)
free = du.free / (1024 * 1024 * 1024)
if free <= lim:
print(f"[WARNING] Only {free:.2f}GB space available in {path}.")
def get_opt_flags(model, precision="fp16"):
iree_flags = []
is_tuned = "tuned" if args.use_tuned else "untuned"
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
"default_compilation_flags"
]
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
if (
device
not in opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
]
):
device = "default_device"
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
return iree_flags
def get_path_stem(path):
path = Path(path)
return path.stem
def get_path_to_diffusers_checkpoint(custom_weights):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = path.stem
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
return path_to_diffusers
def preprocessCKPT(custom_weights, is_inpaint=False):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
return
else:
print(
"Diffusers' checkpoint will be identified here : ",
path_to_diffusers,
)
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
extract_ema = False
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
num_in_channels = 9 if is_inpaint else 4
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
)
pipe.save_pretrained(path_to_diffusers)
print("Loading complete")
def load_vmfb(vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module = SharkInference(mlir_module=None, device=args.device)
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
# This utility returns vmfbs of Clip, Unet, Vae and Vae_encode, in case all of them
# are present; deletes them otherwise.
def fetch_or_delete_vmfbs(extended_model_name, precision="fp32"):
vmfb_path = [
get_vmfb_path_name(extended_model_name[model])
for model in extended_model_name
]
number_of_vmfbs = len(vmfb_path)
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = True
compiled_models = [None] * number_of_vmfbs
for i in range(number_of_vmfbs):
all_vmfb_present = all_vmfb_present and vmfb_present[i]
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(number_of_vmfbs):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
else:
model_name = [model for model in extended_model_name.keys()]
for i in range(number_of_vmfbs):
compiled_models[i] = load_vmfb(
vmfb_path[i], model_name[i], precision
)
return compiled_models
# `fetch_and_update_base_model_id` is a resource utility function which
# helps maintaining mapping of the model to run with its base model.
# If `base_model` is "", then this function tries to fetch the base model
# info for the `model_to_run`.
def fetch_and_update_base_model_id(model_to_run, base_model=""):
variants_path = os.path.join(os.getcwd(), "variants.json")
data = {model_to_run: base_model}
json_data = {}
if os.path.exists(variants_path):
with open(variants_path, "r", encoding="utf-8") as jsonFile:
json_data = json.load(jsonFile)
# Return with base_model's info if base_model is "".
if base_model == "":
if model_to_run in json_data:
base_model = json_data[model_to_run]
return base_model
elif base_model == "":
return base_model
# Update JSON data to contain an entry mapping model_to_run with base_model.
json_data.update(data)
with open(variants_path, "w", encoding="utf-8") as jsonFile:
json.dump(json_data, jsonFile)
# Generate and return a new seed if the provided one is not in the supported range (including -1)
def sanitize_seed(seed):
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed
# clear all the cached objects to recompile cleanly.
def clear_all():
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
from glob import glob
import shutil
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
if os.path.exists(yaml):
os.remove(yaml)
home = os.path.expanduser("~")
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info={}):
output_path = args.output_dir if args.output_dir else Path.cwd()
generated_imgs_path = Path(
output_path, "generated_imgs", dt.now().strftime("%Y%m%d")
)
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
img_model = args.hf_model_id
if args.ckpt_loc:
img_model = Path(os.path.basename(args.ckpt_loc)).stem
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
)
new_entry = {
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
"PROMPT": args.prompts[0],
"NEG_PROMPT": args.negative_prompts[0],
"SEED": img_seed,
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
}
new_entry.update(extra_info)
with open(csv_path, "a", encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
dictwriter_obj.writerow(new_entry)
csv_obj.close()
if args.save_metadata_to_json:
del new_entry["OUTPUT"]
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
return text_output

View File

@@ -0,0 +1,15 @@
You need to pre-create your bot (https://core.telegram.org/bots#how-do-i-create-a-bot)
Then create in the directory web file .env
In it the record:
TG_TOKEN="your_token"
specifying your bot's token from previous step.
Then run telegram_bot.py with the same parameters that you use when running index.py, for example:
python telegram_bot.py --max_length=77 --vulkan_large_heap_block_size=0 --use_base_vae --local_tank_cache h:\shark\TEMP
Bot commands:
/select_model
/select_scheduler
/set_steps "integer number of steps"
/set_guidance_scale "integer number"
/set_negative_prompt "negative text"
Any other text triggers the creation of an image based on it.

View File

@@ -0,0 +1,204 @@
import os
import sys
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
import gradio as gr
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src import args, clear_all
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
)
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create the custom model folder if it doesn't already exist
get_custom_model_path().mkdir(parents=True, exist_ok=True)
if args.clear_all:
clear_all()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_gallery,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
img2img_web,
img2img_gallery,
img2img_init_image,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
img2img_sendto_upscaler,
inpaint_web,
inpaint_gallery,
inpaint_init_image,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
inpaint_sendto_upscaler,
outpaint_web,
outpaint_gallery,
outpaint_init_image,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
outpaint_sendto_upscaler,
upscaler_web,
upscaler_gallery,
upscaler_init_image,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
lora_train_web,
)
# init global sd pipeline and config
global_obj.init()
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
img2img_web.render()
with gr.TabItem(label="Inpainting", id=2):
inpaint_web.render()
with gr.TabItem(label="Outpainting", id=3):
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.TabItem(label="LoRA Training", id=5):
lora_train_web.render()
register_button_click(
txt2img_sendto_img2img,
1,
[txt2img_gallery],
[img2img_init_image, tabs],
)
register_button_click(
txt2img_sendto_inpaint,
2,
[txt2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_outpaint,
3,
[txt2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_upscaler,
4,
[txt2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
img2img_sendto_inpaint,
2,
[img2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_outpaint,
3,
[img2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_upscaler,
4,
[img2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
inpaint_sendto_img2img,
1,
[inpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
inpaint_sendto_outpaint,
3,
[inpaint_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
inpaint_sendto_upscaler,
4,
[inpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
outpaint_sendto_img2img,
1,
[outpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
outpaint_sendto_inpaint,
2,
[outpaint_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
outpaint_sendto_upscaler,
4,
[outpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
upscaler_sendto_img2img,
1,
[upscaler_gallery],
[img2img_init_image, tabs],
)
register_button_click(
upscaler_sendto_inpaint,
2,
[upscaler_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
upscaler_sendto_outpaint,
3,
[upscaler_gallery],
[outpaint_init_image, tabs],
)
sd_web.queue()
sd_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

View File

@@ -0,0 +1,41 @@
from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_web,
txt2img_gallery,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_web,
img2img_gallery,
img2img_init_image,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
img2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.inpaint_ui import (
inpaint_web,
inpaint_gallery,
inpaint_init_image,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
inpaint_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.outpaint_ui import (
outpaint_web,
outpaint_gallery,
outpaint_init_image,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
outpaint_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.upscaler_ui import (
upscaler_web,
upscaler_gallery,
upscaler_init_image,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
)
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web

View File

@@ -0,0 +1,198 @@
/*
Apply Gradio dark theme to the default Gradio theme.
Procedure to upgrade the dark theme:
- Using your browser, visit http://localhost:8080/?__theme=dark
- Open your browser inspector, search for the .dark css class
- Copy .dark class declarations, apply them here into :root
*/
:root {
--body-background-fill: var(--background-fill-primary);
--body-text-color: var(--neutral-100);
--color-accent-soft: var(--neutral-700);
--background-fill-primary: var(--neutral-950);
--background-fill-secondary: var(--neutral-900);
--border-color-accent: var(--neutral-600);
--border-color-primary: var(--neutral-700);
--link-text-color-active: var(--secondary-500);
--link-text-color: var(--secondary-500);
--link-text-color-hover: var(--secondary-400);
--link-text-color-visited: var(--secondary-600);
--body-text-color-subdued: var(--neutral-400);
--shadow-spread: 1px;
--block-background-fill: var(--neutral-800);
--block-border-color: var(--border-color-primary);
--block_border_width: None;
--block-info-text-color: var(--body-text-color-subdued);
--block-label-background-fill: var(--background-fill-secondary);
--block-label-border-color: var(--border-color-primary);
--block_label_border_width: None;
--block-label-text-color: var(--neutral-200);
--block_shadow: None;
--block_title_background_fill: None;
--block_title_border_color: None;
--block_title_border_width: None;
--block-title-text-color: var(--neutral-200);
--panel-background-fill: var(--background-fill-secondary);
--panel-border-color: var(--border-color-primary);
--panel_border_width: None;
--checkbox-background-color: var(--neutral-800);
--checkbox-background-color-focus: var(--checkbox-background-color);
--checkbox-background-color-hover: var(--checkbox-background-color);
--checkbox-background-color-selected: var(--secondary-600);
--checkbox-border-color: var(--neutral-700);
--checkbox-border-color-focus: var(--secondary-500);
--checkbox-border-color-hover: var(--neutral-600);
--checkbox-border-color-selected: var(--secondary-600);
--checkbox-border-width: var(--input-border-width);
--checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
--checkbox-label-border-color: var(--border-color-primary);
--checkbox-label-border-color-hover: var(--checkbox-label-border-color);
--checkbox-label-border-width: var(--input-border-width);
--checkbox-label-text-color: var(--body-text-color);
--checkbox-label-text-color-selected: var(--checkbox-label-text-color);
--error-background-fill: var(--background-fill-primary);
--error-border-color: var(--border-color-primary);
--error_border_width: None;
--error-text-color: #ef4444;
--input-background-fill: var(--neutral-800);
--input-background-fill-focus: var(--secondary-600);
--input-background-fill-hover: var(--input-background-fill);
--input-border-color: var(--border-color-primary);
--input-border-color-focus: var(--neutral-700);
--input-border-color-hover: var(--input-border-color);
--input_border_width: None;
--input-placeholder-color: var(--neutral-500);
--input_shadow: None;
--input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset);
--loader_color: None;
--slider_color: None;
--stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600));
--table-border-color: var(--neutral-700);
--table-even-background-fill: var(--neutral-950);
--table-odd-background-fill: var(--neutral-900);
--table-row-focus: var(--color-accent-soft);
--button-border-width: var(--input-border-width);
--button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
--button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
--button-cancel-border-color: #dc2626;
--button-cancel-border-color-hover: var(--button-cancel-border-color);
--button-cancel-text-color: white;
--button-cancel-text-color-hover: var(--button-cancel-text-color);
--button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600));
--button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500));
--button-primary-border-color: var(--primary-500);
--button-primary-border-color-hover: var(--button-primary-border-color);
--button-primary-text-color: white;
--button-primary-text-color-hover: var(--button-primary-text-color);
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
--button-secondary-border-color: var(--neutral-600);
--button-secondary-border-color-hover: var(--button-secondary-border-color);
--button-secondary-text-color: white;
--button-secondary-text-color-hover: var(--button-secondary-text-color);
--block-border-width: 1px;
--block-label-border-width: 1px;
--form-gap-width: 1px;
--error-border-width: 1px;
--input-border-width: 1px;
}
/* SHARK theme */
/* display in full width for desktop devices */
@media (min-width: 1536px)
{
.gradio-container {
max-width: var(--size-full) !important;
}
}
.gradio-container .contain {
padding: 0 var(--size-4) !important;
}
.container {
background-color: black !important;
padding-top: var(--size-5) !important;
}
#ui_title {
padding: var(--size-2) 0 0 var(--size-1);
}
#top_logo {
background-color: transparent;
border-radius: 0 !important;
border: 0;
}
#demo_title_outer {
border-radius: 0;
}
#prompt_box_outer div:first-child {
border-radius: 0 !important
}
#prompt_box textarea, #negative_prompt_box textarea {
background-color: var(--background-fill-primary) !important;
}
#prompt_examples {
margin: 0 !important;
}
#prompt_examples svg {
display: none !important;
}
#ui_body {
padding: var(--size-2) !important;
border-radius: 0.5em !important;
}
#img_result+div {
display: none !important;
}
footer {
display: none !important;
}
#gallery + div {
border-radius: 0 !important;
}
/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */
#gallery .wrap.default {
pointer-events: none;
}
/* Import Png info box */
#txt2img_prompt_image .fixed-height {
height: var(--size-32);
}
/* Hide "remove buttons" from ui dropdowns */
#custom_model .token-remove.remove-all,
#scheduler .token-remove.remove-all,
#device .token-remove.remove-all,
#stencil_model .token-remove.remove-all {
display: none;
}
/* Hide selected items from ui dropdowns */
#custom_model .options .item .inner-item,
#scheduler .options .item .inner-item,
#device .options .item .inner-item,
#stencil_model .options .item .inner-item {
display:none;
}
/* Hide the download icon from the nod logo */
#top_logo .download {
display: none;
}

View File

@@ -0,0 +1,261 @@
from pathlib import Path
import os
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import img2img_inf
from apps.stable_diffusion.src import args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
predefined_models,
cancel_sd,
)
with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
elem_id="negative_prompt_box",
)
img2img_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
with gr.Accordion(label="Stencil Options", open=False):
with gr.Row():
use_stencil = gr.Dropdown(
elem_id="stencil_model",
label="Stencil model",
value="None",
choices=["None", "canny", "openpose", "scribble"],
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="PNDM",
choices=scheduler_list,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
strength = gr.Slider(
0,
1,
value=args.strength,
step=0.01,
label="Denoising Strength",
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
img2img_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
std_output = gr.Textbox(
value="Nothing to show.",
lines=1,
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
with gr.Row():
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
img2img_sendto_outpaint = gr.Button(
value="SendTo Outpaint"
)
img2img_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
kwargs = dict(
fn=img2img_inf,
inputs=[
prompt,
negative_prompt,
img2img_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
use_stencil,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
],
outputs=[img2img_gallery, std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -0,0 +1,263 @@
from pathlib import Path
import os
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import inpaint_inf
from apps.stable_diffusion.src import args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
predefined_paint_models,
cancel_sd,
)
with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
choices=["None"]
+ get_custom_model_files()
+ predefined_paint_models,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
elem_id="negative_prompt_box",
)
inpaint_init_image = gr.Image(
label="Masked Image",
source="upload",
tool="sketch",
type="pil",
).style(height=350)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="PNDM",
choices=scheduler_list,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
inpaint_full_res = gr.Radio(
choices=["Whole picture", "Only masked"],
type="index",
value="Whole picture",
label="Inpaint area",
)
inpaint_full_res_padding = gr.Slider(
minimum=0,
maximum=256,
step=4,
value=32,
label="Only masked padding, pixels",
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
inpaint_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
std_output = gr.Textbox(
value="Nothing to show.",
lines=1,
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
with gr.Row():
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
inpaint_sendto_outpaint = gr.Button(
value="SendTo Outpaint"
)
inpaint_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
kwargs = dict(
fn=inpaint_inf,
inputs=[
prompt,
negative_prompt,
inpaint_init_image,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
],
outputs=[inpaint_gallery, std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

View File

@@ -0,0 +1,205 @@
from pathlib import Path
import os
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import lora_train
from apps.stable_diffusion.src import prompt_examples, args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list_txt2img,
predefined_models,
)
with gr.Blocks(title="Lora Training") as lora_train_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Group(elem_id="image_dir_box_outer"):
training_images_dir = gr.Textbox(
label="ImageDirectory",
value=args.training_images_dir,
lines=1,
elem_id="prompt_box",
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
choices=scheduler_list_txt2img,
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1,
2000,
value=args.training_steps,
step=1,
label="Training Steps",
)
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
train_lora = gr.Button("Train LoRA")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
std_output = gr.Textbox(
value="Nothing to show.",
lines=1,
show_label=False,
)
lora_save_dir = (
args.lora_save_dir if args.lora_save_dir else Path.cwd()
)
lora_save_dir = Path(lora_save_dir, "lora")
output_loc = gr.Textbox(
label="Saving Lora at",
value=lora_save_dir,
)
kwargs = dict(
fn=lora_train,
inputs=[
prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
training_images_dir,
output_loc,
],
outputs=[std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
train_click = train_lora.click(**kwargs)
stop_batch.click(fn=None, cancels=[prompt_submit, train_click])

View File

@@ -0,0 +1,283 @@
from pathlib import Path
import os
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import outpaint_inf
from apps.stable_diffusion.src import args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
predefined_paint_models,
cancel_sd,
)
with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
choices=["None"]
+ get_custom_model_files()
+ predefined_paint_models,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
elem_id="negative_prompt_box",
)
outpaint_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="PNDM",
choices=scheduler_list,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
pixels = gr.Slider(
8,
256,
value=args.pixels,
step=8,
label="Pixels to expand",
)
mask_blur = gr.Slider(
0,
64,
value=args.mask_blur,
step=1,
label="Mask blur",
)
with gr.Row():
directions = gr.CheckboxGroup(
label="Outpainting direction",
choices=["left", "right", "up", "down"],
value=["left", "right", "up", "down"],
)
with gr.Row():
noise_q = gr.Slider(
0.0,
4.0,
value=1.0,
step=0.01,
label="Fall-off exponent (lower=higher detail)",
)
color_variation = gr.Slider(
0.0,
1.0,
value=0.05,
step=0.01,
label="Color variation",
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=20, step=1, label="Steps"
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
outpaint_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
std_output = gr.Textbox(
value="Nothing to show.",
lines=1,
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
with gr.Row():
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
outpaint_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
kwargs = dict(
fn=outpaint_inf,
inputs=[
prompt,
negative_prompt,
outpaint_init_image,
pixels,
mask_blur,
directions,
noise_q,
color_variation,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
],
outputs=[outpaint_gallery, std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -0,0 +1,279 @@
from pathlib import Path
import os
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import txt2img_inf
from apps.stable_diffusion.src import prompt_examples, args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list_txt2img,
predefined_models,
cancel_sd,
)
with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Column(scale=1, min_width=170):
png_info_img = gr.Image(
label="Import PNG info",
elem_id="txt2img_prompt_image",
type="pil",
tool="None",
visible=True,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
elem_id="negative_prompt_box",
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
choices=scheduler_list_txt2img,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
txt2img_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
std_output = gr.Textbox(
value="Nothing to show.",
lines=1,
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
with gr.Row():
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
txt2img_sendto_outpaint = gr.Button(
value="SendTo Outpaint"
)
txt2img_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
kwargs = dict(
fn=txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
],
outputs=[txt2img_gallery, std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
from apps.stable_diffusion.web.utils.png_metadata import (
import_png_metadata,
)
png_info_img.change(
fn=import_png_metadata,
inputs=[
png_info_img,
],
outputs=[
png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
custom_model,
hf_model_id,
],
)

View File

@@ -0,0 +1,239 @@
from pathlib import Path
import os
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import upscaler_inf
from apps.stable_diffusion.src import args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
predefined_upscaler_models,
)
with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
choices=["None"]
+ get_custom_model_files()
+ predefined_upscaler_models,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
elem_id="negative_prompt_box",
)
upscaler_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="DDIM",
choices=scheduler_list,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
128,
512,
value=args.height,
step=128,
label="Height",
)
width = gr.Slider(
128,
512,
value=args.width,
step=128,
label="Width",
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
noise_level = gr.Slider(
0,
100,
value=args.noise_level,
step=1,
label="Noise Level",
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
upscaler_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
std_output = gr.Textbox(
value="Nothing to show.",
lines=1,
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
with gr.Row():
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
upscaler_sendto_outpaint = gr.Button(
value="SendTo Outpaint"
)
kwargs = dict(
fn=upscaler_inf,
inputs=[
prompt,
negative_prompt,
upscaler_init_image,
height,
width,
steps,
noise_level,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[upscaler_gallery, std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
)

View File

@@ -0,0 +1,105 @@
import os
import sys
from apps.stable_diffusion.src import get_available_devices
import glob
from pathlib import Path
from apps.stable_diffusion.src import args
from dataclasses import dataclass
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
@dataclass
class Config:
mode: str
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
use_lora: str
use_stencil: str
custom_model_filetypes = (
"*.ckpt",
"*.safetensors",
) # the tuple of file types
scheduler_list = [
"DDIM",
"PNDM",
"DPMSolverMultistep",
"EulerAncestralDiscrete",
]
scheduler_list_txt2img = [
"DDIM",
"PNDM",
"LMSDiscrete",
"KDPM2Discrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
]
predefined_models = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]
predefined_paint_models = [
"runwayml/stable-diffusion-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
]
predefined_upscaler_models = [
"stabilityai/stable-diffusion-x4-upscaler",
]
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
def get_custom_model_path():
return Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd(), "models")
def get_custom_model_pathfile(custom_model_name):
return os.path.join(get_custom_model_path(), custom_model_name)
def get_custom_model_files():
ckpt_files = []
for extn in custom_model_filetypes:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_custom_model_path(), extn))
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:
global_obj.set_sd_status(SD_STATE_CANCEL)
except Exception:
pass
nodlogo_loc = resource_path("logos/nod-logo.png")
available_devices = get_available_devices()

View File

@@ -0,0 +1,56 @@
import gc
"""
The global objects include SD pipeline and config.
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
Also we could avoid memory leak when switching models by clearing the cache.
"""
def init():
global sd_obj
global config_obj
sd_obj = None
config_obj = None
def set_sd_obj(value):
global sd_obj
sd_obj = value
def set_cfg_obj(value):
global config_obj
config_obj = value
def set_schedulers(value):
global sd_obj
sd_obj.scheduler = value
def get_sd_obj():
return sd_obj
def get_cfg_obj():
return config_obj
def set_sd_status(value):
global sd_obj
sd_obj.status = value
def get_sd_status():
global sd_obj
return sd_obj.status
def clear_cache():
global sd_obj
global config_obj
del sd_obj
del config_obj
gc.collect()

View File

@@ -0,0 +1,31 @@
import os
import tempfile
import gradio
from os import listdir
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
# Clear all gradio tmp images
def clear_gradio_tmp_imgs_folder():
if not os.path.exists(gradio_tmp_imgs_folder):
return
for fileName in listdir(gradio_tmp_imgs_folder):
# Delete tmp png files
if fileName.startswith("tmp") and fileName.endswith(".png"):
os.remove(gradio_tmp_imgs_folder + fileName)
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
def save_pil_to_file(pil_image, dir=None):
if not os.path.exists(gradio_tmp_imgs_folder):
os.mkdir(gradio_tmp_imgs_folder)
file_obj = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
)
pil_image.save(file_obj)
return file_obj
# Register save_pil_to_file override
gradio.processing_utils.save_pil_to_file = save_pil_to_file

View File

@@ -0,0 +1,148 @@
import re
from pathlib import Path
from apps.stable_diffusion.web.ui.txt2img_ui import (
png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
custom_model,
hf_model_id,
)
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
scheduler_list_txt2img,
predefined_models,
)
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
def parse_generation_parameters(x: str):
res = {}
prompt = ""
negative_prompt = ""
done_with_prompt = False
*lines, lastline = x.strip().split("\n")
if len(re_param.findall(lastline)) < 3:
lines.append(lastline)
lastline = ""
for i, line in enumerate(lines):
line = line.strip()
if line.startswith("Negative prompt:"):
done_with_prompt = True
line = line[16:].strip()
if done_with_prompt:
negative_prompt += ("" if negative_prompt == "" else "\n") + line
else:
prompt += ("" if prompt == "" else "\n") + line
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v)
if m is not None:
res[k + "-1"] = m.group(1)
res[k + "-2"] = m.group(2)
else:
res[k] = v
# Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res:
res["Clip skip"] = "1"
hypernet = res.get("Hypernet", None)
if hypernet is not None:
res[
"Prompt"
] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
res["Hires resize-2"] = 0
return res
def import_png_metadata(pil_data):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
png_hf_model_id = ""
png_custom_model = ""
if "Model" in metadata:
# Remove extension from model info
if metadata["Model"].endswith(".safetensors") or metadata[
"Model"
].endswith(".ckpt"):
metadata["Model"] = Path(metadata["Model"]).stem
# Check for the model name match with one of the local ckpt or safetensors files
if Path(
get_custom_model_pathfile(metadata["Model"] + ".ckpt")
).is_file():
png_custom_model = metadata["Model"] + ".ckpt"
if Path(
get_custom_model_pathfile(metadata["Model"] + ".safetensors")
).is_file():
png_custom_model = metadata["Model"] + ".safetensors"
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if metadata["Model"] in predefined_models:
png_custom_model = metadata["Model"]
# If nothing had matched, check vendor/hf_model_id
if not png_custom_model and metadata["Model"].count("/"):
png_hf_model_id = metadata["Model"]
# No matching model was found
if not png_custom_model and not png_hf_model_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% metadata["Model"]
)
outputs = {
png_info_img: None,
negative_prompt: metadata["Negative prompt"],
steps: int(metadata["Steps"]),
guidance_scale: float(metadata["CFG scale"]),
seed: int(metadata["Seed"]),
width: float(metadata["Size-1"]),
height: float(metadata["Size-2"]),
}
if "Model" in metadata and png_custom_model:
outputs[custom_model] = png_custom_model
outputs[hf_model_id] = ""
if "Model" in metadata and png_hf_model_id:
outputs[custom_model] = "None"
outputs[hf_model_id] = png_hf_model_id
if "Prompt" in metadata:
outputs[prompt] = metadata["Prompt"]
if "Sampler" in metadata:
if metadata["Sampler"] in scheduler_list_txt2img:
outputs[scheduler] = metadata["Sampler"]
else:
print(
"Import PNG info: Unable to find a scheduler for %s"
% metadata["Sampler"]
)
return outputs
except Exception as ex:
if pil_data and pil_data.info.get("parameters"):
print("import_png_metadata failed with %s" % ex)
pass
return {
png_info_img: None,
}

View File

@@ -129,12 +129,12 @@ pytest_benchmark_param = pytest.mark.parametrize(
pytest.param(True, "cpu", marks=pytest.mark.skip),
pytest.param(
False,
"cuda",
"gpu",
marks=pytest.mark.skipif(
check_device_drivers("cuda"), reason="nvidia-smi not found"
check_device_drivers("gpu"), reason="nvidia-smi not found"
),
),
pytest.param(True, "cuda", marks=pytest.mark.skip),
pytest.param(True, "gpu", marks=pytest.mark.skip),
pytest.param(
False,
"vulkan",

View File

@@ -1,88 +0,0 @@
ARG IMAGE_NAME
FROM ${IMAGE_NAME}:12.2.0-runtime-ubuntu22.04 as base
ENV NV_CUDA_LIB_VERSION "12.2.0-1"
FROM base as base-amd64
ENV NV_CUDA_CUDART_DEV_VERSION 12.2.53-1
ENV NV_NVML_DEV_VERSION 12.2.81-1
ENV NV_LIBCUSPARSE_DEV_VERSION 12.1.1.53-1
ENV NV_LIBNPP_DEV_VERSION 12.1.1.14-1
ENV NV_LIBNPP_DEV_PACKAGE libnpp-dev-12-2=${NV_LIBNPP_DEV_VERSION}
ENV NV_LIBCUBLAS_DEV_VERSION 12.2.1.16-1
ENV NV_LIBCUBLAS_DEV_PACKAGE_NAME libcublas-dev-12-2
ENV NV_LIBCUBLAS_DEV_PACKAGE ${NV_LIBCUBLAS_DEV_PACKAGE_NAME}=${NV_LIBCUBLAS_DEV_VERSION}
ENV NV_CUDA_NSIGHT_COMPUTE_VERSION 12.2.0-1
ENV NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE cuda-nsight-compute-12-2=${NV_CUDA_NSIGHT_COMPUTE_VERSION}
ENV NV_NVPROF_VERSION 12.2.60-1
ENV NV_NVPROF_DEV_PACKAGE cuda-nvprof-12-2=${NV_NVPROF_VERSION}
FROM base as base-arm64
ENV NV_CUDA_CUDART_DEV_VERSION 12.2.53-1
ENV NV_NVML_DEV_VERSION 12.2.81-1
ENV NV_LIBCUSPARSE_DEV_VERSION 12.1.1.53-1
ENV NV_LIBNPP_DEV_VERSION 12.1.1.14-1
ENV NV_LIBNPP_DEV_PACKAGE libnpp-dev-12-2=${NV_LIBNPP_DEV_VERSION}
ENV NV_LIBCUBLAS_DEV_PACKAGE_NAME libcublas-dev-12-2
ENV NV_LIBCUBLAS_DEV_VERSION 12.2.1.16-1
ENV NV_LIBCUBLAS_DEV_PACKAGE ${NV_LIBCUBLAS_DEV_PACKAGE_NAME}=${NV_LIBCUBLAS_DEV_VERSION}
ENV NV_CUDA_NSIGHT_COMPUTE_VERSION 12.2.0-1
ENV NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE cuda-nsight-compute-12-2=${NV_CUDA_NSIGHT_COMPUTE_VERSION}
FROM base-${TARGETARCH}
ARG TARGETARCH
LABEL maintainer "SHARK<stdin@nod.com>"
# Register the ROCM package repository, and install rocm-dev package
ARG ROCM_VERSION=5.6
ARG AMDGPU_VERSION=5.6
ARG APT_PREF
RUN echo "$APT_PREF" > /etc/apt/preferences.d/rocm-pin-600
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends ca-certificates curl libnuma-dev gnupg \
&& curl -sL https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - \
&& printf "deb [arch=amd64] https://repo.radeon.com/rocm/apt/$ROCM_VERSION/ jammy main" | tee /etc/apt/sources.list.d/rocm.list \
&& printf "deb [arch=amd64] https://repo.radeon.com/amdgpu/$AMDGPU_VERSION/ubuntu jammy main" | tee /etc/apt/sources.list.d/amdgpu.list \
&& apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
sudo \
libelf1 \
kmod \
file \
python3 \
python3-pip \
rocm-dev \
rocm-libs \
rocm-hip-libraries \
build-essential && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN groupadd -g 109 render
RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-cudart-dev-12-2=${NV_CUDA_CUDART_DEV_VERSION} \
cuda-command-line-tools-12-2=${NV_CUDA_LIB_VERSION} \
cuda-minimal-build-12-2=${NV_CUDA_LIB_VERSION} \
cuda-libraries-dev-12-2=${NV_CUDA_LIB_VERSION} \
cuda-nvml-dev-12-2=${NV_NVML_DEV_VERSION} \
${NV_NVPROF_DEV_PACKAGE} \
${NV_LIBNPP_DEV_PACKAGE} \
libcusparse-dev-12-2=${NV_LIBCUSPARSE_DEV_VERSION} \
${NV_LIBCUBLAS_DEV_PACKAGE} \
${NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE} \
&& rm -rf /var/lib/apt/lists/*
RUN apt install rocm-hip-libraries
# Keep apt from auto upgrading the cublas and nccl packages. See https://gitlab.com/nvidia/container-images/cuda/-/issues/88
RUN apt-mark hold ${NV_LIBCUBLAS_DEV_PACKAGE_NAME}
ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs

View File

@@ -1,41 +0,0 @@
On your host install your Nvidia or AMD gpu drivers.
**HOST Setup**
*Ubuntu 23.04 Nvidia*
```
sudo ubuntu-drivers install
```
Install [docker](https://docs.docker.com/engine/install/ubuntu/) and the post-install to run as a [user](https://docs.docker.com/engine/install/linux-postinstall/)
Install Nvidia [Container and register it](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). In Ubuntu 23.04 systems follow [this](https://github.com/NVIDIA/nvidia-container-toolkit/issues/72#issuecomment-1584574298)
Build docker with :
```
docker build . -f Dockerfile-ubuntu-22.04 -t shark/dev-22.04:5.6 --build-arg=ROCM_VERSION=5.6 --build-arg=AMDGPU_VERSION=5.6 --build-arg=APT_PREF="Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600" --build-arg=IMAGE_NAME=nvidia/cuda --build-arg=TARGETARCH=amd64
```
Run with:
*CPU*
```
docker run -it docker.io/shark/dev-22.04:5.6
```
*Nvidia GPU*
```
docker run --rm -it --gpus all docker.io/shark/dev-22.04:5.6
```
*AMD GPUs*
```
docker run --device /dev/kfd --device /dev/dri docker.io/shark/dev-22.04:5.6
```
More AMD instructions are [here](https://docs.amd.com/en/latest/deploy/docker.html)

View File

@@ -24,13 +24,13 @@ def get_image(url, local_filename):
shutil.copyfileobj(res.raw, f)
def compare_images(new_filename, golden_filename, upload=False):
def compare_images(new_filename, golden_filename):
new = np.array(Image.open(new_filename)) / 255.0
golden = np.array(Image.open(golden_filename)) / 255.0
diff = np.abs(new - golden)
mean = np.mean(diff)
if mean > 0.1:
if os.name != "nt" and upload == True:
if os.name != "nt":
subprocess.run(
[
"gsutil",
@@ -39,7 +39,7 @@ def compare_images(new_filename, golden_filename, upload=False):
"gs://shark_tank/testdata/builder/",
]
)
raise AssertionError("new and golden not close")
raise SystemExit("new and golden not close")
else:
print("SUCCESS")

View File

@@ -1,6 +1,5 @@
#!/bin/bash
IMPORTER=1 BENCHMARK=1 NO_BREVITAS=1 ./setup_venv.sh
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python build_tools/stable_diffusion_testing.py --gen
python tank/generate_sharktank.py
python generate_sharktank.py

View File

@@ -36,7 +36,9 @@ def parse_sd_out(filename, command, device, use_tune, model_name, import_mlir):
metrics[val] = line.split(" ")[-1].strip("\n")
metrics["Average step"] = metrics["Average step"].strip("ms/it")
metrics["Total image generation"] = metrics["Total image generation"].strip("sec")
metrics["Total image generation"] = metrics[
"Total image generation"
].strip("sec")
metrics["device"] = device
metrics["use_tune"] = use_tune
metrics["model_name"] = model_name
@@ -61,14 +63,7 @@ def get_inpaint_inputs():
open("./test_images/inputs/mask.png", "wb").write(mask.content)
def test_loop(
device="vulkan",
beta=False,
extra_flags=[],
upload_bool=True,
exit_on_fail=True,
do_gen=False,
):
def test_loop(device="vulkan", beta=False, extra_flags=[]):
# Get golden values from tank
shutil.rmtree("./test_images", ignore_errors=True)
model_metrics = []
@@ -76,49 +71,27 @@ def test_loop(
os.mkdir("./test_images/golden")
get_inpaint_inputs()
hf_model_names = model_config_dicts[0].values()
tuned_options = [
"--no-use_tuned",
"--use_tuned",
]
tuned_options = ["--no-use_tuned", "--use_tuned"]
import_options = ["--import_mlir", "--no-import_mlir"]
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
inpaint_prompt_text = (
"--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
)
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
if os.name == "nt":
prompt_text = '--prompt="cyberpunk forest by Salvador Dali"'
inpaint_prompt_text = (
'--prompt="Face of a yellow cat, high resolution, sitting on a park bench"'
)
inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"'
if beta:
extra_flags.append("--beta_models=True")
extra_flags.append("--no-progress_bar")
if do_gen:
extra_flags.append("--import_debug")
to_skip = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]
counter = 0
for import_opt in import_options:
for model_name in hf_model_names:
if model_name in to_skip:
continue
for use_tune in tuned_options:
if (
model_name == "stabilityai/stable-diffusion-2-1"
and use_tune == tuned_options[0]
):
continue
elif (
model_name == "stabilityai/stable-diffusion-2-1-base"
and use_tune == tuned_options[1]
):
continue
elif use_tune == tuned_options[1]:
continue
command = (
[
executable, # executable is the python from the venv used to run this
@@ -176,7 +149,9 @@ def test_loop(
)
print(command)
print("Successfully generated image")
os.makedirs("./test_images/golden/" + model_name, exist_ok=True)
os.makedirs(
"./test_images/golden/" + model_name, exist_ok=True
)
download_public_file(
"gs://shark_tank/testdata/golden/" + model_name,
"./test_images/golden/" + model_name,
@@ -191,35 +166,17 @@ def test_loop(
)
test_file = glob(test_file_path)[0]
golden_path = "./test_images/golden/" + model_name + "/*.png"
golden_path = (
"./test_images/golden/" + model_name + "/*.png"
)
golden_file = glob(golden_path)[0]
try:
compare_images(test_file, golden_file, upload=upload_bool)
except AssertionError as e:
print(e)
if exit_on_fail == True:
raise
compare_images(test_file, golden_file)
else:
print(command)
print("failed to generate image for this configuration")
with open(dumpfile_name, "r+") as f:
output = f.readlines()
print("\n".join(output))
exit(1)
if os.name == "nt":
counter += 1
if counter % 2 == 0:
extra_flags.append(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
else:
if counter != 1:
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
if do_gen:
prepare_artifacts()
if "2_1_base" in model_name:
print("failed a known successful model.")
exit(1)
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)
@@ -238,47 +195,15 @@ def test_loop(
f.write(";".join(output) + "\n")
def prepare_artifacts():
gen_path = os.path.join(os.getcwd(), "gen_shark_tank")
if not os.path.isdir(gen_path):
os.mkdir(gen_path)
for dirname in os.listdir(os.getcwd()):
for modelname in ["clip", "unet", "vae"]:
if modelname in dirname and "vmfb" not in dirname:
if not os.path.isdir(os.path.join(gen_path, dirname)):
shutil.move(os.path.join(os.getcwd(), dirname), gen_path)
print(f"Moved dir: {dirname} to {gen_path}.")
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--device", default="vulkan")
parser.add_argument(
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
)
parser.add_argument("-e", "--extra_args", type=str, default=None)
parser.add_argument(
"-u", "--upload", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument("-g", "--gen", action=argparse.BooleanOptionalAction, default=False)
if __name__ == "__main__":
args = parser.parse_args()
print(args)
extra_args = []
if args.extra_args:
for arg in args.extra_args.split(","):
extra_args.append(arg)
test_loop(
args.device,
args.beta,
extra_args,
args.upload,
args.exit_on_fail,
args.gen,
)
if args.gen:
prepare_artifacts()
test_loop(args.device, args.beta, [])

View File

@@ -1,14 +0,0 @@
import os
from sys import executable
import subprocess
from apps.language_models.scripts import vicuna
def test_loop():
precisions = ["fp16", "int8", "int4"]
devices = ["cpu"]
for precision in precisions:
for device in devices:
model = vicuna.UnshardedVicuna(device=device, precision=precision)
model.compile()
del model

View File

@@ -2,11 +2,9 @@ def pytest_addoption(parser):
# Attaches SHARK command-line arguments to the pytest machinery.
parser.addoption(
"--benchmark",
action="store",
type=str,
default=None,
choices=("baseline", "native", "all"),
help="Benchmarks specified engine(s) and writes bench_results.csv.",
action="store_true",
default="False",
help="Pass option to benchmark and write results.csv",
)
parser.addoption(
"--onnx_bench",
@@ -42,13 +40,7 @@ def pytest_addoption(parser):
"--update_tank",
action="store_true",
default="False",
help="Update local shark tank with latest artifacts if model artifact hash mismatched.",
)
parser.addoption(
"--force_update_tank",
action="store_true",
default="False",
help="Force-update local shark tank with artifacts from specified shark_tank URL (defaults to nightly).",
help="Update local shark tank with latest artifacts.",
)
parser.addoption(
"--ci_sha",
@@ -59,21 +51,15 @@ def pytest_addoption(parser):
parser.addoption(
"--local_tank_cache",
action="store",
default=None,
default="",
help="Specify the directory in which all downloaded shark_tank artifacts will be cached.",
)
parser.addoption(
"--tank_url",
type=str,
default="gs://shark_tank/nightly",
default="gs://shark_tank/latest",
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
)
parser.addoption(
"--tank_prefix",
type=str,
default=None,
help="Prefix to gs://shark_tank/ model directories from which to download SHARK tank artifacts. Default is nightly.",
)
parser.addoption(
"--benchmark_dispatches",
default=None,
@@ -84,9 +70,3 @@ def pytest_addoption(parser):
default="./temp_dispatch_benchmarks",
help="Directory in which dispatch benchmarks are saved.",
)
parser.addoption(
"--batchsize",
default=1,
type=int,
help="Batch size for the tested model.",
)

View File

@@ -27,7 +27,7 @@ include(FetchContent)
FetchContent_Declare(
iree
GIT_REPOSITORY https://github.com/nod-ai/srt.git
GIT_REPOSITORY https://github.com/nod-ai/shark-runtime.git
GIT_TAG shark
GIT_SUBMODULES_RECURSE OFF
GIT_SHALLOW OFF

View File

@@ -40,7 +40,7 @@ cmake --build build/
*Prepare the model*
```bash
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux resnet50_tf.mlir -o resnet50_tf.vmfb
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
```
*Prepare the input*
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
```bash
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux stable_diff_tf.mlir -o stable_diff_tf.vmfb
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
```
VAE and Autoencoder are also available
```bash
# VAE
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux vae.mlir -o vae.vmfb
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
# CLIP Autoencoder
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux clip_autoencoder.mlir -o clip_autoencoder.vmfb
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
```

View File

@@ -21,7 +21,7 @@ endif()
# Compile mnist.mlir to mnist.vmfb.
set(_COMPILE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)
set(_COMPILE_ARGS)
list(APPEND _COMPILE_ARGS "--iree-input-type=auto")
list(APPEND _COMPILE_ARGS "--iree-input-type=mhlo")
list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu")
list(APPEND _COMPILE_ARGS "${IREE_SOURCE_DIR}/samples/models/mnist.mlir")
list(APPEND _COMPILE_ARGS "-o")

View File

@@ -10,7 +10,9 @@ from utils import get_datasets
shark_root = Path(__file__).parent.parent
demo_css = shark_root.joinpath("web/demo.css").resolve()
nodlogo_loc = shark_root.joinpath("web/models/stable_diffusion/logos/nod-logo.png")
nodlogo_loc = shark_root.joinpath(
"web/models/stable_diffusion/logos/nod-logo.png"
)
with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
@@ -21,11 +23,8 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
value=nod_logo,
show_label=False,
interactive=False,
show_download_button=False,
elem_id="top_logo",
width=150,
height=100,
)
).style(width=150, height=100)
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
prompt_data = dict()
@@ -38,7 +37,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_body"):
# TODO: add ability to search image by typing
with gr.Column(scale=1, min_width=600):
image = gr.Image(type="filepath", height=512)
image = gr.Image(type="filepath").style(height=512)
with gr.Column(scale=1, min_width=600):
prompts = gr.Dropdown(
@@ -74,7 +73,9 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
for line in reader.iter(type=dict, skip_invalid=True):
prompt_data[line["file_name"]] = (
[line["text"]] if type(line["text"]) is str else line["text"]
[line["text"]]
if type(line["text"]) is str
else line["text"]
)
return gr.Dropdown.update(choices=images[dataset])
@@ -100,7 +101,9 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
prompt_data[image_name] = []
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Image.update(value=img), gr.Dropdown.update(choices=prompt_choices)
return gr.Image.update(value=img), gr.Dropdown.update(
choices=prompt_choices
)
image_name.change(
fn=display_image,
@@ -117,7 +120,12 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
prompts.change(fn=edit_prompt, inputs=prompts, outputs=prompt)
def save_prompt(dataset, image_name, prompts, prompt):
if dataset is None or image_name is None or prompts is None or prompt is None:
if (
dataset is None
or image_name is None
or prompts is None
or prompt is None
):
return
if prompts == "Add new":
@@ -126,7 +134,9 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
idx = prompt_data[image_name].index(prompts)
prompt_data[image_name][idx] = prompt
prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
prompt_path = (
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
)
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
@@ -153,7 +163,9 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
return
prompt_data[image_name].remove(prompts)
prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
prompt_path = (
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
)
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
@@ -216,7 +228,9 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
# upload prompt and remove local data
dataset_path = str(shark_root) + "/dataset/" + dataset
dataset_gs_path = args.gs_url + "/" + dataset + "/"
os.system(f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"')
os.system(
f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"'
)
os.system(f'rm -rf "{dataset_path}"')
return gr.Dropdown.update(value=None)

View File

@@ -1,3 +1,3 @@
# SHARK Annotator
gradio==3.34.0
gradio==3.15.0
jsonlines

View File

@@ -55,7 +55,7 @@ The command line for compilation will start something like this, where the `-` n
The `-o output_filename.vmfb` flag can be used to specify the location to save the compiled vmfb. Note that a dump of the
dispatches that can be compiled + run in isolation can be generated by adding `--iree-hal-dump-executable-benchmarks-to=/some/directory`. Say, if they are in the `benchmarks` directory, the following compile/run commands would work for Vulkan on RDNA3.
```
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
iree-benchmark-module --module=benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb --function=forward --device=vulkan
```
@@ -63,8 +63,8 @@ Where `${NUM}` is the dispatch number that you want to benchmark/profile in isol
### Enabling Tracy for Vulkan profiling
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime (SRT) builds an
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SRT/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime builds an
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SHARK-Runtime/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
```
$env:IREE_ENABLE_RUNTIME_TRACING="ON"
```

View File

@@ -1,75 +0,0 @@
# Overview
This document is intended to provide a starting point for using SHARK stable diffusion with Blender.
We currently make use of the [AI-Render Plugin](https://github.com/benrugg/AI-Render) to integrate with Blender.
## Setup SHARK and prerequisites:
* Download the latest SHARK SD webui .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow instructions on the [README](https://github.com/nod-ai/SHARK#readme)
* Once you have the .exe where you would like SHARK to install, run the .exe from terminal/PowerShell with the `--api` flag:
```
## Run the .exe in API mode:
.\shark_sd_<date>_<ver>.exe --api
## For example:
.\shark_sd_20230411_671.exe --api --server_port=8082
## From a the base directory of a source clone of SHARK:
./setup_venv.ps1
python apps\stable_diffusion\web\index.py --api
```
Your local SD server should start and look something like this:
![image](https://user-images.githubusercontent.com/87458719/231369758-e2c3c45a-eccc-4fe5-a788-4a3bf1ace1d1.png)
* Note: When running in api mode with `--api`, the .exe will not function as a webUI. Thus, the address in the terminal output will only be useful for API requests.
### Install AI Render
- Get AI Render on [Blender Market](https://blendermarket.com/products/ai-render) or [Gumroad](https://airender.gumroad.com/l/ai-render)
- Open Blender, then go to Edit > Preferences > Add-ons > Install and then find the zip file
- We will be using the Automatic1111 SD backend for the AI-Render plugin. Follow instructions [here](https://github.com/benrugg/AI-Render/wiki/Local-Installation) to setup local SD backend.
Your AI-Render preferences should be configured as shown; the highlighted part should match your terminal output:
![image](https://user-images.githubusercontent.com/87458719/231390322-59a54a09-520a-4a08-b658-6e37bd63e932.png)
The [AI-Render README](https://github.com/benrugg/AI-Render/blob/main/README.md) has more details on installation and usage, as well as video tutorials.
## Using AI-Render + SHARK in your Blender project
- In the Render Properties tab, in the AI-Render dropdown, enable AI-Render.
![image](https://user-images.githubusercontent.com/87458719/231392843-9bd51744-3ce2-464e-843a-0c4d4c96df0c.png)
- Select an image size (it's usually better to upscale later than go high on the img2img resolution here.)
![image](https://user-images.githubusercontent.com/87458719/231394288-0c4ab8c5-dc30-4dbe-8bc1-7520ded5efe8.png)
- From here, you can enter a prompt and configure img2img Stable Diffusion parameters, and AI-Render will run SHARK SD img2img on the rendered scene.
- AI-Render has useful presets for aesthetic styles, so you should be able to keep your subject prompt simple and focus on creating a decent Blender scene to start from.
![image](https://user-images.githubusercontent.com/87458719/231440729-2fe69586-41cb-4274-9ce7-f6c08def600b.png)
## Examples:
Scene (Input image):
![blender-sample-2](https://user-images.githubusercontent.com/87458719/231450408-0e680086-3e52-4962-a5c1-c703a94d1583.png)
Prompt:
"A bowl of tangerines in front of rocks, masterpiece, oil on canvas, by Georgia O'Keefe, trending on artstation, landscape painting by Caspar David Friedrich"
Negative Prompt (default):
"ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
Example output:
![blender-sample-2_out](https://user-images.githubusercontent.com/87458719/231451145-a0b56897-a7d0-4add-bbed-7e8af21a65df.png)

View File

@@ -1,140 +0,0 @@
# Overview
In [1.47.2](https://github.com/LostRuins/koboldcpp/releases/tag/v1.47.2) [Koboldcpp](https://github.com/LostRuins/koboldcpp) added AUTOMATIC1111 integration for image generation. Since SHARK implements a small subset of the A1111 REST api, you can also use SHARK for this. This document gives a starting point for how to get this working.
## In Action
![preview](https://user-images.githubusercontent.com/121311569/280557602-bb97bad0-fdf5-4922-a2cc-4f327f2760db.jpg)
## Memory considerations
Since both Koboldcpp and SHARK will use VRAM on your graphic card(s) running both at the same time using the same card will impose extra limitations on the model size you can fully offload to the video card in Koboldcpp. For me, on a RX 7900 XTX on Windows with 24 GiB of VRAM, the limit was about a 13 Billion parameter model with Q5_K_M quantisation.
## Performance Considerations
When using SHARK for image generation, especially with Koboldcpp, you need to be aware that it is currently designed to pay a large upfront cost in time compiling and tuning the model you select, to get an optimal individual image generation time. You need to be the judge as to whether this trade-off is going to be worth it for your OS and hardware combination.
It means that the first time you run a particular Stable Diffusion model for a particular combination of image size, LoRA, and VAE, SHARK will spend *many minutes* - even on a beefy machaine with very fast graphics card with lots of memory - building that model combination just so it can save it to disk. It may even have to go away and download the model if it doesn't already have it locally. Once it has done its build of a model combination for your hardware once, it shouldn't need to do it again until you upgrade to a newer SHARK version, install different drivers or change your graphics hardware. It will just upload the files it generated the first time to your graphics card and proceed from there.
This does mean however, that on a brand new fresh install of SHARK that has not generated any images on a model you haven't selected before, the first image Koboldcpp requests may look like it is *never* going finish and that the whole process has broken. Be forewarned, make yourself a cup of coffee, and expect a lot of messages about compilation and tuning from SHARK in the terminal you ran it from.
## Setup SHARK and prerequisites:
* Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme).
* Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install.
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_accept_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
```powershell
## Run the .exe in API mode, with CORS support, on the A1111 endpoint port:
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_accept_origin="*" --server_port=7860
## Run trom the base directory of a source clone of SHARK on Windows:
.\setup_venv.ps1
python .\apps\stable_diffusion\web\index.py --api --api_accept_origin="*" --server_port=7860
## Run a the base directory of a source clone of SHARK on Linux:
./setup_venv.sh
source shark.venv/bin/activate
python ./apps/stable_diffusion/web/index.py --api --api_accept_origin="*" --server_port=7860
## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860
## Since the api respects most applicable SHARK command line arguments for options not specified,
## or currently unimplemented by API, there might be some you want to set, as listed in `--help`
.\node_ai_shark_studio_20320901_2525.exe --help
## For instance, the example above, but with a a custom VAE specified
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
## An example with multiple specific CORS origins
python apps/stable_diffusion/web/index.py --api --api_accept_origin="koboldcpp.example.com:7001" --api_accept_origin="koboldcpp.example.com:7002" --server_port=7860
```
SHARK should start in server mode, and you should see something like this:
![SHARK API startup](https://user-images.githubusercontent.com/121311569/280556294-c3f7fc1a-c8e2-467d-afe6-365638d6823a.png)
* Note: When running in api mode with `--api`, the .exe will not function as a webUI. Thus, the address or port shown in the terminal output will only be useful for API requests.
## Configure Koboldcpp for local image generation:
* Get the latest [Koboldcpp](https://github.com/LostRuins/koboldcpp/releases) if you don't already have it. If you have a recent AMD card that has ROCm HIP [support for Windows](https://rocmdocs.amd.com/en/latest/release/windows_support.html#windows-supported-gpus) or [support for Linux](https://rocmdocs.amd.com/en/latest/release/gpu_os_support.html#linux-supported-gpus), you'll likely prefer [YellowRosecx's ROCm fork](https://github.com/YellowRoseCx/koboldcpp-rocm).
* Start Koboldcpp in another terminal/Powershell and setup your model configuration. Refer to the [Koboldcpp README](https://github.com/YellowRoseCx/koboldcpp-rocm) for more details on how to do this if this is your first time using Koboldcpp.
* Once the main UI has loaded into your browser click the settings button, go to the advanced tab, and then choose *Local A1111* from the generate images dropdown:
![Settings button location](https://user-images.githubusercontent.com/121311569/280556246-10692d79-e89f-4fdf-87ba-82f3d78ed49d.png)
![Advanced Settings with 'Local A1111' location](https://user-images.githubusercontent.com/121311569/280556234-6ebc8ba7-1469-442a-93a7-5626a094ddf1.png)
*if you get an error here, see the next section [below](#connecting-to-shark-on-a-different-address-or-port)*
* A list of Stable Diffusion models available to your SHARK instance should now be listed in the box below *generate images*. The default value will usually be set to `stabilityai/stable-diffusion-2-1-base`. Choose the model you want to use for image generation from the list (but see [performance considerations](#performance-considerations)).
* You should now be ready to generate images, either by clicking the 'Add Img' button above the text entry box:
![Add Image Button](https://user-images.githubusercontent.com/121311569/280556161-846c7883-4a83-4458-a56a-bd9f93ca354c.png)
...or by selecting the 'Autogenerate' option in the settings:
![Setting the autogenerate images option](https://user-images.githubusercontent.com/121311569/280556230-ae221a46-ba68-499b-a519-c8f290bbbeae.png)
*I often find that even if I have selected autogenerate I have to do an 'add img' to get things started off*
* There is one final piece of image generation configuration within Koboldcpp you might want to do. This is also in the generate images section of advanced settings. Here there is, not very obviously, a 'style' button:
![Selecting the 'styles' button](https://user-images.githubusercontent.com/121311569/280556694-55cd1c55-a059-4b54-9293-63d66a32368e.png)
This will bring up a dialog box where you can enter a short text that will sent as a prefix to the Prompt sent to SHARK:
![Entering extra image styles](https://user-images.githubusercontent.com/121311569/280556172-4aab9794-7a77-46d7-bdda-43df570ad19a.png)
## Connecting to SHARK on a different address or port
If you didn't set the port to `--server_port=7860` when starting SHARK, or you are running it on different machine on your network than you are running Koboldcpp, or to where you are running the koboldcpp's kdlite client frontend, then you very likely got the following error:
![Can't find the A1111 endpoint error](https://user-images.githubusercontent.com/121311569/280555857-601f53dc-35e9-4027-9180-baa61d2393ba.png)
As long as SHARK is running correctly, this means you need to set the url and port to the correct values in Koboldcpp. For instance. to set the port that Koboldcpp looks for an image generator to SHARK's default port of 8080:
* Select the cog icon the Generate Images section of Advanced settings:
![Selecting the endpoint cog](https://user-images.githubusercontent.com/121311569/280555866-4287ecc5-f29f-4c03-8f5a-abeaf31b0442.png)
* Then edit the port number at the end of the url in the 'A1111 Endpoint Selection' dialog box to read 8080:
![Changing the endpoint port](https://user-images.githubusercontent.com/121311569/280556170-f8848b7b-6fc9-4cf7-80eb-5c312f332fd9.png)
* Similarly, when running SHARK on a different machine you will need to change host part of the endpoint url to the hostname or ip address where SHARK is running, similarly:
![Changing the endpoint hostname](https://user-images.githubusercontent.com/121311569/280556167-c6541dea-0f85-417a-b661-fdf4dc40d05f.png)
## Examples
Here's how Koboldcpp shows an image being requested:
![An image being generated]((https://user-images.githubusercontent.com/121311569/280556210-bb1c9efd-79ac-478e-b726-b25b82ef2186.png)
The generated image in context in story mode:
![A generated image](https://user-images.githubusercontent.com/121311569/280556179-4e9f3752-f349-4cba-bc6a-f85f8dc79b10.jpg)
And the same image when clicked on:
![A selected image](https://user-images.githubusercontent.com/121311569/280556216-2ca4c0a4-3889-4ef5-8a09-30084fb34081.jpg)
## Where to find the images in SHARK
Even though Koboldcpp requests images at a size of 512x512, it resizes then to 256x256, converts them to `.jpeg`, and only shows them at 200x200 in the main text window. It does this so it can save them compactly embedded in your story as a `data://` uri.
However the images at the original size are saved by SHARK in its `output_dir` which is usually a folder named for the current date. inside `generated_imgs` folder in the SHARK installation directory.
You can browse these, either using the Output Gallery tab from within the SHARK web ui:
![SHARK web ui output gallery tab](https://user-images.githubusercontent.com/121311569/280556582-9303ca85-2594-4a8c-97a2-fbd72337980b.jpg)
...or by browsing to the `output_dir` in your operating system's file manager:
![SHARK output directory subfolder in Windows File Explorer](https://user-images.githubusercontent.com/121311569/280556297-66173030-2324-415c-a236-ef3fcd73e6ed.jpg)

278
generate_sharktank.py Normal file
View File

@@ -0,0 +1,278 @@
# Lint as: python3
"""SHARK Tank"""
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
# will generate local shark tank folder like this:
# /SHARK
# /gen_shark_tank
# /albert_lite_base
# /...model_name...
#
import os
import csv
import argparse
from shark.shark_importer import SharkImporter
import subprocess as sp
import hashlib
import numpy as np
from pathlib import Path
from apps.stable_diffusion.src.models import (
model_wrappers as mw,
)
from apps.stable_diffusion.src.utils.stable_args import (
args,
)
def create_hash(file_name):
with open(file_name, "rb") as f:
file_hash = hashlib.blake2b()
while chunk := f.read(2**20):
file_hash.update(chunk)
return file_hash.hexdigest()
def save_torch_model(torch_model_list):
from tank.model_utils import (
get_hf_model,
get_vision_model,
get_hf_img_cls_model,
get_fp16_model,
)
with open(torch_model_list) as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")
fields = next(torch_reader)
for row in torch_reader:
torch_model_name = row[0]
tracing_required = row[1]
model_type = row[2]
is_dynamic = row[3]
tracing_required = False if tracing_required == "False" else True
is_dynamic = False if is_dynamic == "False" else True
print("generating artifacts for: " + torch_model_name)
model = None
input = None
if model_type == "stable_diffusion":
args.use_tuned = False
args.import_mlir = True
args.use_tuned = False
args.local_tank_cache = WORKDIR
precision_values = ["fp16"]
seq_lengths = [64, 77]
for precision_value in precision_values:
args.precision = precision_value
for length in seq_lengths:
model = mw.SharkifyStableDiffusionModel(
model_id=torch_model_name,
custom_weights="",
precision=precision_value,
max_len=length,
width=512,
height=512,
use_base_vae=False,
debug=True,
sharktank_dir=WORKDIR,
generate_vmfb=False,
)
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(torch_model_name)
elif model_type == "hf":
model, input, _ = get_hf_model(torch_model_name)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(torch_model_name)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name)
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"
)
os.makedirs(torch_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
model,
(input,),
frontend="torch",
)
mlir_importer.import_debug(
is_dynamic=False,
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name,
)
# Generate torch dynamic models.
if is_dynamic:
mlir_importer.import_debug(
is_dynamic=True,
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name + "_dynamic",
)
def save_tf_model(tf_model_list):
from tank.model_utils_tf import (
get_causal_image_model,
get_causal_lm_model,
get_keras_model,
get_TFhf_model,
)
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
with open(tf_model_list) as csvfile:
tf_reader = csv.reader(csvfile, delimiter=",")
fields = next(tf_reader)
for row in tf_reader:
tf_model_name = row[0]
model_type = row[1]
model = None
input = None
print(f"Generating artifacts for model {tf_model_name}")
if model_type == "hf":
model, input, _ = get_causal_lm_model(tf_model_name)
if model_type == "img":
model, input, _ = get_causal_image_model(tf_model_name)
if model_type == "keras":
model, input, _ = get_keras_model(tf_model_name)
if model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name)
tf_model_name = tf_model_name.replace("/", "_")
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
os.makedirs(tf_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
model,
inputs=input,
frontend="tf",
)
mlir_importer.import_debug(
is_dynamic=False,
dir=tf_model_dir,
model_name=tf_model_name,
)
mlir_hash = create_hash(
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
)
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
def save_tflite_model(tflite_model_list):
from shark.tflite_utils import TFLitePreprocessor
with open(tflite_model_list) as csvfile:
tflite_reader = csv.reader(csvfile, delimiter=",")
for row in tflite_reader:
print("\n")
tflite_model_name = row[0]
tflite_model_link = row[1]
print("tflite_model_name", tflite_model_name)
print("tflite_model_link", tflite_model_link)
tflite_model_name_dir = os.path.join(
WORKDIR, str(tflite_model_name) + "_tflite"
)
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
# Preprocess to get SharkImporter input args
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()
# Use SharkImporter to get SharkInference input args
my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
frontend="tflite",
raw_model_file=raw_model_file_path,
)
my_shark_importer.import_debug(
dir=tflite_model_name_dir,
model_name=tflite_model_name,
func_name="main",
)
mlir_hash = create_hash(
os.path.join(
tflite_model_name_dir,
tflite_model_name + "_tflite" + ".mlir",
)
)
np.save(
os.path.join(tflite_model_name_dir, "hash"),
np.array(mlir_hash),
)
# Validates whether the file is present or not.
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
if __name__ == "__main__":
# Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--torch_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/torch_model_list.csv",
# help="""Contains the file with torch_model name and args.
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
# )
# parser.add_argument(
# "--tf_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tf_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--tflite_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tflite/tflite_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--ci_tank_dir",
# type=bool,
# default=False,
# )
# parser.add_argument("--upload", type=bool, default=False)
# old_args = parser.parse_args()
home = str(Path.home())
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
torch_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tf_model_list.csv"
)
tflite_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
)
save_torch_model(
os.path.join(os.path.dirname(__file__), "tank", "torch_sd_list.csv")
)
save_torch_model(torch_model_csv)
save_tf_model(tf_model_csv)
save_tflite_model(tflite_model_csv)

192
inference/CMakeLists.txt Normal file
View File

@@ -0,0 +1,192 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.17)
project(sharkbackend LANGUAGES C CXX)
#
# Options
#
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()
#
# Dependencies
#
# FetchContent requires us to include the transitive closure of all
# repos that we depend on so that we can override the tags.
#
include(FetchContent)
FetchContent_Declare(
repo-common
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
GIT_TAG ${TRITON_COMMON_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-core
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
GIT_TAG ${TRITON_CORE_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-backend
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
#
# The backend must be built into a shared library. Use an ldscript to
# hide all symbols except for the TRITONBACKEND API.
#
configure_file(src/libtriton_dshark.ldscript libtriton_dshark.ldscript COPYONLY)
add_library(
triton-dshark-backend SHARED
src/dshark.cc
#src/dshark_driver_module.c
)
add_library(
SharkBackend::triton-dshark-backend ALIAS triton-dshark-backend
)
target_include_directories(
triton-dshark-backend
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src
)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
add_subdirectory(thirdparty/shark-runtime EXCLUDE_FROM_ALL)
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
iree_hal_hal
iree_hal_cuda_cuda
iree_hal_cuda_registration_registration
iree_hal_vmvx_registration_registration
iree_hal_dylib_registration_registration
iree_modules_hal_hal
iree_vm_vm
iree_vm_bytecode_module
iree_hal_local_loaders_system_library_loader
iree_hal_local_loaders_vmvx_module_loader
)
target_compile_features(triton-dshark-backend PRIVATE cxx_std_11)
target_link_libraries(
triton-dshark-backend
PRIVATE
triton-core-serverapi # from repo-core
triton-core-backendapi # from repo-core
triton-core-serverstub # from repo-core
triton-backend-utils # from repo-backend
)
if(WIN32)
set_target_properties(
triton-dshark-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_dshark
)
else()
set_target_properties(
triton-dshark-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_dshark
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_dshark.ldscript
LINK_FLAGS "-Wl,--version-script libtriton_dshark.ldscript"
)
endif()
#
# Install
#
include(GNUInstallDirs)
set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/SharkBackend)
install(
TARGETS
triton-dshark-backend
EXPORT
triton-dshark-backend-targets
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
)
install(
EXPORT
triton-dshark-backend-targets
FILE
SharkBackendTargets.cmake
NAMESPACE
SharkBackend::
DESTINATION
${INSTALL_CONFIGDIR}
)
include(CMakePackageConfigHelpers)
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/cmake/SharkBackendConfig.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
)
install(
FILES
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
DESTINATION ${INSTALL_CONFIGDIR}
)
#
# Export from build tree
#
export(
EXPORT triton-dshark-backend-targets
FILE ${CMAKE_CURRENT_BINARY_DIR}/SharkBackendTargets.cmake
NAMESPACE SharkBackend::
)
export(PACKAGE SharkBackend)

100
inference/README.md Normal file
View File

@@ -0,0 +1,100 @@
# SHARK Triton Backend
The triton backend for shark.
# Build
Install SHARK
```
git clone https://github.com/nod-ai/SHARK.git
# skip above step if dshark is already installed
cd SHARK/inference
```
install dependancies
```
apt-get install patchelf rapidjson-dev python3-dev
git submodule update --init
```
update the submodules of iree
```
cd thirdparty/shark-runtime
git submodule update --init
```
Next, make the backend and install it
```
cd ../..
mkdir build && cd build
cmake -DTRITON_ENABLE_GPU=ON \
-DIREE_HAL_DRIVER_CUDA=ON \
-DIREE_TARGET_BACKEND_CUDA=ON \
-DMLIR_ENABLE_CUDA_RUNNER=ON \
-DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \
-DTRITON_BACKEND_REPO_TAG=r22.02 \
-DTRITON_CORE_REPO_TAG=r22.02 \
-DTRITON_COMMON_REPO_TAG=r22.02 ..
make install
```
# Incorporating into Triton
There are much more in depth explenations for the following steps in triton's documentation:
https://github.com/triton-inference-server/server/blob/main/docs/compose.md#triton-with-unsupported-and-custom-backends
There should be a file at /build/install/backends/dshark/libtriton_dshark.so. You will need to copy it into your triton server image.
More documentation is in the link above, but to create the docker image, you need to run the compose.py command in the triton-backend server repo
To first build your image, clone the tritonserver repo.
```
git clone https://github.com/triton-inference-server/server.git
```
then run `compose.py` to build a docker compose file
```
cd server
python3 compose.py --repoagent checksum --dry-run
```
Because dshark is a third party backend, you will need to manually modify the `Dockerfile.compose` to include the dshark backend. To do this, in the Dockerfile.compose file produced, copy this line.
the dshark backend will be located in the build folder from earlier under `/build/install/backends`
```
COPY /path/to/build/install/backends/dshark /opt/tritonserver/backends/dshark
```
Next run
```
docker build -t tritonserver_custom -f Dockerfile.compose .
docker run -it --gpus=1 --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
```
where `path/to/model_repos` is where you are storing the models you want to run
if your not using gpus, omit `--gpus=1`
```
docker run -it --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
```
# Setting up a model
to include a model in your backend, add a directory with your model name to your model repository directory. examples of models can be seen here: https://github.com/triton-inference-server/backend/tree/main/examples/model_repos/minimal_models
make sure to adjust the input correctly in the config.pbtxt file, and save a vmfb file under 1/model.vmfb
# CUDA
if you're having issues with cuda, make sure your correct drivers are installed, and that `nvidia-smi` works, and also make sure that the nvcc compiler is on the path.

View File

@@ -0,0 +1,39 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include(CMakeFindDependencyMacro)
get_filename_component(
SHARKBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
)
list(APPEND CMAKE_MODULE_PATH ${SHARKBACKEND_CMAKE_DIR})
if(NOT TARGET SharkBackend::triton-dshark-backend)
include("${SHARKBACKEND_CMAKE_DIR}/SharkBackendTargets.cmake")
endif()
set(SHARKBACKEND_LIBRARIES SharkBackend::triton-dshark-backend)

1409
inference/src/dshark.cc Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,30 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
{
global:
TRITONBACKEND_*;
local: *;
};

View File

@@ -6,15 +6,36 @@ from distutils.sysconfig import get_python_lib
import fileinput
from pathlib import Path
# Temporary workaround for transformers/__init__.py.
path_to_transformers_hook = Path(
get_python_lib() + "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
# Diffusers 0.13.1 fails with transformers __init.py errros in BLIP. So remove it for now until we fork it
pix2pix_init = Path(get_python_lib() + "/diffusers/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(get_python_lib() + "/diffusers/pipelines/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(
get_python_lib() + "/diffusers/pipelines/stable_diffusion/__init__.py"
)
if path_to_transformers_hook.is_file():
pass
else:
with open(path_to_transformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")
for line in fileinput.input(pix2pix_init, inplace=True):
if "StableDiffusionPix2PixZeroPipeline" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
@@ -55,12 +76,3 @@ for line in fileinput.input(path_to_lazy_loader, inplace=True):
)
else:
print(line, end="")
# For getting around timm's packaging.
# Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505
path_to_timm_activations = Path(get_python_lib() + "/timm/layers/activations_jit.py")
for line in fileinput.input(path_to_timm_activations, inplace=True):
if "@torch.jit.script" in line:
print("@torch.jit._script_if_tracing", end="\n")
else:
print(line, end="")

View File

@@ -5,25 +5,13 @@ requires = [
"packaging",
"numpy>=1.22.4",
"torch-mlir>=20221021.633",
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
]
build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79
include = '\.pyi?$'
exclude = '''
(
/(
| apps/stable_diffusion
| apps/language_models
| shark
| benchmarks
| tank
| build
| generated_imgs
| shark.venv
)/
| setup.py
)
'''

View File

@@ -1,3 +1,3 @@
[pytest]
addopts = --verbose -s -p no:warnings
norecursedirs = inference tank/tflite examples benchmarks shark apps/shark_studio
addopts = --verbose -p no:warnings
norecursedirs = inference tank/tflite examples benchmarks shark

View File

@@ -8,8 +8,19 @@ torchvision
tqdm
#iree-compiler | iree-runtime should already be installed
#these dont work ok osx
#iree-tools-tflite
#iree-tools-xla
#iree-tools-tf
# TensorFlow and JAX.
gin-config
tensorflow-macos
tensorflow-metal
#tf-models-nightly
#tensorflow-text-nightly
transformers
tensorflow-probability
#jax[cpu]
# tflitehub dependencies.

View File

@@ -2,20 +2,30 @@
--pre
numpy>1.22.4
torchvision
pytorch-triton
torchvision
tabulate
tqdm
#iree-compiler | iree-runtime should already be installed
iree-tools-tflite
iree-tools-xla
iree-tools-tf
# Modelling and JAX.
# TensorFlow and JAX.
gin-config
tf-nightly
keras>=2.10
#tf-models-nightly
#tensorflow-text-nightly
transformers
diffusers
#tensorflow-probability
#jax[cpu]
# tflitehub dependencies.
Pillow
# Testing and support.
@@ -23,10 +33,9 @@ lit
pyyaml
python-dateutil
sacremoses
sentencepiece
# web dependecies.
gradio==3.44.3
gradio
altair
scipy

Some files were not shown because too many files have changed in this diff Show More