Compare commits

..

1 Commits

Author SHA1 Message Date
Anush Elangovan
6d6a9dcae8 Revert "Revert "Enable --device_allocator=caching""
This reverts commit 41ee65b377.
2023-02-09 23:00:32 -08:00
119 changed files with 1849 additions and 17230 deletions

View File

@@ -1,5 +0,0 @@
[flake8]
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py

View File

@@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
@@ -44,20 +44,18 @@ jobs:
body: |
Automatic snapshot release of nod.ai SHARK.
draft: true
prerelease: true
prerelease: false
- name: Build Package
shell: powershell
run: |
./setup_venv.ps1
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
python process_skipfiles.py
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
# GHA windows VM OOMs so disable for now
@@ -67,9 +65,9 @@ jobs:
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
#- uses: actions/upload-artifact@v2
# with:
# path: dist/*
- uses: actions/upload-artifact@v2
with:
path: dist/*
- name: Upload Release Assets
id: upload-release-assets
@@ -79,7 +77,6 @@ jobs:
with:
release_id: ${{ steps.create_release.outputs.id }}
assets_path: ./dist/*
#asset_content_type: application/vnd.microsoft.portable-executable
- name: Publish Release
id: publish_release
@@ -95,7 +92,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
python-version: ["3.10"]
backend: [IREE, SHARK]
steps:
@@ -134,7 +131,7 @@ jobs:
source iree.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://iree-org.github.io/iree/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models

View File

@@ -31,7 +31,7 @@ jobs:
matrix:
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
suite: [cpu,cuda,vulkan]
python-version: ["3.11"]
python-version: ["3.10"]
include:
- os: ubuntu-latest
suite: lint
@@ -99,12 +99,11 @@ jobs:
run: |
# black format check
black --version
black --check .
black --line-length 79 --check .
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude lit.cfg.py
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
@@ -112,7 +111,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
@@ -120,9 +119,9 @@ jobs:
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
@@ -137,7 +136,7 @@ jobs:
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" -k vulkan --update_tank
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
@@ -145,19 +144,19 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
pytest -k vulkan -s --ci
pytest --benchmark -k vulkan -s
type bench_results.csv
- name: Validate Stable Diffusion Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
./shark.venv/Scripts/activate
python build_tools/stable_diffusion_testing.py --device=vulkan

5
.gitignore vendored
View File

@@ -168,8 +168,6 @@ shark_tmp/
*.vmfb
.use-iree
tank/dict_configs.py
*.csv
reproducers/
# ORT related artefacts
cache_models/
@@ -184,6 +182,3 @@ models/
# models folder
apps/stable_diffusion/web/models/
# Stencil annotators.
stencil_annotator/

3
.style.yapf Normal file
View File

@@ -0,0 +1,3 @@
[style]
based_on_style = google
column_limit = 80

View File

@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
* [AMD RDNA Users] Download this specific driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mril-iree). Latest drivers may not work.
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
@@ -25,32 +25,18 @@ Other users please ensure you have your latest vendor drivers and Vulkan SDK fro
### Quick Start for SHARK Stable Diffusion for Windows 10/11 Users
Install the Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
Install Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
Download the [stable release](https://github.com/nod-ai/shark/releases/latest)
Download the latest .exe https://github.com/nod-ai/SHARK/releases.
Double click the .exe and you should have the [UI](http://localhost:8080/) in the browser.
Double click the .exe and you should have the [UI]( http://localhost:8080/?__theme=dark) in the browser.
If you have custom models put them in a `models/` directory where the .exe is.
If you have custom models (ckpt, safetensors) put in a `models/` directory where the .exe is.
Enjoy.
<details>
<summary>More installation notes</summary>
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
Some known AMD Driver quirks and fixes with cursors are documented [here](https://github.com/nod-ai/SHARK/blob/main/apps/stable_diffusion/stable_diffusion_amd.md ).
## Running
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE)
* The first run may take few minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
* You will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/.
## Stopping
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
</details>
<details>
<summary>Advanced Installation (Only for developers)</summary>
@@ -68,7 +54,7 @@ cd SHARK
### Windows 10/11 Users
* Install the latest Python 3.11.x version from [here](https://www.python.org/downloads/windows/)
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
* Install Git for Windows from [here](https://git-scm.com/download/win)
@@ -114,20 +100,21 @@ source shark.venv/bin/activate
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
```shell
python3.11 apps/stable_diffusion/scripts/main.py --app=txt2img --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
python3.10 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
</details>
The output on a AMD 7900XTX would look something like:
The output on a 7900XTX would like:
```shell
```shell
Stats for run 0:
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
@@ -153,7 +140,7 @@ Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any
This step sets up a new VirtualEnv for Python
```shell
python --version #Check you have 3.11 on Linux, macOS or Windows Powershell
python --version #Check you have 3.10 on Linux, macOS or Windows Powershell
python -m venv shark_venv
source shark_venv/bin/activate # Use shark_venv/Scripts/activate on Windows
@@ -167,7 +154,7 @@ python -m pip install --upgrade pip
### Install SHARK
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10
```shell
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
@@ -202,10 +189,10 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
<details>
<summary>Development, Testing and Benchmarks</summary>
If you want to use Python3.11 and with TF Import tools you can use the environment variables like:
If you want to use Python3.10 and with TF Import tools you can use the environment variables like:
Set `USE_IREE=1` to use upstream IREE
```
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
```
### Run any of the hundreds of SHARK tank models via the test framework
@@ -215,14 +202,14 @@ python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use g
pytest tank/test_models.py -k "MiniLM"
```
### How to use your locally built IREE / Torch-MLIR with SHARK
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
for Torch-MLIR.
How to use your locally built Torch-MLIR with SHARK:
### How to use your locally built Torch-MLIR with SHARK
```shell
1.) Run `./setup_venv.sh in SHARK` and activate `shark.venv` virtual env.
2.) Run `pip uninstall torch-mlir`.
@@ -240,15 +227,9 @@ Now the SHARK will use your locally build Torch-MLIR repo.
## Benchmarking Dispatches
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your command line argument.
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
```
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
```
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
```
@@ -272,7 +253,7 @@ Output will include:
- A .txt file containing benchmark output
See tank/README.md for further instructions on how to run model tests and benchmarks from the SHARK tank.
See tank/README.md for instructions on how to run model tests and benchmarks from the SHARK tank.
</details>

View File

@@ -1,303 +0,0 @@
import torch
import torch_mlir
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
)
import time
import numpy as np
from torch.nn import functional as F
import os
from threading import Thread
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_inference import SharkInference
from pathlib import Path
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def shouldStop(tokens):
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if tokens[0][-1] == stop_id:
return True
return False
MAX_SEQUENCE_LENGTH = 256
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def get_torch_mlir_module_bytecode(model, model_inputs):
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
# tracing_mode='symbolic',
)(*model_inputs)
print("Got FX_G")
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
transform_fx(fx_g)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
print("FX_G recompile")
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
print("Got TS_G")
return ts_g
def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
# ADD Device Arg
from shark.shark_inference import SharkInference
vmfb_path = Path(model_vmfb_name + ".vmfb")
if vmfb_path.exists():
print("Loading ", vmfb_path)
shark_module = SharkInference(
None, device="cuda", mlir_dialect="tm_tensor"
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
return shark_module
mlir_path = Path(model_name + ".mlir")
print(
f"[DEBUG] mlir path { mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path) as f:
bytecode = f.read("rb")
else:
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
)
shark_module.compile()
import os
path = shark_module.save_module(os.getcwd(), model_vmfb_name, [])
print("Saved vmfb at ", str(path))
return shark_module
class StableLMModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
combine_input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
output = self.model(**combine_input_dict)
return output.logits
# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
def get_tokenizer():
model_path = "stabilityai/stablelm-tuned-alpha-3b"
tok = AutoTokenizer.from_pretrained(model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
print(f"Sucessfully loaded the tokenizer to the memory")
return tok
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
def generate(
new_text,
streamer,
max_new_tokens,
do_sample,
top_p,
top_k,
temperature,
num_beams,
stopping_criteria,
sharkStableLM,
tok=None,
input_ids=torch.randint(3, (1, 256)),
attention_mask=torch.randint(3, (1, 256)),
):
if tok == None:
tok = get_tokenizer
# Construct the input message string for the model by concatenating the current system message and conversation history
# Tokenize the messages string
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
words_list = []
for i in range(max_new_tokens):
numWords = len(new_text.split())
# if(numWords>220):
# break
model_inputs = tok(
[new_text],
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
output = sharkStableLM(
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
)
output = torch.from_numpy(output)
next_toks = torch.topk(output, 1)
if shouldStop(next_toks.indices):
break
# streamer.put(next_toks.indices[0][int(sum_attentionmask)-1])
new_word = tok.decode(
next_toks.indices[0][int(sum_attentionmask) - 1],
skip_special_tokens=True,
)
print(new_word, end="", flush=True)
words_list.append(new_word)
if new_word == "":
break
new_text = new_text + new_word
return words_list

View File

@@ -1,656 +0,0 @@
import sys
import warnings
warnings.filterwarnings("ignore")
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_compiler")
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_runtime")
import torch
import torch_mlir
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import transform_fx as transform_fx_
import re
from shark.shark_inference import SharkInference
from tqdm import tqdm
from torch_mlir import TensorPlaceholder
class FirstVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states, attention_mask, position_ids):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class SecondVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
past_key_value0,
past_key_value1,
),
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class CompiledFirstVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class CompiledSecondVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
def forward(
self,
input_ids,
is_first=True,
past_key_values=None,
attention_mask=None,
):
if is_first:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers0
)
return self.model.forward(input_ids, attention_mask=attention_mask)
else:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers1
)
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
def write_in_dynamic_inputs0(module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def write_in_dynamic_inputs1(module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
if "dim_42 =" in line:
continue
if f"%c{dynamic_input_size}_i64 =" in line:
new_lines.append(
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
)
new_lines.append(
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
)
continue
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim_42)", line)
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim_42\)",
"tensor.empty(%dim_42, %dim_42)",
line,
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def compile_vicuna_layer(
vicuna_layer,
hidden_states,
attention_mask,
position_ids,
past_key_value0=None,
past_key_value1=None,
):
hidden_states_placeholder = TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
attention_mask, dynamic_axes=[2, 3]
)
position_ids_placeholder = TensorPlaceholder.like(
position_ids, dynamic_axes=[1]
)
if past_key_value0 is None and past_key_value1 is None:
fx_g = make_fx(
vicuna_layer,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(hidden_states, attention_mask, position_ids)
else:
fx_g = make_fx(
vicuna_layer,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
)
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
transform_fx(fx_g)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
print("FX_G recompile")
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
return ts_g
def get_model_and_tokenizer(path="TheBloke/vicuna-7B-1.1-HF"):
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
return vicuna_model, tokenizer
def compile_to_vmfb(inputs, layers, is_first=True):
mlirs, modules = [], []
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
if is_first:
mlir_path = Path(f"{idx}_0.mlir")
vmfb_path = Path(f"{idx}_0.vmfb")
else:
mlir_path = Path(f"{idx}_1.mlir")
vmfb_path = Path(f"{idx}_1.vmfb")
if vmfb_path.exists():
continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
hidden_states_placeholder = TensorPlaceholder.like(
inputs[0], dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
inputs[1], dynamic_axes=[3]
)
position_ids_placeholder = TensorPlaceholder.like(
inputs[2], dynamic_axes=[1]
)
if not is_first:
pkv0_placeholder = TensorPlaceholder.like(
inputs[3], dynamic_axes=[2]
)
pkv1_placeholder = TensorPlaceholder.like(
inputs[4], dynamic_axes=[2]
)
print(f"Compiling layer {idx} mlir")
if is_first:
ts_g = compile_vicuna_layer(
layer, inputs[0], inputs[1], inputs[2]
)
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
inputs[1],
inputs[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
else:
ts_g = compile_vicuna_layer(
layer,
inputs[0],
inputs[1],
inputs[2],
inputs[3],
inputs[4],
)
module = torch_mlir.compile(
ts_g,
(
inputs[0],
attention_mask_placeholder,
inputs[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
# bytecode_stream = BytesIO()
# module.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
if is_first:
module = write_in_dynamic_inputs0(str(module), 137)
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
else:
module = write_in_dynamic_inputs1(str(module), 138)
if idx in [0, 5, 6, 7]:
module_str = module
module_str = module_str.splitlines()
new_lines = []
for line in module_str:
if len(line) < 1000:
new_lines.append(line)
else:
new_lines.append(line[:999])
module_str = "\n".join(new_lines)
f1_ = open(f"{idx}_1_test.mlir", "w+")
f1_.write(module_str)
f1_.close()
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
mlirs.append(bytecode)
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
if is_first:
vmfb_path = Path(f"{idx}_0.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_0",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
else:
vmfb_path = Path(f"{idx}_1.vmfb")
if idx < 25:
device = "vulkan"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_1",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
return mlirs, modules
def get_sharded_model():
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it
SAMPLE_INPUT_LEN = 137
vicuna_model = get_model_and_tokenizer()[0]
placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
)
placeholder_input1 = (
torch.zeros([1, 1, 4096]),
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
torch.zeros([1, 1], dtype=torch.int64),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
layers0 = [FirstVicunaLayer(layer) for layer in vicuna_model.model.layers]
_, modules0 = compile_to_vmfb(placeholder_input0, layers0, is_first=True)
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
layers1 = [SecondVicunaLayer(layer) for layer in vicuna_model.model.layers]
_, modules1 = compile_to_vmfb(placeholder_input1, layers1, is_first=False)
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
sharded_model = ShardedVicunaModel(
vicuna_model, shark_layers0, shark_layers1
)
return sharded_model
if __name__ == "__main__":
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"
sharded_model = get_sharded_model()
tokenizer = get_model_and_tokenizer()[1]
past_key_values = None
while True:
print("\n\n")
user_prompt = input("User: ")
prompt_history = (
prompt_history + "USER:\n" + user_prompt + prologue_prompt
)
prompt = prompt_history.strip()
input_ids = tokenizer(prompt).input_ids
tokens = input_ids
prompt = print("Robot:", end=" ")
new_sentence = []
max_response_len = 1000
for iteration in range(max_response_len):
original_input_ids = input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
if iteration == 0:
output = sharded_model.forward(input_ids, is_first=True)
else:
output = sharded_model.forward(
input_ids, past_key_values=past_key_values, is_first=False
)
logits = output["logits"]
past_key_values = output["past_key_values"]
new_token = int(torch.argmax(logits[:, -1, :], dim=1)[0])
if new_token == 2:
break
new_sentence += [new_token]
tokens.append(new_token)
original_input_ids.append(new_token)
input_ids = [new_token]
for i in range(len(tokens)):
if type(tokens[i]) != int:
tokens[i] = int(tokens[i][0])
new_sentence_str = tokenizer.decode(new_sentence)
print(new_sentence_str)
prompt_history += f"\n{new_sentence_str}\n"

View File

@@ -1,777 +0,0 @@
import sys
import warnings
import gradio as gr
import time
warnings.filterwarnings("ignore")
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_compiler")
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_runtime")
import torch
import torch_mlir
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import transform_fx as transform_fx_
import re
from shark.shark_inference import SharkInference
from tqdm import tqdm
from torch_mlir import TensorPlaceholder
from apps.stable_diffusion.web.ui.utils import available_devices
class FirstVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states, attention_mask, position_ids):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class SecondVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
past_key_value0,
past_key_value1,
),
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class CompiledFirstVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class CompiledSecondVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
def forward(
self,
input_ids,
is_first=True,
past_key_values=None,
attention_mask=None,
):
if is_first:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers0
)
return self.model.forward(input_ids, attention_mask=attention_mask)
else:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers1
)
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
def write_in_dynamic_inputs0(module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def write_in_dynamic_inputs1(module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
if "dim_42 =" in line:
continue
if f"%c{dynamic_input_size}_i64 =" in line:
new_lines.append(
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
)
new_lines.append(
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
)
continue
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim_42)", line)
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim_42\)",
"tensor.empty(%dim_42, %dim_42)",
line,
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def compile_vicuna_layer(
vicuna_layer,
hidden_states,
attention_mask,
position_ids,
past_key_value0=None,
past_key_value1=None,
):
hidden_states_placeholder = TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
attention_mask, dynamic_axes=[2, 3]
)
position_ids_placeholder = TensorPlaceholder.like(
position_ids, dynamic_axes=[1]
)
if past_key_value0 is None and past_key_value1 is None:
fx_g = make_fx(
vicuna_layer,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(hidden_states, attention_mask, position_ids)
else:
fx_g = make_fx(
vicuna_layer,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
)
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
transform_fx(fx_g)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
print("FX_G recompile")
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
return ts_g
path = "TheBloke/vicuna-7B-1.1-HF"
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
def compile_to_vmfb(inputs, layers, is_first=True):
mlirs, modules = [], []
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
if is_first:
mlir_path = Path(f"{idx}_0.mlir")
vmfb_path = Path(f"{idx}_0.vmfb")
else:
mlir_path = Path(f"{idx}_1.mlir")
vmfb_path = Path(f"{idx}_1.vmfb")
if vmfb_path.exists():
continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
hidden_states_placeholder = TensorPlaceholder.like(
inputs[0], dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
inputs[1], dynamic_axes=[3]
)
position_ids_placeholder = TensorPlaceholder.like(
inputs[2], dynamic_axes=[1]
)
if not is_first:
pkv0_placeholder = TensorPlaceholder.like(
inputs[3], dynamic_axes=[2]
)
pkv1_placeholder = TensorPlaceholder.like(
inputs[4], dynamic_axes=[2]
)
print(f"Compiling layer {idx} mlir")
if is_first:
ts_g = compile_vicuna_layer(
layer, inputs[0], inputs[1], inputs[2]
)
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
inputs[1],
inputs[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
else:
ts_g = compile_vicuna_layer(
layer,
inputs[0],
inputs[1],
inputs[2],
inputs[3],
inputs[4],
)
module = torch_mlir.compile(
ts_g,
(
inputs[0],
attention_mask_placeholder,
inputs[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
# bytecode_stream = BytesIO()
# module.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
if is_first:
module = write_in_dynamic_inputs0(str(module), 137)
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
else:
module = write_in_dynamic_inputs1(str(module), 138)
if idx in [0, 5, 6, 7]:
module_str = module
module_str = module_str.splitlines()
new_lines = []
for line in module_str:
if len(line) < 1000:
new_lines.append(line)
else:
new_lines.append(line[:999])
module_str = "\n".join(new_lines)
f1_ = open(f"{idx}_1_test.mlir", "w+")
f1_.write(module_str)
f1_.close()
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
mlirs.append(bytecode)
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
if is_first:
vmfb_path = Path(f"{idx}_0.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_0",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
else:
vmfb_path = Path(f"{idx}_1.vmfb")
if idx < 25:
device = "vulkan"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_1",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
return mlirs, modules
def get_sharded_model():
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it
SAMPLE_INPUT_LEN = 137
global vicuna_model
placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
)
placeholder_input1 = (
torch.zeros([1, 1, 4096]),
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
torch.zeros([1, 1], dtype=torch.int64),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
layers0 = [FirstVicunaLayer(layer) for layer in vicuna_model.model.layers]
_, modules0 = compile_to_vmfb(placeholder_input0, layers0, is_first=True)
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
layers1 = [SecondVicunaLayer(layer) for layer in vicuna_model.model.layers]
_, modules1 = compile_to_vmfb(placeholder_input1, layers1, is_first=False)
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
sharded_model = ShardedVicunaModel(
vicuna_model, shark_layers0, shark_layers1
)
return sharded_model
sharded_model = get_sharded_model()
def user(message, history):
print("msg=", message)
print("history=", history)
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def chat(curr_system_message, history):
global sharded_model
past_key_values = None
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
print(messages)
prompt = messages.strip()
input_ids = tokenizer(prompt).input_ids
tokens = input_ids
new_sentence = []
max_response_len = 1000
partial_sentence = []
partial_text = ""
start_time = time.time()
for iteration in range(max_response_len):
original_input_ids = input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
if iteration == 0:
output = sharded_model.forward(input_ids, is_first=True)
else:
output = sharded_model.forward(
input_ids, past_key_values=past_key_values, is_first=False
)
logits = output["logits"]
past_key_values = output["past_key_values"]
new_token = int(torch.argmax(logits[:, -1, :], dim=1)[0])
if new_token == 2:
break
new_sentence += [new_token]
partial_sentence += [new_token]
if iteration > 0 and iteration % 2 == 0:
new_text = tokenizer.decode(partial_sentence)
partial_sentence = []
print(new_text, " ")
partial_text += new_text + " "
history[-1][1] = partial_text
yield history
tokens.append(new_token)
original_input_ids.append(new_token)
input_ids = [new_token]
end_time = time.time()
print(
f"Total time taken to generated response is {end_time-start_time} seconds"
)
for i in range(len(tokens)):
if type(tokens[i]) != int:
tokens[i] = int(tokens[i][0])
new_sentence_str = tokenizer.decode(new_sentence)
print(new_sentence_str)
history[-1][1] = new_sentence_str
return history
system_msg = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
# history_eg = [["hi hello how are you", ""]]
# print(chat(system_msg, history_eg))
with gr.Blocks(title="Chatbot") as vicuna_chat:
with gr.Row():
model = gr.Dropdown(
label="Select Model",
value="TheBloke/vicuna-7B-1.1-HF",
choices=[
"TheBloke/vicuna-7B-1.1-HF",
],
)
device_value = None
for d in available_devices:
if "vulkan" in d:
device_value = d
break
device = gr.Dropdown(
label="Device",
value=device_value if device_value else available_devices[0],
interactive=False,
choices=available_devices,
)
chatbot = gr.Chatbot().style(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
).style(container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit")
stop = gr.Button("Stop")
clear = gr.Button("Clear")
system_msg = gr.Textbox(
system_msg, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot],
outputs=[chatbot],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot],
outputs=[chatbot],
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)
import argparse
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
args, unknown = p.parse_known_args()
vicuna_chat.queue()
vicuna_chat.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

View File

@@ -1 +1 @@
from apps.stable_diffusion.scripts.train_lora_word import lora_train
from apps.stable_diffusion.scripts.txt2img import txt2img_inf

View File

@@ -1,126 +0,0 @@
import sys
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
StencilPipeline,
resize_stencil,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
def main():
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
image = Image.open(args.img_path).convert("RGB")
# When the models get uploaded, it should be default to False.
args.import_mlir = True
use_stencil = args.use_stencil
if use_stencil:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, args.width, args.height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model
if use_stencil:
img2img_obj = StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
start_time = time.time()
generated_imgs = img2img_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.batch_size,
args.height,
args.width,
args.steps,
args.strength,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, strength={args.strength}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += img2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
extra_info = {"STRENGTH": args.strength}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -1,104 +0,0 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
def main():
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
if args.mask_path is None:
print("Flag --mask_path is required.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
model_id = (
args.hf_model_id
if "inpaint" in args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
image = Image.open(args.img_path)
mask_image = Image.open(args.mask_path)
inpaint_obj = InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = inpaint_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
mask_image,
args.batch_size,
args.height,
args.width,
args.inpaint_full_res,
args.inpaint_full_res_padding,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += inpaint_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -1,19 +0,0 @@
from apps.stable_diffusion.src import args
from apps.stable_diffusion.scripts import (
img2img,
txt2img,
# inpaint,
# outpaint,
)
if __name__ == "__main__":
if args.app == "txt2img":
txt2img.main()
elif args.app == "img2img":
img2img.main()
# elif args.app == "inpaint":
# inpaint.main()
# elif args.app == "outpaint":
# outpaint.main()
else:
print(f"args.app value is {args.app} but this isn't supported")

View File

@@ -1,119 +0,0 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
def main():
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
model_id = (
args.hf_model_id
if "inpaint" in args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
image = Image.open(args.img_path)
outpaint_obj = OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = outpaint_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.pixels,
args.mask_blur,
args.left,
args.right,
args.top,
args.bottom,
args.noise_q,
args.color_variation,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += outpaint_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
# save this information as metadata of output generated image.
directions = []
if args.left:
directions.append("left")
if args.right:
directions.append("right")
if args.top:
directions.append("up")
if args.bottom:
directions.append("down")
extra_info = {
"PIXELS": args.pixels,
"MASK_BLUR": args.mask_blur,
"DIRECTIONS": directions,
"NOISE_Q": args.noise_q,
"COLOR_VARIATION": args.color_variation,
}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -1,692 +0,0 @@
# Install the required libs
# pip install -U git+https://github.com/huggingface/diffusers.git
# pip install accelerate transformers ftfy
# HuggingFace Token
# YOUR_TOKEN = "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"
# Import required libraries
import itertools
import math
import os
from typing import List
import random
import torch_mlir
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import PIL
import logging
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
import torch._dynamo as dynamo
from torch.fx.experimental.proxy_tensor import make_fx
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
from shark.shark_inference import SharkInference
torch._dynamo.config.verbose = True
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker,
)
from PIL import Image
from tqdm.auto import tqdm
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer,
)
from io import BytesIO
from dataclasses import dataclass
from apps.stable_diffusion.src import (
args,
get_schedulers,
set_init_device_flags,
clear_all,
)
from apps.stable_diffusion.src.utils import update_lora_weight
# Setup the dataset
class LoraDataset(Dataset):
def __init__(
self,
data_root,
tokenizer,
size=512,
repeats=100,
interpolation="bicubic",
set="train",
prompt="myloraprompt",
center_crop=False,
):
self.data_root = data_root
self.tokenizer = tokenizer
self.size = size
self.center_crop = center_crop
self.prompt = prompt
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
self.num_images = len(self.image_paths)
self._length = self.num_images
if set == "train":
self._length = self.num_images * repeats
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
example["input_ids"] = self.tokenizer(
self.prompt,
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids[0]
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = np.array(image).astype(np.uint8)
image = (image / 127.5 - 1.0).astype(np.float32)
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
return example
def torch_device(device):
device_tokens = device.split("=>")
if len(device_tokens) == 1:
device_str = device_tokens[0].strip()
else:
device_str = device_tokens[1].strip()
device_type_tokens = device_str.split("://")
if device_type_tokens[0] == "metal":
device_type_tokens[0] = "vulkan"
if len(device_type_tokens) > 1:
return device_type_tokens[0] + ":" + device_type_tokens[1]
else:
return device_type_tokens[0]
########## Setting up the model ##########
def lora_train(
prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
training_images_dir: str,
lora_save_dir: str,
use_lora: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
print(
"Note LoRA training is not compatible with the latest torch-mlir branch"
)
print(
"To run LoRA training you'll need this to follow this guide for the torch-mlir branch: https://github.com/nod-ai/SHARK/tree/main/shark/examples/shark_training/stable_diffusion"
)
torch.manual_seed(seed)
args.prompts = [prompt]
args.steps = steps
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = custom_model
else:
args.hf_model_id = custom_model
args.training_images_dir = training_images_dir
args.lora_save_dir = lora_save_dir
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = torch_device(device)
args.use_lora = use_lora
# Load the Stable Diffusion model
text_encoder = CLIPTextModel.from_pretrained(
args.hf_model_id, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(args.hf_model_id, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
args.hf_model_id, subfolder="unet"
)
def freeze_params(params):
for param in params:
param.requires_grad = False
# Freeze everything but LoRA
freeze_params(vae.parameters())
freeze_params(unet.parameters())
freeze_params(text_encoder.parameters())
# Move vae and unet to device
vae.to(args.device)
unet.to(args.device)
text_encoder.to(args.device)
if use_lora != "":
update_lora_weight(unet, args.use_lora, "unet")
else:
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = vae
def forward(self, input):
x = self.vae.encode(input, return_dict=False)[0]
return x
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = unet
def forward(self, x, y, z):
return self.unet.forward(x, y, z, return_dict=False)[0]
shark_vae = VaeModel()
shark_unet = UnetModel()
####### Creating our training data ########
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id,
subfolder="tokenizer",
)
# Let's create the Dataset and Dataloader
train_dataset = LoraDataset(
data_root=args.training_images_dir,
tokenizer=tokenizer,
size=vae.sample_size,
prompt=args.prompts[0],
repeats=100,
center_crop=False,
set="train",
)
def create_dataloader(train_batch_size=1):
return torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True
)
# Create noise_scheduler for training
noise_scheduler = DDPMScheduler.from_config(
args.hf_model_id, subfolder="scheduler"
)
######## Training ###########
# Define hyperparameters for our training. If you are not happy with your results,
# you can tune the `learning_rate` and the `max_train_steps`
# Setting up all training args
hyperparameters = {
"learning_rate": 5e-04,
"scale_lr": True,
"max_train_steps": steps,
"train_batch_size": batch_size,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True,
"mixed_precision": "fp16",
"seed": 42,
"output_dir": "sd-concept-output",
}
# creating output directory
cwd = os.getcwd()
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
while not os.path.exists(str(out_dir)):
try:
os.mkdir(out_dir)
except OSError as error:
print("Output directory not created")
###### Torch-MLIR Compilation ######
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
@make_simple_dynamo_backend
def refbackend_torchdynamo_backend(
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
):
# handling usage of empty tensor without initializing
transform_fx(fx_graph)
fx_graph.recompile()
if _returns_nothing(fx_graph):
return fx_graph
removed_none_indexes = _remove_nones(fx_graph)
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
mlir_module = torch_mlir.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors"
)
bytecode_stream = BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
def compiled_callable(*inputs):
inputs = [x.numpy() for x in inputs]
result = shark_module("forward", inputs)
if was_unwrapped:
result = [
result,
]
if not isinstance(result, list):
result = torch.from_numpy(result)
else:
result = tuple(torch.from_numpy(x) for x in result)
result = list(result)
for removed_index in removed_none_indexes:
result.insert(removed_index, None)
result = tuple(result)
return result
return compiled_callable
def predictions(torch_func, jit_func, batchA, batchB):
res = jit_func(batchA.numpy(), batchB.numpy())
if res is not None:
# prediction = torch.from_numpy(res)
prediction = res
else:
prediction = None
return prediction
logger = logging.getLogger(__name__)
train_batch_size = hyperparameters["train_batch_size"]
gradient_accumulation_steps = hyperparameters[
"gradient_accumulation_steps"
]
learning_rate = hyperparameters["learning_rate"]
if hyperparameters["scale_lr"]:
learning_rate = (
learning_rate
* gradient_accumulation_steps
* train_batch_size
# * accelerator.num_processes
)
# Initialize the optimizer
optimizer = torch.optim.AdamW(
lora_layers.parameters(), # only optimize the embeddings
lr=learning_rate,
)
# Training function
def train_func(batch_pixel_values, batch_input_ids):
# Convert images to latent space
latents = shark_vae(batch_pixel_values).sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.num_train_timesteps,
(bsz,),
device=latents.device,
).long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch_input_ids)[0]
# Predict the noise residual
noise_pred = shark_unet(
noisy_latents,
timesteps,
encoder_hidden_states,
)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
)
loss = (
F.mse_loss(noise_pred, target, reduction="none")
.mean([1, 2, 3])
.mean()
)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
def training_function():
max_train_steps = hyperparameters["max_train_steps"]
output_dir = hyperparameters["output_dir"]
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
train_dataloader = create_dataloader(train_batch_size)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
num_train_epochs = math.ceil(
max_train_steps / num_update_steps_per_epoch
)
# Train!
total_batch_size = (
train_batch_size
* gradient_accumulation_steps
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(
f" Instantaneous batch size per device = {train_batch_size}"
)
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(
# range(max_train_steps), disable=not accelerator.is_local_main_process
range(max_train_steps)
)
progress_bar.set_description("Steps")
global_step = 0
params__ = [
i for i in text_encoder.get_input_embeddings().parameters()
]
for epoch in range(num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
dynamo_callable = dynamo.optimize(
refbackend_torchdynamo_backend
)(train_func)
lam_func = lambda x, y: dynamo_callable(
torch.from_numpy(x), torch.from_numpy(y)
)
loss = predictions(
train_func,
lam_func,
batch["pixel_values"],
batch["input_ids"],
)
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item()}
progress_bar.set_postfix(**logs)
if global_step >= max_train_steps:
break
training_function()
# Save the lora weights
unet.save_attn_procs(args.lora_save_dir)
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
if param.grad is not None:
del param.grad # free some memory
torch.cuda.empty_cache()
if __name__ == "__main__":
if args.clear_all:
clear_all()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
if len(args.prompts) != 1:
print("Need exactly one prompt for the LoRA word")
lora_train(
args.prompts[0],
args.height,
args.width,
args.training_steps,
args.guidance_scale,
args.seed,
args.batch_count,
args.batch_size,
args.scheduler,
"None",
args.hf_model_id,
args.precision,
args.device,
args.max_length,
args.training_images_dir,
args.lora_save_dir,
args.use_lora,
)

View File

@@ -1,126 +0,0 @@
import os
from pathlib import Path
from shark_tuner.codegen_tuner import SharkCodegenTuner
from shark_tuner.iree_utils import (
dump_dispatches,
create_context,
export_module_to_mlir_file,
)
from shark_tuner.model_annotation import model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.utils import set_init_device_flags
from apps.stable_diffusion.src.utils.sd_annotation import (
get_device_args,
load_winograd_configs,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def load_mlir_module():
sd_model = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
max_len=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=False,
low_cpu_mem_usage=args.low_cpu_mem_usage,
return_mlir=True,
)
if args.annotation_model == "unet":
mlir_module = sd_model.unet()
model_name = sd_model.model_name["unet"]
elif args.annotation_model == "vae":
mlir_module = sd_model.vae()
model_name = sd_model.model_name["vae"]
else:
raise ValueError(
f"{args.annotation_model} is not supported for tuning."
)
return mlir_module, model_name
def main():
args.use_tuned = False
set_init_device_flags()
mlir_module, model_name = load_mlir_module()
# Get device and device specific arguments
device, device_spec_args = get_device_args()
device_spec = ""
vulkan_target_triple = ""
if device_spec_args:
device_spec = device_spec_args[-1].split("=")[-1].strip()
if device == "vulkan":
vulkan_target_triple = device_spec
device_spec = device_spec.split("-")[0]
# Add winograd annotation for vulkan device
use_winograd = (
True
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
else False
)
winograd_config = (
load_winograd_configs()
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
else ""
)
with create_context() as ctx:
input_module = model_annotation(
ctx,
input_contents=mlir_module,
config_path=winograd_config,
search_op="conv",
winograd=use_winograd,
)
# Dump model dispatches
generates_dir = Path.home() / "tmp"
if not os.path.exists(generates_dir):
os.makedirs(generates_dir)
dump_mlir = generates_dir / "temp.mlir"
dispatch_dir = generates_dir / f"{model_name}_{device_spec}_dispatches"
export_module_to_mlir_file(input_module, dump_mlir)
dump_dispatches(
dump_mlir,
device,
dispatch_dir,
vulkan_target_triple,
use_winograd=use_winograd,
)
# Tune each dispatch
dtype = "f16" if args.precision == "fp16" else "f32"
config_filename = f"{model_name}_{device_spec}_configs.json"
for f_path in os.listdir(dispatch_dir):
if not f_path.endswith(".mlir"):
continue
model_dir = os.path.join(dispatch_dir, f_path)
tuner = SharkCodegenTuner(
model_dir,
device,
"random",
args.num_iters,
args.tuned_config_dir,
dtype,
args.search_op,
batch_size=1,
config_filename=config_filename,
use_dispatch=True,
vulkan_target_triple=vulkan_target_triple,
)
tuner.tune()
if __name__ == "__main__":
main()

View File

@@ -1,49 +1,299 @@
import os
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
import sys
import json
import torch
import transformers
import re
import time
from pathlib import Path
from PIL import PngImagePlugin
from datetime import datetime as dt
from dataclasses import dataclass
from csv import DictWriter
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
def main():
if args.clear_all:
clear_all()
@dataclass
class Config:
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
# This has to come before importing cache objects
if args.clear_all:
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
from glob import glob
import shutil
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
if os.path.exists(yaml):
os.remove(yaml)
home = os.path.expanduser("~")
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed):
output_path = args.output_dir if args.output_dir else Path.cwd()
generated_imgs_path = Path(output_path, "generated_imgs")
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
img_model = args.hf_model_id
if args.ckpt_loc:
img_model = os.path.basename(args.ckpt_loc)
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
)
new_entry = {
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
"PROMPT": args.prompts[0],
"NEG_PROMPT": args.negative_prompts[0],
"SEED": img_seed,
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
}
with open(csv_path, "a") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
dictwriter_obj.writerow(new_entry)
csv_obj.close()
if args.save_metadata_to_json:
del new_entry["OUTPUT"]
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
txt2img_obj = None
config_obj = None
schedulers = None
# Exposed to UI.
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
):
global txt2img_obj
global config_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = custom_model
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
)
if config_obj != new_config_obj:
config_obj = new_config_obj
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.use_tuned = True
args.import_mlir = False
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
)
if not txt2img_obj:
sys.exit("text to image pipeline must not return a null value")
txt2img_obj.scheduler = schedulers[scheduler]
start_time = time.time()
txt2img_obj.log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = txt2img_obj.generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
txt2img_obj.log += "\n"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
return generated_imgs, text_output
if __name__ == "__main__":
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
)
for current_batch in range(args.batch_count):
if current_batch > 0:
for run in range(args.runs):
if run > 0:
seed = -1
seed = utils.sanitize_seed(seed)
@@ -73,13 +323,9 @@ def main():
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
# TODO: if using --runs=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -1,91 +0,0 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
if __name__ == "__main__":
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
# When the models get uploaded, it should be default to False.
args.import_mlir = True
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
image = (
Image.open(args.img_path)
.convert("RGB")
.resize((args.height, args.width))
)
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model
upscaler_obj = UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ddpm_scheduler=schedulers["DDPM"],
ondemand=args.ondemand,
)
start_time = time.time()
generated_imgs = upscaler_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.batch_size,
args.height,
args.width,
args.steps,
args.noise_level,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, noise_level={args.noise_level}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += upscaler_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
extra_info = {"NOISE LEVEL": args.noise_level}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)

View File

@@ -1,7 +1,6 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
@@ -16,44 +15,36 @@ datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torchvision')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('diffusers')
datas += copy_metadata('transformers')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv-python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('tkinter')
datas += collect_data_files('webview')
datas += collect_data_files('sentencepiece')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
( 'web/ui/css/*', 'ui/css' ),
( 'web/ui/logos/*', 'logos' )
( 'web/css/*', 'css' ),
( 'web/logos/*', 'logos' )
]
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
a = Analysis(
['web/index.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
hookspath=[],
hooksconfig={},
runtime_hooks=[],

View File

@@ -1,6 +1,5 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_submodules
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
@@ -16,16 +15,13 @@ datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torchvision')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('diffusers')
datas += copy_metadata('transformers')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
@@ -40,15 +36,13 @@ binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
a = Analysis(
['scripts/main.py'],
['scripts/txt2img.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
hookspath=[],
hooksconfig={},
runtime_hooks=[],

View File

@@ -3,16 +3,6 @@ from apps.stable_diffusion.src.utils import (
set_init_device_flags,
prompt_examples,
get_available_devices,
clear_all,
save_output_img,
resize_stencil,
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,
StencilPipeline,
UpscalerPipeline,
)
from apps.stable_diffusion.src.pipelines import Text2ImagePipeline
from apps.stable_diffusion.src.schedulers import get_schedulers

View File

@@ -2,7 +2,6 @@ from apps.stable_diffusion.src.models.model_wrappers import (
SharkifyStableDiffusionModel,
)
from apps.stable_diffusion.src.models.opt_params import (
get_vae_encode,
get_vae,
get_unet,
get_clip,

View File

@@ -1,26 +1,19 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from collections import defaultdict
from pathlib import Path
import torch
import safetensors.torch
import traceback
import subprocess
import re
import sys
import os
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
base_models,
args,
fetch_or_delete_vmfbs,
preprocessCKPT,
convert_original_vae,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
get_stencil_model_id,
update_lora_weight,
)
@@ -35,169 +28,26 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
elif shape[i] == "width":
new_shape.append(width)
elif isinstance(shape[i], str):
if "*" in shape[i]:
if "batch_size" in shape[i]:
mul_val = int(shape[i].split("*")[0])
if "batch_size" in shape[i]:
new_shape.append(batch_size * mul_val)
elif "height" in shape[i]:
new_shape.append(height * mul_val)
elif "width" in shape[i]:
new_shape.append(width * mul_val)
elif "/" in shape[i]:
import math
div_val = int(shape[i].split("/")[1])
if "batch_size" in shape[i]:
new_shape.append(math.ceil(batch_size / div_val))
elif "height" in shape[i]:
new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val))
new_shape.append(batch_size * mul_val)
else:
new_shape.append(shape[i])
return new_shape
def check_compilation(model, model_name):
if not model:
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
class SharkifyStableDiffusionModel:
def __init__(
self,
model_id: str,
custom_weights: str,
custom_vae: str,
precision: str,
max_len: int = 64,
width: int = 512,
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
low_cpu_mem_usage: bool = False,
debug: bool = False,
sharktank_dir: str = "",
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
use_lora: str = "",
use_quantize: str = None,
return_mlir: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.use_quantize = use_quantize
if custom_weights != "":
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
if not os.path.isfile(weights_path):
subprocess.run(["wget", custom_weights, "-O", weights_path])
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
self.model_id = "stabilityai/stable-diffusion-2-1-base"
self.custom_vae = custom_vae
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
"_"
+ str(batch_size)
+ "_"
+ str(max_len)
+ "_"
+ str(height)
+ "_"
+ str(width)
+ "_"
+ precision
)
print(f'use_tuned? sharkify: {use_tuned}')
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.is_upscaler = is_upscaler
self.use_stencil = get_stencil_model_id(use_stencil)
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora
print(self.model_name)
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
self.sharktank_dir = sharktank_dir
self.generate_vmfb = generate_vmfb
self.inputs = dict()
self.model_to_run = ""
if self.custom_weights != "":
self.model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights, self.is_inpaint)
else:
self.model_to_run = args.hf_model_id
self.custom_vae = self.process_custom_vae()
self.base_model_id = fetch_and_update_base_model_id(self.model_to_run)
if self.base_model_id != "" and args.ckpt_loc != "":
args.hf_model_id = self.base_model_id
self.return_mlir = return_mlir
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
index = 0
for model in sub_model_list:
sub_model = model
model_config = self.model_name
if "vae" == model:
if self.custom_vae != "":
model_config = model_config + get_path_stem(self.custom_vae)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
model_config = model_config + get_path_stem(self.use_stencil)
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
sys.exit("please specify max_len in the range [32, 77].")
if not (width % 8 == 0 and width >= 128):
sys.exit("width should be greater than 128 and multiple of 8")
if not (height % 8 == 0 and height >= 128):
sys.exit("height should be greater than 128 and multiple of 8")
# Get the input info for a model i.e. "unet", "clip", "vae", etc.
def get_input_info_for(self, model_info):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = []
for inp in model_info:
shape = model_info[inp]["shape"]
dtype = dtype_config[model_info[inp]["dtype"]]
# Get the input info for various models i.e. "unet", "clip", "vae".
def get_input_info(model_info, max_len, width, height, batch_size):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = defaultdict(list)
for k in model_info:
for inp in model_info[k]:
shape = model_info[k][inp]["shape"]
dtype = dtype_config[model_info[k][inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, self.max_len, self.width, self.height, self.batch_size
shape, max_len, width, height, batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
@@ -207,64 +57,79 @@ class SharkifyStableDiffusionModel:
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map.append(tensor)
return input_map
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
input_map[k].append(tensor)
return input_map
class SharkifyStableDiffusionModel:
def __init__(
self,
model_id: str,
custom_weights: str,
precision: str,
max_len: int = 64,
width: int = 512,
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
self.model_id = model_id if custom_weights == "" else custom_weights
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
str(batch_size)
+ "_"
+ str(max_len)
+ "_"
+ str(height)
+ "_"
+ str(width)
+ "_"
+ precision
)
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
# We need a better naming convention for the .vmfbs because despite
# using the custom model variant the .vmfb names remain the same and
# it'll always pick up the compiled .vmfb instead of compiling the
# custom model.
# So, currently, we add `self.model_id` in the `self.model_name` of
# .vmfb file.
# TODO: Have a better way of naming the vmfbs using self.model_name.
model_name = re.sub(r"\W+", "_", self.model_id)
if model_name[0] == "_":
model_name = model_name[1:]
self.model_name = self.model_name + "_" + model_name
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
sys.exit("please specify max_len in the range [32, 77].")
if not (width % 8 == 0 and width >= 384):
sys.exit("width should be greater than 384 and multiple of 8")
if not (height % 8 == 0 and height >= 384):
sys.exit("height should be greater than 384 and multiple of 8")
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
def forward(self, input):
latents = self.vae.encode(input).latent_dist.sample()
return 0.18215 * latents
vae_encode = VaeEncodeModel()
inputs = tuple(self.inputs["vae_encode"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
shark_vae_encode, vae_encode_mlir = compile_through_fx(
vae_encode,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae_encode"],
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae_encode",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae_encode, vae_encode_mlir
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
super().__init__()
self.vae = None
if custom_vae == "":
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae
def forward(self, input):
@@ -277,166 +142,33 @@ class SharkifyStableDiffusionModel:
x = x * 255.0
return x.round()
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
vae = VaeModel()
inputs = tuple(self.inputs["vae"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae, vae_mlir = compile_through_fx(
is_f16 = True if self.precision == "fp16" else False
vae_name = "base_vae" if self.base_vae else "vae"
shark_vae = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
model_name=vae_name + self.model_name,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae, vae_mlir
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
def forward( self, latent, timestep, text_embedding, guidance_scale, control1,
control2, control3, control4, control5, control6, control7,
control8, control9, control10, control11, control12, control13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple([ control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12,])
mb_res_samples = control13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
down_block_additional_residuals=db_res_samples,
mid_block_additional_residual=mb_res_samples,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = ControlledUnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["stencil_unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self):
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
):
super().__init__()
self.cnet = ControlNetModel.from_pretrained(
model_id,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.cnet.in_channels
self.train(False)
def forward(
self,
latent,
timestep,
text_embedding,
stencil_image_input,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
# TODO: guidance NOT NEEDED change in `get_input_info` later
latents = torch.cat(
[latent] * 2
) # needs to be same as controlledUNET latents
stencil_image = torch.cat(
[stencil_image_input] * 2
) # needs to be same as controlledUNET latents
down_block_res_samples, mid_block_res_sample = self.cnet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
controlnet_cond=stencil_image,
return_dict=False,
)
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
scnet = StencilControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
input_mask = [True, True, True, True]
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
extended_model_name=self.model_name["stencil_adaptor"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_adaptor",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_cnet, cnet_mlir
return shark_vae
def get_unet(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
def __init__(self, model_id=self.model_id):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
if(args.attention_slicing.isdigit()):
self.unet.set_attention_slice(int(args.attention_slicing))
else:
self.unet.set_attention_slice(args.attention_slicing)
# TODO: Instead of flattening the `control` try to use the list.
def forward(
self, latent, timestep, text_embedding, guidance_scale,
self, latent, timestep, text_embedding, guidance_scale
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
@@ -449,238 +181,115 @@ class SharkifyStableDiffusionModel:
)
return noise_pred
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
unet = UnetModel()
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet, unet_mlir = compile_through_fx(
shark_unet = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_unet_upscaler(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, latent, timestep, text_embedding, noise_level):
unet_out = self.unet.forward(
latent,
timestep,
text_embedding,
noise_level,
return_dict=False,
)[0]
return unet_out
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
model_name="unet" + self.model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
return shark_unet
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
def __init__(self, model_id=self.model_id):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip, clip_mlir = compile_through_fx(
clip_model = CLIPText()
shark_clip = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
extended_model_name=self.model_name["clip"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
model_name="clip" + self.model_name,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip, clip_mlir
return shark_clip
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
if not custom_vae.endswith((".ckpt", ".safetensors")):
return self.custom_vae
try:
preprocessCKPT(self.custom_vae)
return get_path_to_diffusers_checkpoint(self.custom_vae)
except:
print("Processing standalone Vae checkpoint")
vae_checkpoint = None
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
if custom_vae.endswith(".ckpt"):
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
else:
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
if "state_dict" in vae_checkpoint:
vae_checkpoint = vae_checkpoint["state_dict"]
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
# configiration.
def compile_all(self, base_model_id):
self.inputs = get_input_info(
base_models[base_model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
return compiled_clip, compiled_unet, compiled_vae
try:
vae_checkpoint = convert_original_vae(vae_checkpoint)
finally:
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
def __call__(self):
# Step 1:
# -- Fetch all vmfbs for the model, if present, else delete the lot.
vmfbs = fetch_or_delete_vmfbs(
self.model_name, self.base_vae, self.precision
)
if vmfbs[0]:
# -- If all vmfbs are indeed present, we also try and fetch the base
# model configuration for running SD with custom checkpoints.
if self.custom_weights != "":
args.hf_model_id = fetch_and_update_base_model_id(self.custom_weights)
if args.hf_model_id == "":
sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.")
print("Loaded vmfbs from cache and successfully fetched base model configuration.")
return vmfbs
def compile_unet_variants(self, model):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler()
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
return get_unet()
else:
return self.get_unet()
# Step 2:
# -- If vmfbs weren't found, we try to see if the base model configuration
# for the required SD run is known to us and bypass the retry mechanism.
model_to_run = ""
if self.custom_weights != "":
model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights)
else:
return self.get_controlled_unet()
model_to_run = args.hf_model_id
base_model_fetched = fetch_and_update_base_model_id(model_to_run)
if base_model_fetched != "":
print("Compiling all the models with the fetched base model configuration.")
if args.ckpt_loc != "":
args.hf_model_id = base_model_fetched
return self.compile_all(base_model_fetched)
def vae_encode(self):
try:
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
check_compilation(compiled_vae_encode, "Vae Encode")
if self.return_mlir:
return vae_encode_mlir
return compiled_vae_encode
except Exception as e:
sys.exit(e)
def clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(base_models["clip"])
compiled_clip, clip_mlir = self.get_clip()
check_compilation(compiled_clip, "Clip")
if self.return_mlir:
return clip_mlir
return compiled_clip
except Exception as e:
sys.exit(e)
def unet(self):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
compiled_unet = None
unet_inputs = base_models[model]
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet, unet_mlir = self.compile_unet_variants(model)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
try:
compiled_unet, unet_mlir = self.compile_unet_variants(model)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(self.model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
break
check_compilation(compiled_unet, "Unet")
if self.return_mlir:
return unet_mlir
return compiled_unet
except Exception as e:
sys.exit(e)
def vae(self):
try:
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
self.inputs["vae"] = self.get_input_info_for(vae_input)
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")
if self.return_mlir:
return vae_mlir
return compiled_vae
except Exception as e:
sys.exit(e)
def controlnet(self):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
check_compilation(compiled_stencil_adaptor, "Stencil")
if self.return_mlir:
return controlnet_mlir
return compiled_stencil_adaptor
except Exception as e:
sys.exit(e)
# Step 3:
# -- This is the retry mechanism where the base model's configuration is not
# known to us and figure that out by trial and error.
print("Inferring base model configuration.")
for model_id in base_models:
try:
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id)
except Exception as e:
if args.enable_stack_trace:
traceback.print_exc()
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please re-run the command with `--enable_stack_trace` flag and create an issue with detailed log at https://github.com/nod-ai/SHARK/issues"
)

View File

@@ -9,26 +9,15 @@ from apps.stable_diffusion.src.utils import (
hf_model_variant_map = {
"Linaqruf/anything-v3.0": ["anythingv3", "v1_4"],
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v1_4"],
"prompthero/openjourney": ["openjourney", "v1_4"],
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v1_4"],
"Linaqruf/anything-v3.0": ["anythingv3", "v2_1base"],
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v2_1base"],
"prompthero/openjourney": ["openjourney", "v2_1base"],
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
bucket_key = "gs://shark_tank/prashant_nod"
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
@@ -50,12 +39,6 @@ def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
# TODO: Get the quantize model from model_db.json
if args.use_quantize == "int8":
bk, mk, flags = get_quantize_model()
return get_shark_model(bk, mk, flags)
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
@@ -69,23 +52,6 @@ def get_unet():
return get_shark_model(bucket, model_name, iree_flags)
def get_vae_encode():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.

View File

@@ -1,18 +1,3 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
InpaintPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpaint import (
OutpaintPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_stencil import (
StencilPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_upscaler import (
UpscalerPipeline,
)

View File

@@ -1,200 +0,0 @@
import torch
import time
import numpy as np
from tqdm.auto import tqdm
from random import randint
from PIL import Image
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class Image2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_image_latents(
self,
image,
batch_size,
height,
width,
generator,
num_inference_steps,
strength,
dtype,
):
# Pre process image -> get image encoded -> process latents
# TODO: process with variable HxW combos
# Pre process image
image = image.resize((width, height))
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
image_arr = 2 * (image_arr - 0.5)
# set scheduler steps
self.scheduler.set_timesteps(num_inference_steps)
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps
)
t_start = max(num_inference_steps - init_timestep, 0)
# timesteps reduced as per strength
timesteps = self.scheduler.timesteps[t_start:]
# new number of steps to be used as per strength will be
# num_inference_steps = num_inference_steps - t_start
# image encode
latents = self.encode_image((image_arr,))
latents = torch.from_numpy(latents).to(dtype)
# add noise to data
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
latents = self.scheduler.add_noise(
latents, noise, timesteps[0].repeat(1)
)
return latents, timesteps
def encode_image(self, input_image):
self.load_vae_encode()
vae_encode_start = time.time()
latents = self.vae_encode("forward", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
if self.ondemand:
self.unload_vae_encode()
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
return latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
strength,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
use_stencil,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare input image latent
image_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=image_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -1,473 +0,0 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from PIL import Image, ImageOps
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class InpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
latents = latents * self.scheduler.init_noise_sigma
return latents
def get_crop_region(self, mask, pad=0):
h, w = mask.shape
crop_left = 0
for i in range(w):
if not (mask[:, i] == 0).all():
break
crop_left += 1
crop_right = 0
for i in reversed(range(w)):
if not (mask[:, i] == 0).all():
break
crop_right += 1
crop_top = 0
for i in range(h):
if not (mask[i] == 0).all():
break
crop_top += 1
crop_bottom = 0
for i in reversed(range(h)):
if not (mask[i] == 0).all():
break
crop_bottom += 1
return (
int(max(crop_left - pad, 0)),
int(max(crop_top - pad, 0)),
int(min(w - crop_right + pad, w)),
int(min(h - crop_bottom + pad, h)),
)
def expand_crop_region(
self,
crop_region,
processing_width,
processing_height,
image_width,
image_height,
):
x1, y1, x2, y2 = crop_region
ratio_crop_region = (x2 - x1) / (y2 - y1)
ratio_processing = processing_width / processing_height
if ratio_crop_region > ratio_processing:
desired_height = (x2 - x1) / ratio_processing
desired_height_diff = int(desired_height - (y2 - y1))
y1 -= desired_height_diff // 2
y2 += desired_height_diff - desired_height_diff // 2
if y2 >= image_height:
diff = y2 - image_height
y2 -= diff
y1 -= diff
if y1 < 0:
y2 -= y1
y1 -= y1
if y2 >= image_height:
y2 = image_height
else:
desired_width = (y2 - y1) * ratio_processing
desired_width_diff = int(desired_width - (x2 - x1))
x1 -= desired_width_diff // 2
x2 += desired_width_diff - desired_width_diff // 2
if x2 >= image_width:
diff = x2 - image_width
x2 -= diff
x1 -= diff
if x1 < 0:
x2 -= x1
x1 -= x1
if x2 >= image_width:
x2 = image_width
return x1, y1, x2, y2
def resize_image(self, resize_mode, im, width, height):
"""
resize_mode:
0: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
1: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
"""
if resize_mode == 0:
ratio = width / height
src_ratio = im.width / im.height
src_w = (
width if ratio > src_ratio else im.width * height // im.height
)
src_h = (
height if ratio <= src_ratio else im.height * width // im.width
)
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(
resized,
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
)
else:
ratio = width / height
src_ratio = im.width / im.height
src_w = (
width if ratio < src_ratio else im.width * height // im.height
)
src_h = (
height if ratio >= src_ratio else im.height * width // im.width
)
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(
resized,
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
)
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
res.paste(
resized.resize((width, fill_height), box=(0, 0, width, 0)),
box=(0, 0),
)
res.paste(
resized.resize(
(width, fill_height),
box=(0, resized.height, width, resized.height),
),
box=(0, fill_height + src_h),
)
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
res.paste(
resized.resize(
(fill_width, height), box=(0, 0, 0, height)
),
box=(0, 0),
)
res.paste(
resized.resize(
(fill_width, height),
box=(resized.width, 0, resized.width, height),
),
box=(fill_width + src_w, 0),
)
return res
def prepare_mask_and_masked_image(
self,
image,
mask,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
):
# preprocess image
image = image.resize((width, height))
mask = mask.resize((width, height))
paste_to = ()
overlay_image = None
if inpaint_full_res:
# prepare overlay image
overlay_image = Image.new("RGB", (image.width, image.height))
overlay_image.paste(
image.convert("RGB"),
mask=ImageOps.invert(mask.convert("L")),
)
# prepare mask
mask = mask.convert("L")
crop_region = self.get_crop_region(
np.array(mask), inpaint_full_res_padding
)
crop_region = self.expand_crop_region(
crop_region, width, height, mask.width, mask.height
)
x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region)
mask = self.resize_image(1, mask, width, height)
paste_to = (x1, y1, x2 - x1, y2 - y1)
# prepare image
image = image.crop(crop_region)
image = self.resize_image(1, image, width, height)
if isinstance(image, (Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image, paste_to, overlay_image
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
):
mask = torch.nn.functional.interpolate(
mask, size=(height // 8, width // 8)
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
return mask, masked_image_latents
def apply_overlay(self, image, paste_loc, overlay):
x, y, w, h = paste_loc
image = self.resize_image(0, image, w, h)
overlay.paste(image, (x, y))
return overlay
def generate_images(
self,
prompts,
neg_prompts,
image,
mask_image,
batch_size,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Preprocess mask and image
(
mask,
masked_image,
paste_to,
overlay_image,
) = self.prepare_mask_and_masked_image(
image,
mask_image,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
)
# Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask=mask,
masked_image=masked_image,
batch_size=batch_size,
height=height,
width=width,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
mask=mask,
masked_image_latents=masked_image_latents,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
if inpaint_full_res:
output_image = self.apply_overlay(
all_imgs[0], paste_to, overlay_image
)
return [output_image]
return all_imgs

View File

@@ -1,567 +0,0 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from PIL import Image, ImageDraw, ImageFilter
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
import math
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class OutpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_mask_and_masked_image(
self, image, mask, mask_blur, width, height
):
if mask_blur > 0:
mask = mask.filter(ImageFilter.GaussianBlur(mask_blur))
image = image.resize((width, height))
mask = mask.resize((width, height))
# preprocess image
if isinstance(image, (Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
):
mask = torch.nn.functional.interpolate(
mask, size=(height // 8, width // 8)
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
return mask, masked_image_latents
def get_matched_noise(
self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05
):
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(data):
if data.ndim > 2: # has channels
out_fft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]),
dtype=np.complex128,
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_fft[:, :, c] = np.fft.fft2(
np.fft.fftshift(c_data), norm="ortho"
)
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
else: # one channel
out_fft = np.zeros(
(data.shape[0], data.shape[1]), dtype=np.complex128
)
out_fft[:, :] = np.fft.fft2(
np.fft.fftshift(data), norm="ortho"
)
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
return out_fft
def _ifft2(data):
if data.ndim > 2: # has channels
out_ifft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]),
dtype=np.complex128,
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_ifft[:, :, c] = np.fft.ifft2(
np.fft.fftshift(c_data), norm="ortho"
)
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
else: # one channel
out_ifft = np.zeros(
(data.shape[0], data.shape[1]), dtype=np.complex128
)
out_ifft[:, :] = np.fft.ifft2(
np.fft.fftshift(data), norm="ortho"
)
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
return out_ifft
def _get_gaussian_window(width, height, std=3.14, mode=0):
window_scale_x = float(width / min(width, height))
window_scale_y = float(height / min(width, height))
window = np.zeros((width, height))
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
for y in range(height):
fy = (y / height * 2.0 - 1.0) * window_scale_y
if mode == 0:
window[:, y] = np.exp(-(x**2 + fy**2) * std)
else:
window[:, y] = (
1 / ((x**2 + 1.0) * (fy**2 + 1.0))
) ** (std / 3.14)
return window
def _get_masked_window_rgb(np_mask_grey, hardness=1.0):
np_mask_rgb = np.zeros(
(np_mask_grey.shape[0], np_mask_grey.shape[1], 3)
)
if hardness != 1.0:
hardened = np_mask_grey[:] ** hardness
else:
hardened = np_mask_grey[:]
for c in range(3):
np_mask_rgb[:, :, c] = hardened[:]
return np_mask_rgb
def _match_cumulative_cdf(source, template):
src_values, src_unique_indices, src_counts = np.unique(
source.ravel(), return_inverse=True, return_counts=True
)
tmpl_values, tmpl_counts = np.unique(
template.ravel(), return_counts=True
)
# calculate normalized quantiles for each array
src_quantiles = np.cumsum(src_counts) / source.size
tmpl_quantiles = np.cumsum(tmpl_counts) / template.size
interp_a_values = np.interp(
src_quantiles, tmpl_quantiles, tmpl_values
)
return interp_a_values[src_unique_indices].reshape(source.shape)
def _match_histograms(image, reference):
if image.ndim != reference.ndim:
raise ValueError(
"Image and reference must have the same number of channels."
)
if image.shape[-1] != reference.shape[-1]:
raise ValueError(
"Number of channels in the input image and reference image must match!"
)
matched = np.empty(image.shape, dtype=image.dtype)
for channel in range(image.shape[-1]):
matched_channel = _match_cumulative_cdf(
image[..., channel], reference[..., channel]
)
matched[..., channel] = matched_channel
matched = matched.astype(np.float64, copy=False)
return matched
width = _np_src_image.shape[0]
height = _np_src_image.shape[1]
num_channels = _np_src_image.shape[2]
np_src_image = _np_src_image[:] * (1.0 - np_mask_rgb)
np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
# rather than leave the masked area black, we get better results from fft by filling the average unmasked color
windowed_image = _np_src_image * (
1.0 - _get_masked_window_rgb(np_mask_grey)
)
windowed_image /= np.max(windowed_image)
windowed_image += np.average(_np_src_image) * np_mask_rgb
src_fft = _fft2(
windowed_image
) # get feature statistics from masked src img
src_dist = np.absolute(src_fft)
src_phase = src_fft / src_dist
# create a generator with a static seed to make outpainting deterministic / only follow global seed
rng = np.random.default_rng(0)
noise_window = _get_gaussian_window(
width, height, mode=1
) # start with simple gaussian noise
noise_rgb = rng.random((width, height, num_channels))
noise_grey = np.sum(noise_rgb, axis=2) / 3.0
# the colorfulness of the starting noise is blended to greyscale with a parameter
noise_rgb *= color_variation
for c in range(num_channels):
noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
noise_fft = _fft2(noise_rgb)
for c in range(num_channels):
noise_fft[:, :, c] *= noise_window
noise_rgb = np.real(_ifft2(noise_fft))
shaped_noise_fft = _fft2(noise_rgb)
shaped_noise_fft[:, :, :] = (
np.absolute(shaped_noise_fft[:, :, :]) ** 2
* (src_dist**noise_q)
* src_phase
) # perform the actual shaping
# color_variation
brightness_variation = 0.0
contrast_adjusted_np_src = (
_np_src_image[:] * (brightness_variation + 1.0)
- brightness_variation * 2.0
)
shaped_noise = np.real(_ifft2(shaped_noise_fft))
shaped_noise -= np.min(shaped_noise)
shaped_noise /= np.max(shaped_noise)
shaped_noise[img_mask, :] = _match_histograms(
shaped_noise[img_mask, :] ** 1.0,
contrast_adjusted_np_src[ref_mask, :],
)
shaped_noise = (
_np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
)
matched_noise = shaped_noise[:]
return np.clip(matched_noise, 0.0, 1.0)
def generate_images(
self,
prompts,
neg_prompts,
image,
pixels,
mask_blur,
is_left,
is_right,
is_top,
is_bottom,
noise_q,
color_variation,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
process_width = width
process_height = height
left = pixels if is_left else 0
right = pixels if is_right else 0
up = pixels if is_top else 0
down = pixels if is_bottom else 0
target_w = math.ceil((image.width + left + right) / 64) * 64
target_h = math.ceil((image.height + up + down) / 64) * 64
if left > 0:
left = left * (target_w - image.width) // (left + right)
if right > 0:
right = target_w - image.width - left
if up > 0:
up = up * (target_h - image.height) // (up + down)
if down > 0:
down = target_h - image.height - up
def expand(
init_img,
expand_pixels,
is_left=False,
is_right=False,
is_top=False,
is_bottom=False,
):
is_horiz = is_left or is_right
is_vert = is_top or is_bottom
pixels_horiz = expand_pixels if is_horiz else 0
pixels_vert = expand_pixels if is_vert else 0
res_w = init_img.width + pixels_horiz
res_h = init_img.height + pixels_vert
process_res_w = math.ceil(res_w / 64) * 64
process_res_h = math.ceil(res_h / 64) * 64
img = Image.new("RGB", (process_res_w, process_res_h))
img.paste(
init_img,
(pixels_horiz if is_left else 0, pixels_vert if is_top else 0),
)
msk = Image.new("RGB", (process_res_w, process_res_h), "white")
draw = ImageDraw.Draw(msk)
draw.rectangle(
(
expand_pixels + mask_blur if is_left else 0,
expand_pixels + mask_blur if is_top else 0,
msk.width - expand_pixels - mask_blur
if is_right
else res_w,
msk.height - expand_pixels - mask_blur
if is_bottom
else res_h,
),
fill="black",
)
np_image = (np.asarray(img) / 255.0).astype(np.float64)
np_mask = (np.asarray(msk) / 255.0).astype(np.float64)
noised = self.get_matched_noise(
np_image, np_mask, noise_q, color_variation
)
output_image = Image.fromarray(
np.clip(noised * 255.0, 0.0, 255.0).astype(np.uint8),
mode="RGB",
)
target_width = (
min(width, init_img.width + pixels_horiz)
if is_horiz
else img.width
)
target_height = (
min(height, init_img.height + pixels_vert)
if is_vert
else img.height
)
crop_region = (
0 if is_left else output_image.width - target_width,
0 if is_top else output_image.height - target_height,
target_width if is_left else output_image.width,
target_height if is_top else output_image.height,
)
mask_to_process = msk.crop(crop_region)
image_to_process = output_image.crop(crop_region)
# Preprocess mask and image
mask, masked_image = self.prepare_mask_and_masked_image(
image_to_process, mask_to_process, mask_blur, width, height
)
# Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask=mask,
masked_image=masked_image,
batch_size=batch_size,
height=height,
width=width,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
mask=mask,
masked_image_latents=masked_image_latents,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
res_img = all_imgs[0].resize(
(image_to_process.width, image_to_process.height)
)
output_image.paste(
res_img,
(
0 if is_left else output_image.width - res_img.width,
0 if is_top else output_image.height - res_img.height,
),
)
output_image = output_image.crop((0, 0, res_w, res_h))
return output_image
img = image.resize((width, height))
if left > 0:
img = expand(img, left, is_left=True)
if right > 0:
img = expand(img, right, is_right=True)
if up > 0:
img = expand(img, up, is_top=True)
if down > 0:
img = expand(img, down, is_bottom=True)
return [img]

View File

@@ -1,274 +0,0 @@
import torch
import time
import numpy as np
from tqdm.auto import tqdm
from random import randint
from PIL import Image
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import controlnet_hint_conversion
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class StencilPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
def load_controlnet(self):
if self.controlnet is not None:
return
self.controlnet = self.sd_model.controlnet()
def unload_controlnet(self):
del self.controlnet
self.controlnet = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_stencil_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
controlnet_conditioning_scale: float = 1.0,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
self.load_controlnet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.ondemand:
self.unload_unet()
self.unload_controlnet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
strength,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
use_stencil,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
controlnet_hint = controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
)
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare initial latent.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
final_timesteps = self.scheduler.timesteps
# Get Image latents
latents = self.produce_stencil_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -1,4 +1,5 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
@@ -8,39 +9,34 @@ from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
def prepare_latents(
self,
@@ -110,10 +106,8 @@ class Text2ImagePipeline(StableDiffusionPipeline):
dtype=dtype,
)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -130,15 +124,12 @@ class Text2ImagePipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in range(0, latents.shape[0], batch_size):
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -1,319 +0,0 @@
import inspect
import torch
import time
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from PIL import Image
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, Image.Image):
image = [image]
if isinstance(image[0], Image.Image):
w, h = image[0].size
w, h = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class UpscalerPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
low_res_scheduler: Union[
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.low_res_scheduler = low_res_scheduler
def prepare_extra_step_kwargs(self, generator, eta):
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
latents = 1 / 0.08333 * (latents.float())
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height,
width,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_img_latents(
self,
latents,
image,
text_embeddings,
guidance_scale,
noise_level,
total_timesteps,
dtype,
cpu_scheduling,
extra_step_kwargs,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = torch.cat([latent_model_input, image], dim=1)
timestep = torch.tensor([t]).to(dtype).detach().numpy()
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if cpu_scheduling:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
else:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
noise_level,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# 4. Preprocess image
image = preprocess(image).to(dtype)
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long)
noise = torch.randn(
image.shape,
generator=generator,
).to(dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
image = torch.cat([image] * 2)
noise_level = torch.cat([noise_level] * image.shape[0])
height, width = image.shape[2:]
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
eta = 0.0
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# guidance scale as a float32 tensor.
# guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
image=image,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
noise_level=noise_level,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
extra_step_kwargs=extra_step_kwargs,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -1,5 +1,4 @@
import torch
import numpy as np
from transformers import CLIPTokenizer
from PIL import Image
from tqdm.auto import tqdm
@@ -7,14 +6,11 @@ import time
from typing import Union
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
@@ -29,105 +25,32 @@ from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
import sys
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
class StableDiffusionPipeline:
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
self.vae = None
self.text_encoder = None
self.unet = None
self.model_max_length = 77
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.unet = unet
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
self.sd_model = sd_model
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
try:
self.tokenizer = get_tokenizer()
except:
self.load_unet()
self.unload_unet()
self.tokenizer = get_tokenizer()
def load_clip(self):
if self.text_encoder is not None:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
)
self.text_encoder = self.sd_model.clip()
else:
try:
self.text_encoder = get_clip()
except:
print("download pipeline failed, falling back to import_mlir")
self.text_encoder = self.sd_model.clip()
def unload_clip(self):
del self.text_encoder
self.text_encoder = None
def load_unet(self):
if self.unet is not None:
return
if self.import_mlir or self.use_lora:
self.unet = self.sd_model.unet()
else:
try:
self.unet = get_unet()
except:
print("download pipeline failed, falling back to import_mlir")
self.unet = self.sd_model.unet()
def unload_unet(self):
del self.unet
self.unet = None
def load_vae(self):
if self.vae is not None:
return
if self.import_mlir or self.use_lora:
self.vae = self.sd_model.vae()
else:
try:
self.vae = get_vae()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae = self.sd_model.vae()
def unload_vae(self):
del self.vae
self.vae = None
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -147,14 +70,12 @@ class StableDiffusionPipeline:
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
self.load_clip()
clip_inf_start = time.time()
text_embeddings = self.text_encoder("forward", (text_input,))
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
@@ -191,29 +112,16 @@ class StableDiffusionPipeline:
total_timesteps,
dtype,
cpu_scheduling,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
self.status = SD_STATE_IDLE
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
@@ -246,11 +154,6 @@ class StableDiffusionPipeline:
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -266,17 +169,14 @@ class StableDiffusionPipeline:
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
import_mlir: bool,
model_id: str,
ckpt_loc: str,
custom_vae: str,
precision: str,
max_length: int,
batch_size: int,
@@ -284,556 +184,23 @@ class StableDiffusionPipeline:
width: int,
use_base_vae: bool,
use_tuned: bool,
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
if (
not import_mlir
and not use_lora
and cls.__name__ == "StencilPipeline"
):
sys.exit("StencilPipeline not supported with SharkTank currently.")
is_inpaint = cls.__name__ in [
"InpaintPipeline",
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
sd_model = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
use_quantize=use_quantize,
)
if cls.__name__ in ["UpscalerPipeline"]:
return cls(
scheduler,
ddpm_scheduler,
sd_model,
import_mlir,
use_lora,
ondemand,
if import_mlir:
# TODO: Delet this when on-the-fly tuning of models work.
use_tuned = False
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
# #####################################################
# Implements text embeddings with weights from prompts
# https://huggingface.co/AlanB/lpw_stable_diffusion_mod
# #####################################################
def encode_prompts_weight(
self,
prompt,
negative_prompt,
model_max_length,
do_classifier_free_guidance=True,
max_embeddings_multiples=1,
num_images_per_prompt=1,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
model_max_length (int):
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not,
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
num_images_per_prompt (`int`):
number of images that should be generated per prompt
SHARK: num_images_per_prompt is not used (defaulted to 1)
"""
# SHARK: Save model_max_length, load the clip and init inference time
self.model_max_length = model_max_length
self.load_clip()
clip_inf_start = time.time()
batch_size = len(prompt) if isinstance(prompt, list) else 1
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt
if do_classifier_free_guidance
else None,
max_embeddings_multiples=max_embeddings_multiples,
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
return cls(
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
)
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = text_embeddings.shape
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = uncond_embeddings.shape
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()
from typing import List, Optional, Union
import re
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(
pipe: StableDiffusionPipeline, prompt: List[str], max_length: int
):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens,
weights,
max_length,
bos,
eos,
no_boseos_middle=True,
chunk_length=77,
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length
if no_boseos_middle
else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
tokens[i] = (
[bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
)
if no_boseos_middle:
weights[i] = (
[1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
)
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j
* (chunk_length - 2) : min(
len(weights[i]), (j + 1) * (chunk_length - 2)
)
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
# text_embedding = pipe.text_encoder(text_input_chunk)[0]
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
formatted_text_input_chunk = torch.cat(
[text_input_chunk, text_input_chunk]
)
text_embedding = pipe.text_encoder(
"forward", (formatted_text_input_chunk,)
)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
else:
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
# Convert the result to tensor
# text_embeddings = pipe.text_encoder(text_input)[0]
formatted_text_input = torch.cat([text_input, text_input])
text_embeddings = pipe.text_encoder(
"forward", (formatted_text_input,)
)[0]
text_embeddings = torch.from_numpy(text_embeddings)[None, :]
return text_embeddings
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
pipe (`StableDiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(
pipe, prompt, max_length - 2
)
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(
prompt, max_length=max_length, truncation=True
).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(
max_length, max([len(token) for token in uncond_tokens])
)
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(
uncond_tokens, dtype=torch.long, device="cpu"
)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
prompt_weights = torch.tensor(
prompt_weights, dtype=torch.float, device="cpu"
)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
uncond_weights = torch.tensor(
uncond_weights, dtype=torch.float, device="cpu"
)
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@@ -1,13 +1,10 @@
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
@@ -20,14 +17,6 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
@@ -52,10 +41,6 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"SharkEulerDiscrete"
] = SharkEulerDiscreteScheduler.from_pretrained(

View File

@@ -40,7 +40,6 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
def compile(self):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
device = args.device.split(":", 1)[0].strip()
model_input = {
"euler": {
@@ -88,46 +87,33 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
def _import(self):
if args.import_mlir:
scaling_model = ScalingModel()
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
self.scaling_model = compile_through_fx(
scaling_model,
(example_latent, example_sigma),
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model, _ = compile_through_fx(
self.step_model = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
if args.import_mlir:
_import(self)
else:
try:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_step_" + args.precision,
iree_flags,
)
except:
print(
"failed to download model, falling back and using import_mlir"
)
args.import_mlir = True
_import(self)
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()

View File

@@ -11,10 +11,6 @@ from apps.stable_diffusion.src.utils.resources import (
)
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
controlnet_hint_conversion,
get_stencil_model_id,
)
from apps.stable_diffusion.src.utils.utils import (
get_shark_model,
compile_through_fx,
@@ -24,15 +20,8 @@ from apps.stable_diffusion.src.utils.utils import (
get_available_devices,
get_opt_flags,
preprocessCKPT,
convert_original_vae,
fetch_or_delete_vmfbs,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
get_path_stem,
get_extended_name,
clear_all,
save_output_img,
get_generation_text_info,
update_lora_weight,
resize_stencil,
)

View File

@@ -1,22 +1,34 @@
{
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"stabilityai/stable-diffusion-2-1": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae": {
"latents" : {
"shape" : [
@@ -25,272 +37,62 @@
"dtype":"f32"
}
},
"vae_upscaler": {
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"CompVis/stable-diffusion-v1-4": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"8*height","8*width"
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
}
},
"unet": {
"stabilityai/stable-diffusion-2-1": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len",
768
"max_len"
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-2-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"runwayml/stable-diffusion-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-x4-upscaler": {
"latents": {
"shape": [
"2*batch_size",
7,
"8*height",
"8*width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"noise_level": {
"shape": [2],
"dtype": "i64"
}
}
},
"stencil_adaptor": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
}
},
"stencil_unet": {
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
},
"control1": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control2": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control3": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control4": {
"shape": [2, 320, "height/2", "width/2"],
"dtype": "f32"
},
"control5": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"control6": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"control7": {
"shape": [2, 640, "height/4", "width/4"],
"dtype": "f32"
},
"control8": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"control9": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"control10": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control11": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control12": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
"dtype":"i64"
}
}
}
}
}

View File

@@ -3,8 +3,6 @@
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
"stablediffusion/inpaint_v1":"runwayml/stable-diffusion-inpainting",
"stablediffusion/inpaint_v2":"stabilityai/stable-diffusion-2-inpainting",
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
"openjourney/v1_4":"prompthero/openjourney",

View File

@@ -1,19 +1,82 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/nightly"
"stablediffusion/untuned":"gs://shark_tank/sd_untuned",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
},
{
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan"
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
"anythingv3/v2_1base/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
"analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
"openjourney/v2_1base/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
"openjourney/v2_1base/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
"openjourney/v2_1base/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
}
]

View File

@@ -45,12 +45,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}

View File

@@ -20,22 +20,6 @@ def get_device():
return device
def get_device_args():
device = get_device()
device_spec_args = []
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
gpu_flags = get_iree_gpu_args()
for flag in gpu_flags:
device_spec_args.append(flag)
elif device == "vulkan":
device_spec_args.append(
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
return device, device_spec_args
# Download the model (Unet or VAE fp16) from shark_tank
def load_model_from_tank():
from apps.stable_diffusion.src.models import (
@@ -70,66 +54,28 @@ def load_winograd_configs():
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
if not os.path.exists(WORKDIR):
os.mkdir(WORKDIR)
winograd_config_dir = os.path.join(WORKDIR, "configs", config_name)
winograd_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
def load_lower_configs(base_model_id=None):
def load_lower_configs():
from apps.stable_diffusion.src.models import get_variant_version
from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
)
if not base_model_id:
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
variant, version = get_variant_version(base_model_id)
if version == "inpaint_v1":
version = "v1_4"
elif version == "inpaint_v2":
version = "v2_1base"
config_bucket = "gs://shark_tank/sd_tuned_configs/"
device, device_spec_args = get_device_args()
spec = ""
if device_spec_args:
spec = device_spec_args[-1].split("=")[-1].strip()
if device == "vulkan":
spec = spec.split("-")[0]
variant, version = get_variant_version(args.hf_model_id)
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_version = version
if variant in ["anythingv3", "analogdiffusion"]:
args.max_length = 77
config_version = "v1_4"
if args.annotation_model == "vae":
if not spec or spec in ["rdna3", "sm_80"]:
config_name = (
f"{args.annotation_model}_{args.precision}_{device}.json"
)
else:
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
and args.width == 768
):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
args.max_length = 77
device = get_device()
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json"
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
lowering_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading lowering config file from ", lowering_config_dir)
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
@@ -137,6 +83,13 @@ def load_lower_configs(base_model_id=None):
# Annotate the model with Winograd attribute on selected conv ops
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with create_context() as ctx:
winograd_model = model_annotation(
ctx,
@@ -150,41 +103,59 @@ def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
winograd_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
if args.save_annotation:
if model_name.split("_")[-1] != "tuned":
out_file_path = os.path.join(
args.annotation_output, model_name + "_tuned_torch.mlir"
)
else:
out_file_path = os.path.join(
args.annotation_output, model_name + "_torch.mlir"
)
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return bytecode
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return bytecode, out_file_path
def dump_after_mlir(input_mlir, use_winograd):
import iree.compiler as ireec
device, device_spec_args = get_device_args()
def dump_after_mlir(input_mlir, model_name, use_winograd):
if use_winograd:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline='builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))' "
)
else:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
dump_after = "iree-preprocessing-pad-linalg-ops"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline='builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))' "
)
dump_module = ireec.compile_str(
input_mlir,
target_backends=[iree_target_map(device)],
extra_args=device_spec_args
+ [
preprocess_flag,
"--compile-to=preprocessing",
],
device_spec_args = ""
device = get_device()
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
gpu_flags = get_iree_gpu_args()
for flag in gpu_flags:
device_spec_args += flag + " "
elif device == "vulkan":
device_spec_args = (
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
print("Applying tuned configs on", model_name)
run_cmd(
f"iree-compile {input_mlir} "
"--iree-input-type=tm_tensor "
f"--iree-hal-target-backends={iree_target_map(device)} "
f"{device_spec_args}"
f"{preprocess_flag}"
"--iree-stream-resource-index-bits=64 "
"--iree-vm-target-index-bits=64 "
f"--mlir-print-ir-after={dump_after} "
"--compile-to=flow "
f"2>{args.annotation_output}/dump_after_winograd.mlir "
)
return dump_module
# For Unet annotate the model with tuned lowering configs
@@ -192,66 +163,72 @@ def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
# Dump IR after padding/img2col/winograd passes
dump_module = dump_after_mlir(input_mlir, use_winograd)
print("Applying tuned configs on", model_name)
dump_after_mlir(input_mlir, model_name, use_winograd)
# Annotate the model with lowering configs in the config file
with create_context() as ctx:
tuned_model = model_annotation(
ctx,
input_contents=dump_module,
input_contents=f"{args.annotation_output}/dump_after_winograd.mlir",
config_path=lowering_config_dir,
search_op="all",
)
# Remove the intermediate mlir and save the final annotated model
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
bytecode_stream = io.BytesIO()
tuned_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
if args.save_annotation:
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return bytecode
with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return bytecode, out_file_path
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model = annotate_with_winograd(
winograd_model, model_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
winograd_model, lowering_config_dir, model_name, use_winograd
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
model_path, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
tuned_model = mlir_model
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model, output_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
use_winograd = False
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
if model_from_tank:
mlir_model = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
else:
# Just use this function to convert bytecode to string
orig_model, model_path = annotate_with_winograd(
mlir_model, "", model_name
)
mlir_model = model_path
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)
print(f"Saved the annotated mlir in {output_path}.")
return tuned_model
if __name__ == "__main__":
mlir_model, model_name = load_model_from_tank()
sd_model_annotation(mlir_model, model_name)
sd_model_annotation(mlir_model, model_name, model_from_tank=True)

View File

@@ -1,5 +1,4 @@
import argparse
import os
from pathlib import Path
@@ -7,13 +6,6 @@ def path_expand(s):
return Path(s).expanduser().resolve()
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
@@ -22,33 +14,21 @@ p = argparse.ArgumentParser(
### Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
)
p.add_argument(
"-p",
"--prompts",
nargs="+",
default=["cyberpunk forest by Salvador Dali"],
action="append",
default=[],
help="text of which images to be generated.",
)
p.add_argument(
"--negative_prompts",
nargs="+",
default=["trees, green"],
default=[""],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--img_path",
type=str,
help="Path to the image input for img2img/inpainting",
)
p.add_argument(
"--steps",
type=int,
@@ -59,8 +39,8 @@ p.add_argument(
p.add_argument(
"--seed",
type=int,
default=-1,
help="the seed to use. -1 for a random one.",
default=42,
help="the seed to use.",
)
p.add_argument(
@@ -68,14 +48,13 @@ p.add_argument(
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `batch_count`.",
help="the number of inferences to be made in a single `run`.",
)
p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 769, 8),
help="the height of the output image.",
)
@@ -83,7 +62,6 @@ p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 769, 8),
help="the width of the output image.",
)
@@ -94,13 +72,6 @@ p.add_argument(
help="the value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="the value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
@@ -108,121 +79,6 @@ p.add_argument(
help="max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="the strength of change applied on the given input image for img2img",
)
##############################################################################
### Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The no. of steps to train",
)
##############################################################################
### Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture",
)
p.add_argument(
"--inpaint_full_res_padding",
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding",
)
p.add_argument(
"--pixels",
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting",
)
p.add_argument(
"--mask_blur",
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend left for outpainting",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend right for outpainting",
)
p.add_argument(
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend top for outpainting",
)
p.add_argument(
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend bottom for outpainting",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0)",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0)",
)
##############################################################################
### Model Config and Usage Params
##############################################################################
@@ -292,10 +148,10 @@ p.add_argument(
)
p.add_argument(
"--batch_count",
"--runs",
type=int,
default=1,
help="number of batch to be generated with random seeds in single execution",
help="number of images to be generated with random seeds in single execution",
)
p.add_argument(
@@ -305,13 +161,6 @@ p.add_argument(
help="Path to SD's .ckpt file.",
)
p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
)
p.add_argument(
"--hf_model_id",
type=str,
@@ -320,45 +169,10 @@ p.add_argument(
)
p.add_argument(
"--low_cpu_mem_usage",
"--enable_stack_trace",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer)",
)
p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble"],
help="Enable the stencil feature.",
)
p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
)
p.add_argument(
"--ondemand",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM",
help="Enable showing the stack trace when retrying the base model configuration",
)
##############################################################################
@@ -366,7 +180,7 @@ p.add_argument(
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
"--iree-vulkan-target-triple",
type=str,
default="",
help="Specify target triple for vulkan",
@@ -381,7 +195,7 @@ p.add_argument(
p.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
default="4147483648",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
@@ -465,17 +279,11 @@ p.add_argument(
p.add_argument(
"--write_metadata_to_png",
default=True,
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="if import_mlir is True, saves mlir via the debug option in shark importer. Does nothing if import_mlir is false (the default)",
)
##############################################################################
### Web UI flags
##############################################################################
@@ -484,7 +292,7 @@ p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the progress bar animation during image generation",
help="flag for removing the pregress bar animation during image generation",
)
p.add_argument(
@@ -493,13 +301,7 @@ p.add_argument(
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="one of: [api, app, web]",
)
p.add_argument(
"--share",
@@ -515,12 +317,6 @@ p.add_argument(
help="flag for setting server port",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for enabling rest API",
)
##############################################################################
### SD model auto-annotation flags
##############################################################################
@@ -540,39 +336,10 @@ p.add_argument(
)
p.add_argument(
"--save_annotation",
"--use_winograd",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file",
help="Apply Winograd on selected conv ops.",
)
##############################################################################
### SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all",
)
args, unknown = p.parse_known_args()
if args.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
os.getcwd(), args.hf_model_id.replace("/", "_")
)

View File

@@ -1,2 +0,0 @@
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector

View File

@@ -1,6 +0,0 @@
import cv2
class CannyDetector:
def __call__(self, img, low_threshold, high_threshold):
return cv2.Canny(img, low_threshold, high_threshold)

View File

@@ -1,62 +0,0 @@
import requests
from pathlib import Path
import torch
import numpy as np
# from annotator.util import annotator_ckpts_path
from apps.stable_diffusion.src.utils.stencils.openpose.body import Body
from apps.stable_diffusion.src.utils.stencils.openpose.hand import Hand
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
draw_bodypose,
draw_handpose,
handDetect,
)
body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth"
class OpenposeDetector:
def __init__(self):
cwd = Path.cwd()
ckpt_path = Path(cwd, "stencil_annotator")
ckpt_path.mkdir(parents=True, exist_ok=True)
body_modelpath = ckpt_path / "body_pose_model.pth"
hand_modelpath = ckpt_path / "hand_pose_model.pth"
if not body_modelpath.is_file():
r = requests.get(body_model_path, allow_redirects=True)
open(body_modelpath, "wb").write(r.content)
if not hand_modelpath.is_file():
r = requests.get(hand_model_path, allow_redirects=True)
open(hand_modelpath, "wb").write(r.content)
self.body_estimation = Body(body_modelpath)
self.hand_estimation = Hand(hand_modelpath)
def __call__(self, oriImg, hand=False):
oriImg = oriImg[:, :, ::-1].copy()
with torch.no_grad():
candidate, subset = self.body_estimation(oriImg)
canvas = np.zeros_like(oriImg)
canvas = draw_bodypose(canvas, candidate, subset)
if hand:
hands_list = handDetect(candidate, subset, oriImg)
all_hand_peaks = []
for x, y, w, is_left in hands_list:
peaks = self.hand_estimation(
oriImg[y : y + w, x : x + w, :]
)
peaks[:, 0] = np.where(
peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x
)
peaks[:, 1] = np.where(
peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y
)
all_hand_peaks.append(peaks)
canvas = draw_handpose(canvas, all_hand_peaks)
return canvas, dict(
candidate=candidate.tolist(), subset=subset.tolist()
)

View File

@@ -1,499 +0,0 @@
import cv2
import numpy as np
import math
from scipy.ndimage.filters import gaussian_filter
import torch
import torch.nn as nn
from collections import OrderedDict
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
make_layers,
transfer,
padRightDownCorner,
)
class BodyPoseModel(nn.Module):
def __init__(self):
super(BodyPoseModel, self).__init__()
# these layers have no relu layer
no_relu_layers = [
"conv5_5_CPM_L1",
"conv5_5_CPM_L2",
"Mconv7_stage2_L1",
"Mconv7_stage2_L2",
"Mconv7_stage3_L1",
"Mconv7_stage3_L2",
"Mconv7_stage4_L1",
"Mconv7_stage4_L2",
"Mconv7_stage5_L1",
"Mconv7_stage5_L2",
"Mconv7_stage6_L1",
"Mconv7_stage6_L1",
]
blocks = {}
block0 = OrderedDict(
[
("conv1_1", [3, 64, 3, 1, 1]),
("conv1_2", [64, 64, 3, 1, 1]),
("pool1_stage1", [2, 2, 0]),
("conv2_1", [64, 128, 3, 1, 1]),
("conv2_2", [128, 128, 3, 1, 1]),
("pool2_stage1", [2, 2, 0]),
("conv3_1", [128, 256, 3, 1, 1]),
("conv3_2", [256, 256, 3, 1, 1]),
("conv3_3", [256, 256, 3, 1, 1]),
("conv3_4", [256, 256, 3, 1, 1]),
("pool3_stage1", [2, 2, 0]),
("conv4_1", [256, 512, 3, 1, 1]),
("conv4_2", [512, 512, 3, 1, 1]),
("conv4_3_CPM", [512, 256, 3, 1, 1]),
("conv4_4_CPM", [256, 128, 3, 1, 1]),
]
)
# Stage 1
block1_1 = OrderedDict(
[
("conv5_1_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_2_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_3_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_4_CPM_L1", [128, 512, 1, 1, 0]),
("conv5_5_CPM_L1", [512, 38, 1, 1, 0]),
]
)
block1_2 = OrderedDict(
[
("conv5_1_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_2_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_3_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_4_CPM_L2", [128, 512, 1, 1, 0]),
("conv5_5_CPM_L2", [512, 19, 1, 1, 0]),
]
)
blocks["block1_1"] = block1_1
blocks["block1_2"] = block1_2
self.model0 = make_layers(block0, no_relu_layers)
# Stages 2 - 6
for i in range(2, 7):
blocks["block%d_1" % i] = OrderedDict(
[
("Mconv1_stage%d_L1" % i, [185, 128, 7, 1, 3]),
("Mconv2_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d_L1" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d_L1" % i, [128, 38, 1, 1, 0]),
]
)
blocks["block%d_2" % i] = OrderedDict(
[
("Mconv1_stage%d_L2" % i, [185, 128, 7, 1, 3]),
("Mconv2_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d_L2" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d_L2" % i, [128, 19, 1, 1, 0]),
]
)
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_1 = blocks["block1_1"]
self.model2_1 = blocks["block2_1"]
self.model3_1 = blocks["block3_1"]
self.model4_1 = blocks["block4_1"]
self.model5_1 = blocks["block5_1"]
self.model6_1 = blocks["block6_1"]
self.model1_2 = blocks["block1_2"]
self.model2_2 = blocks["block2_2"]
self.model3_2 = blocks["block3_2"]
self.model4_2 = blocks["block4_2"]
self.model5_2 = blocks["block5_2"]
self.model6_2 = blocks["block6_2"]
def forward(self, x):
out1 = self.model0(x)
out1_1 = self.model1_1(out1)
out1_2 = self.model1_2(out1)
out2 = torch.cat([out1_1, out1_2, out1], 1)
out2_1 = self.model2_1(out2)
out2_2 = self.model2_2(out2)
out3 = torch.cat([out2_1, out2_2, out1], 1)
out3_1 = self.model3_1(out3)
out3_2 = self.model3_2(out3)
out4 = torch.cat([out3_1, out3_2, out1], 1)
out4_1 = self.model4_1(out4)
out4_2 = self.model4_2(out4)
out5 = torch.cat([out4_1, out4_2, out1], 1)
out5_1 = self.model5_1(out5)
out5_2 = self.model5_2(out5)
out6 = torch.cat([out5_1, out5_2, out1], 1)
out6_1 = self.model6_1(out6)
out6_2 = self.model6_2(out6)
return out6_1, out6_2
class Body(object):
def __init__(self, model_path):
self.model = BodyPoseModel()
if torch.cuda.is_available():
self.model = self.model.cuda()
model_dict = transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def __call__(self, oriImg):
scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre1 = 0.1
thre2 = 0.05
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = cv2.resize(
oriImg,
(0, 0),
fx=scale,
fy=scale,
interpolation=cv2.INTER_CUBIC,
)
imageToTest_padded, pad = padRightDownCorner(
imageToTest, stride, padValue
)
im = (
np.transpose(
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
(3, 2, 0, 1),
)
/ 256
- 0.5
)
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
if torch.cuda.is_available():
data = data.cuda()
with torch.no_grad():
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
# extract outputs, resize, and remove padding
heatmap = np.transpose(
np.squeeze(Mconv7_stage6_L2), (1, 2, 0)
) # output 1 is heatmaps
heatmap = cv2.resize(
heatmap,
(0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC,
)
heatmap = heatmap[
: imageToTest_padded.shape[0] - pad[2],
: imageToTest_padded.shape[1] - pad[3],
:,
]
heatmap = cv2.resize(
heatmap,
(oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC,
)
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
paf = np.transpose(
np.squeeze(Mconv7_stage6_L1), (1, 2, 0)
) # output 0 is PAFs
paf = cv2.resize(
paf,
(0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC,
)
paf = paf[
: imageToTest_padded.shape[0] - pad[2],
: imageToTest_padded.shape[1] - pad[3],
:,
]
paf = cv2.resize(
paf,
(oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC,
)
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
paf_avg += +paf / len(multiplier)
all_peaks = []
peak_counter = 0
for part in range(18):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
map_left = np.zeros(one_heatmap.shape)
map_left[1:, :] = one_heatmap[:-1, :]
map_right = np.zeros(one_heatmap.shape)
map_right[:-1, :] = one_heatmap[1:, :]
map_up = np.zeros(one_heatmap.shape)
map_up[:, 1:] = one_heatmap[:, :-1]
map_down = np.zeros(one_heatmap.shape)
map_down[:, :-1] = one_heatmap[:, 1:]
peaks_binary = np.logical_and.reduce(
(
one_heatmap >= map_left,
one_heatmap >= map_right,
one_heatmap >= map_up,
one_heatmap >= map_down,
one_heatmap > thre1,
)
)
peaks = list(
zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
) # note reverse
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
peak_id = range(peak_counter, peak_counter + len(peaks))
peaks_with_score_and_id = [
peaks_with_score[i] + (peak_id[i],)
for i in range(len(peak_id))
]
all_peaks.append(peaks_with_score_and_id)
peak_counter += len(peaks)
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
[3, 17],
[6, 18],
]
# the middle joints heatmap correpondence
mapIdx = [
[31, 32],
[39, 40],
[33, 34],
[35, 36],
[41, 42],
[43, 44],
[19, 20],
[21, 22],
[23, 24],
[25, 26],
[27, 28],
[29, 30],
[47, 48],
[49, 50],
[53, 54],
[51, 52],
[55, 56],
[37, 38],
[45, 46],
]
connection_all = []
special_k = []
mid_num = 10
for k in range(len(mapIdx)):
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
candA = all_peaks[limbSeq[k][0] - 1]
candB = all_peaks[limbSeq[k][1] - 1]
nA = len(candA)
nB = len(candB)
indexA, indexB = limbSeq[k]
if nA != 0 and nB != 0:
connection_candidate = []
for i in range(nA):
for j in range(nB):
vec = np.subtract(candB[j][:2], candA[i][:2])
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
norm = max(0.001, norm)
vec = np.divide(vec, norm)
startend = list(
zip(
np.linspace(
candA[i][0], candB[j][0], num=mid_num
),
np.linspace(
candA[i][1], candB[j][1], num=mid_num
),
)
)
vec_x = np.array(
[
score_mid[
int(round(startend[I][1])),
int(round(startend[I][0])),
0,
]
for I in range(len(startend))
]
)
vec_y = np.array(
[
score_mid[
int(round(startend[I][1])),
int(round(startend[I][0])),
1,
]
for I in range(len(startend))
]
)
score_midpts = np.multiply(
vec_x, vec[0]
) + np.multiply(vec_y, vec[1])
score_with_dist_prior = sum(score_midpts) / len(
score_midpts
) + min(0.5 * oriImg.shape[0] / norm - 1, 0)
criterion1 = len(
np.nonzero(score_midpts > thre2)[0]
) > 0.8 * len(score_midpts)
criterion2 = score_with_dist_prior > 0
if criterion1 and criterion2:
connection_candidate.append(
[
i,
j,
score_with_dist_prior,
score_with_dist_prior
+ candA[i][2]
+ candB[j][2],
]
)
connection_candidate = sorted(
connection_candidate, key=lambda x: x[2], reverse=True
)
connection = np.zeros((0, 5))
for c in range(len(connection_candidate)):
i, j, s = connection_candidate[c][0:3]
if i not in connection[:, 3] and j not in connection[:, 4]:
connection = np.vstack(
[connection, [candA[i][3], candB[j][3], s, i, j]]
)
if len(connection) >= min(nA, nB):
break
connection_all.append(connection)
else:
special_k.append(k)
connection_all.append([])
# last number in each row is the total parts number of that person
# the second last number in each row is the score of the overall configuration
subset = -1 * np.ones((0, 20))
candidate = np.array(
[item for sublist in all_peaks for item in sublist]
)
for k in range(len(mapIdx)):
if k not in special_k:
partAs = connection_all[k][:, 0]
partBs = connection_all[k][:, 1]
indexA, indexB = np.array(limbSeq[k]) - 1
for i in range(len(connection_all[k])): # = 1:size(temp,1)
found = 0
subset_idx = [-1, -1]
for j in range(len(subset)): # 1:size(subset,1):
if (
subset[j][indexA] == partAs[i]
or subset[j][indexB] == partBs[i]
):
subset_idx[found] = j
found += 1
if found == 1:
j = subset_idx[0]
if subset[j][indexB] != partBs[i]:
subset[j][indexB] = partBs[i]
subset[j][-1] += 1
subset[j][-2] += (
candidate[partBs[i].astype(int), 2]
+ connection_all[k][i][2]
)
elif found == 2: # if found 2 and disjoint, merge them
j1, j2 = subset_idx
membership = (
(subset[j1] >= 0).astype(int)
+ (subset[j2] >= 0).astype(int)
)[:-2]
if len(np.nonzero(membership == 2)[0]) == 0: # merge
subset[j1][:-2] += subset[j2][:-2] + 1
subset[j1][-2:] += subset[j2][-2:]
subset[j1][-2] += connection_all[k][i][2]
subset = np.delete(subset, j2, 0)
else: # as like found == 1
subset[j1][indexB] = partBs[i]
subset[j1][-1] += 1
subset[j1][-2] += (
candidate[partBs[i].astype(int), 2]
+ connection_all[k][i][2]
)
# if find no partA in the subset, create a new subset
elif not found and k < 17:
row = -1 * np.ones(20)
row[indexA] = partAs[i]
row[indexB] = partBs[i]
row[-1] = 2
row[-2] = (
sum(
candidate[
connection_all[k][i, :2].astype(int), 2
]
)
+ connection_all[k][i][2]
)
subset = np.vstack([subset, row])
# delete some rows of subset which has few parts occur
deleteIdx = []
for i in range(len(subset)):
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
deleteIdx.append(i)
subset = np.delete(subset, deleteIdx, axis=0)
# candidate: x, y, score, id
return candidate, subset

View File

@@ -1,205 +0,0 @@
import cv2
import numpy as np
from scipy.ndimage.filters import gaussian_filter
import torch
import torch.nn as nn
from skimage.measure import label
from collections import OrderedDict
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
make_layers,
transfer,
padRightDownCorner,
npmax,
)
class HandPoseModel(nn.Module):
def __init__(self):
super(HandPoseModel, self).__init__()
# these layers have no relu layer
no_relu_layers = [
"conv6_2_CPM",
"Mconv7_stage2",
"Mconv7_stage3",
"Mconv7_stage4",
"Mconv7_stage5",
"Mconv7_stage6",
]
# stage 1
block1_0 = OrderedDict(
[
("conv1_1", [3, 64, 3, 1, 1]),
("conv1_2", [64, 64, 3, 1, 1]),
("pool1_stage1", [2, 2, 0]),
("conv2_1", [64, 128, 3, 1, 1]),
("conv2_2", [128, 128, 3, 1, 1]),
("pool2_stage1", [2, 2, 0]),
("conv3_1", [128, 256, 3, 1, 1]),
("conv3_2", [256, 256, 3, 1, 1]),
("conv3_3", [256, 256, 3, 1, 1]),
("conv3_4", [256, 256, 3, 1, 1]),
("pool3_stage1", [2, 2, 0]),
("conv4_1", [256, 512, 3, 1, 1]),
("conv4_2", [512, 512, 3, 1, 1]),
("conv4_3", [512, 512, 3, 1, 1]),
("conv4_4", [512, 512, 3, 1, 1]),
("conv5_1", [512, 512, 3, 1, 1]),
("conv5_2", [512, 512, 3, 1, 1]),
("conv5_3_CPM", [512, 128, 3, 1, 1]),
]
)
block1_1 = OrderedDict(
[
("conv6_1_CPM", [128, 512, 1, 1, 0]),
("conv6_2_CPM", [512, 22, 1, 1, 0]),
]
)
blocks = {}
blocks["block1_0"] = block1_0
blocks["block1_1"] = block1_1
# stage 2-6
for i in range(2, 7):
blocks["block%d" % i] = OrderedDict(
[
("Mconv1_stage%d" % i, [150, 128, 7, 1, 3]),
("Mconv2_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d" % i, [128, 22, 1, 1, 0]),
]
)
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_0 = blocks["block1_0"]
self.model1_1 = blocks["block1_1"]
self.model2 = blocks["block2"]
self.model3 = blocks["block3"]
self.model4 = blocks["block4"]
self.model5 = blocks["block5"]
self.model6 = blocks["block6"]
def forward(self, x):
out1_0 = self.model1_0(x)
out1_1 = self.model1_1(out1_0)
concat_stage2 = torch.cat([out1_1, out1_0], 1)
out_stage2 = self.model2(concat_stage2)
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
out_stage3 = self.model3(concat_stage3)
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
out_stage4 = self.model4(concat_stage4)
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
out_stage5 = self.model5(concat_stage5)
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
out_stage6 = self.model6(concat_stage6)
return out_stage6
class Hand(object):
def __init__(self, model_path):
self.model = HandPoseModel()
if torch.cuda.is_available():
self.model = self.model.cuda()
model_dict = transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def __call__(self, oriImg):
scale_search = [0.5, 1.0, 1.5, 2.0]
# scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre = 0.05
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
# paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = cv2.resize(
oriImg,
(0, 0),
fx=scale,
fy=scale,
interpolation=cv2.INTER_CUBIC,
)
imageToTest_padded, pad = padRightDownCorner(
imageToTest, stride, padValue
)
im = (
np.transpose(
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
(3, 2, 0, 1),
)
/ 256
- 0.5
)
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
if torch.cuda.is_available():
data = data.cuda()
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
with torch.no_grad():
output = self.model(data).cpu().numpy()
# output = self.model(data).numpy()q
# extract outputs, resize, and remove padding
heatmap = np.transpose(
np.squeeze(output), (1, 2, 0)
) # output 1 is heatmaps
heatmap = cv2.resize(
heatmap,
(0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC,
)
heatmap = heatmap[
: imageToTest_padded.shape[0] - pad[2],
: imageToTest_padded.shape[1] - pad[3],
:,
]
heatmap = cv2.resize(
heatmap,
(oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC,
)
heatmap_avg += heatmap / len(multiplier)
all_peaks = []
for part in range(21):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
# 全部小于阈值
if np.sum(binary) == 0:
all_peaks.append([0, 0])
continue
label_img, label_numbers = label(
binary, return_num=True, connectivity=binary.ndim
)
max_index = (
np.argmax(
[
np.sum(map_ori[label_img == i])
for i in range(1, label_numbers + 1)
]
)
+ 1
)
label_img[label_img != max_index] = 0
map_ori[label_img == 0] = 0
y, x = npmax(map_ori)
all_peaks.append([x, y])
return np.array(all_peaks)

View File

@@ -1,272 +0,0 @@
import math
import numpy as np
import matplotlib
import cv2
from collections import OrderedDict
import torch.nn as nn
def make_layers(block, no_relu_layers):
layers = []
for layer_name, v in block.items():
if "pool" in layer_name:
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
layers.append((layer_name, layer))
else:
conv2d = nn.Conv2d(
in_channels=v[0],
out_channels=v[1],
kernel_size=v[2],
stride=v[3],
padding=v[4],
)
layers.append((layer_name, conv2d))
if layer_name not in no_relu_layers:
layers.append(("relu_" + layer_name, nn.ReLU(inplace=True)))
return nn.Sequential(OrderedDict(layers))
def padRightDownCorner(img, stride, padValue):
h = img.shape[0]
w = img.shape[1]
pad = 4 * [None]
pad[0] = 0 # up
pad[1] = 0 # left
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
img_padded = img
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
img_padded = np.concatenate((pad_up, img_padded), axis=0)
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
img_padded = np.concatenate((pad_left, img_padded), axis=1)
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
img_padded = np.concatenate((img_padded, pad_down), axis=0)
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
img_padded = np.concatenate((img_padded, pad_right), axis=1)
return img_padded, pad
# transfer caffe model to pytorch which will match the layer name
def transfer(model, model_weights):
transfered_model_weights = {}
for weights_name in model.state_dict().keys():
transfered_model_weights[weights_name] = model_weights[
".".join(weights_name.split(".")[1:])
]
return transfered_model_weights
# draw the body keypoint and lims
def draw_bodypose(canvas, candidate, subset):
stickwidth = 4
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
[3, 17],
[6, 18],
]
colors = [
[255, 0, 0],
[255, 85, 0],
[255, 170, 0],
[255, 255, 0],
[170, 255, 0],
[85, 255, 0],
[0, 255, 0],
[0, 255, 85],
[0, 255, 170],
[0, 255, 255],
[0, 170, 255],
[0, 85, 255],
[0, 0, 255],
[85, 0, 255],
[170, 0, 255],
[255, 0, 255],
[255, 0, 170],
[255, 0, 85],
]
for i in range(18):
for n in range(len(subset)):
index = int(subset[n][i])
if index == -1:
continue
x, y = candidate[index][0:2]
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
for i in range(17):
for n in range(len(subset)):
index = subset[n][np.array(limbSeq[i]) - 1]
if -1 in index:
continue
cur_canvas = canvas.copy()
Y = candidate[index.astype(int), 0]
X = candidate[index.astype(int), 1]
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly(
(int(mY), int(mX)),
(int(length / 2), stickwidth),
int(angle),
0,
360,
1,
)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
return canvas
# image drawed by opencv is not good.
def draw_handpose(canvas, all_hand_peaks, show_number=False):
edges = [
[0, 1],
[1, 2],
[2, 3],
[3, 4],
[0, 5],
[5, 6],
[6, 7],
[7, 8],
[0, 9],
[9, 10],
[10, 11],
[11, 12],
[0, 13],
[13, 14],
[14, 15],
[15, 16],
[0, 17],
[17, 18],
[18, 19],
[19, 20],
]
for peaks in all_hand_peaks:
for ie, e in enumerate(edges):
if np.sum(np.all(peaks[e], axis=1) == 0) == 0:
x1, y1 = peaks[e[0]]
x2, y2 = peaks[e[1]]
cv2.line(
canvas,
(x1, y1),
(x2, y2),
matplotlib.colors.hsv_to_rgb(
[ie / float(len(edges)), 1.0, 1.0]
)
* 255,
thickness=2,
)
for i, keyponit in enumerate(peaks):
x, y = keyponit
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
if show_number:
cv2.putText(
canvas,
str(i),
(x, y),
cv2.FONT_HERSHEY_SIMPLEX,
0.3,
(0, 0, 0),
lineType=cv2.LINE_AA,
)
return canvas
# detect hand according to body pose keypoints
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
def handDetect(candidate, subset, oriImg):
# right hand: wrist 4, elbow 3, shoulder 2
# left hand: wrist 7, elbow 6, shoulder 5
ratioWristElbow = 0.33
detect_result = []
image_height, image_width = oriImg.shape[0:2]
for person in subset.astype(int):
# if any of three not detected
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
if not (has_left or has_right):
continue
hands = []
# left hand
if has_left:
left_shoulder_index, left_elbow_index, left_wrist_index = person[
[5, 6, 7]
]
x1, y1 = candidate[left_shoulder_index][:2]
x2, y2 = candidate[left_elbow_index][:2]
x3, y3 = candidate[left_wrist_index][:2]
hands.append([x1, y1, x2, y2, x3, y3, True])
# right hand
if has_right:
(
right_shoulder_index,
right_elbow_index,
right_wrist_index,
) = person[[2, 3, 4]]
x1, y1 = candidate[right_shoulder_index][:2]
x2, y2 = candidate[right_elbow_index][:2]
x3, y3 = candidate[right_wrist_index][:2]
hands.append([x1, y1, x2, y2, x3, y3, False])
for x1, y1, x2, y2, x3, y3, is_left in hands:
x = x3 + ratioWristElbow * (x3 - x2)
y = y3 + ratioWristElbow * (y3 - y2)
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
# x-y refers to the center --> offset to topLeft point
x -= width / 2
y -= width / 2 # width = height
# overflow the image
if x < 0:
x = 0
if y < 0:
y = 0
width1 = width
width2 = width
if x + width > image_width:
width1 = image_width - x
if y + width > image_height:
width2 = image_height - y
width = min(width1, width2)
# the max hand box value is 20 pixels
if width >= 20:
detect_result.append([int(x), int(y), int(width), is_left])
"""
return value: [[x, y, w, True if left hand else False]].
width=height since the network require squared input.
x, y is the coordinate of top left
"""
return detect_result
# get max index of 2d array
def npmax(array):
arrayindex = array.argmax(1)
arrayvalue = array.max(1)
i = arrayvalue.argmax()
j = arrayindex[i]
return (i,)

View File

@@ -1,186 +0,0 @@
import numpy as np
from PIL import Image
import torch
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
)
stencil = {}
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def controlnet_hint_shaping(
controlnet_hint, height, width, dtype, num_images_per_prompt=1
):
channels = 3
if isinstance(controlnet_hint, torch.Tensor):
# torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt)
shape_chw = (channels, height, width)
shape_bchw = (1, channels, height, width)
shape_nchw = (num_images_per_prompt, channels, height, width)
if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]:
controlnet_hint = controlnet_hint.to(
dtype=dtype, device=torch.device("cpu")
)
if controlnet_hint.shape != shape_nchw:
controlnet_hint = controlnet_hint.repeat(
num_images_per_prompt, 1, 1, 1
)
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width}),"
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, np.ndarray):
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
# hwc is opencv compatible image format. Color channel must be BGR Format.
if controlnet_hint.shape == (height, width):
controlnet_hint = np.repeat(
controlnet_hint[:, :, np.newaxis], channels, axis=2
) # hw -> hwc(c==3)
shape_hwc = (height, width, channels)
shape_bhwc = (1, height, width, channels)
shape_nhwc = (num_images_per_prompt, height, width, channels)
if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]:
controlnet_hint = torch.from_numpy(controlnet_hint.copy())
controlnet_hint = controlnet_hint.to(
dtype=dtype, device=torch.device("cpu")
)
controlnet_hint /= 255.0
if controlnet_hint.shape != shape_nhwc:
controlnet_hint = controlnet_hint.repeat(
num_images_per_prompt, 1, 1, 1
)
controlnet_hint = controlnet_hint.permute(
0, 3, 1, 2
) # b h w c -> b c h w
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({width}, {channels}), "
+ f"({height}, {width}, {channels}), "
+ f"(1, {height}, {width}, {channels}) or "
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, Image.Image):
if controlnet_hint.size == (width, height):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
controlnet_hint = np.array(controlnet_hint) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return controlnet_hint_shaping(
controlnet_hint, height, width, num_images_per_prompt
)
else:
raise ValueError(
f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}"
)
else:
raise ValueError(
f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
)
def controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
):
controlnet_hint = None
match use_stencil:
case "canny":
print("Detecting edge with canny")
controlnet_hint = hint_canny(image)
case "openpose":
print("Detecting human pose")
controlnet_hint = hint_openpose(image)
case "scribble":
print("Working with scribble")
controlnet_hint = hint_scribble(image)
case _:
return None
controlnet_hint = controlnet_hint_shaping(
controlnet_hint, height, width, dtype, num_images_per_prompt
)
return controlnet_hint
stencil_to_model_id_map = {
"canny": "lllyasviel/control_v11p_sd15_canny",
"depth": "lllyasviel/control_v11p_sd15_depth",
"hed": "lllyasviel/sd-controlnet-hed",
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
"normal": "lllyasviel/control_v11p_sd15_normalbae",
"openpose": "lllyasviel/control_v11p_sd15_openpose",
"scribble": "lllyasviel/control_v11p_sd15_scribble",
"seg": "lllyasviel/control_v11p_sd15_seg",
}
def get_stencil_model_id(use_stencil):
if use_stencil in stencil_to_model_id_map:
return stencil_to_model_id_map[use_stencil]
return None
# Stencil 1. Canny
def hint_canny(
image: Image.Image,
low_threshold=100,
high_threshold=200,
):
with torch.no_grad():
input_image = np.array(image)
if not "canny" in stencil:
stencil["canny"] = CannyDetector()
detected_map = stencil["canny"](
input_image, low_threshold, high_threshold
)
detected_map = HWC3(detected_map)
return detected_map
# Stencil 2. OpenPose.
def hint_openpose(
image: Image.Image,
):
with torch.no_grad():
input_image = np.array(image)
if not "openpose" in stencil:
stencil["openpose"] = OpenposeDetector()
detected_map, _ = stencil["openpose"](input_image)
detected_map = HWC3(detected_map)
return detected_map
# Stencil 3. Scribble.
def hint_scribble(image: Image.Image):
with torch.no_grad():
input_image = np.array(image)
detected_map = np.zeros_like(input_image, dtype=np.uint8)
detected_map[np.min(input_image, axis=2) < 127] = 255
return detected_map

View File

@@ -1,17 +1,9 @@
import os
import gc
import json
import re
from PIL import PngImagePlugin
from PIL import Image
from datetime import datetime as dt
from csv import DictWriter
from pathlib import Path
import numpy as np
from random import randint
import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
@@ -22,40 +14,26 @@ from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys
import sys, functools, operator
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
create_vae_diffusers_config,
convert_ldm_vae_checkpoint,
load_pipeline_from_original_stable_diffusion_ckpt,
)
import requests
from io import BytesIO
from omegaconf import OmegaConf
def get_extended_name(model_name):
device = args.device.split("://", 1)[0]
extended_name = "{}_{}".format(model_name, device)
return extended_name
def get_vmfb_path_name(model_name):
vmfb_path = os.path.join(os.getcwd(), model_name + ".vmfb")
return vmfb_path
def _load_vmfb(shark_module, vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
return [vmfb_path, extended_name]
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
vmfb_path = get_vmfb_path_name(model_name)
[vmfb_path, extended_name] = get_vmfb_path_name(model_name)
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
@@ -69,7 +47,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
)
)
path = shark_module.save_module(
os.getcwd(), model_name, extra_args
os.getcwd(), extended_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
@@ -95,7 +73,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="tm_tensor"
mlir_model, device=args.device, mlir_dialect="linalg"
)
return _compile_module(shark_module, model_name, extra_args)
@@ -104,72 +82,42 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
def compile_through_fx(
model,
inputs,
extended_model_name,
model_name,
is_f16=False,
f16_input_mask=None,
use_tuned=False,
save_dir=tempfile.gettempdir(),
debug=False,
generate_vmfb=True,
extra_args=[],
base_model_id=None,
model_name=None,
precision=None,
return_mlir=False,
):
if not return_mlir and model_name is not None:
vmfb_path = get_vmfb_path_name(extended_model_name)
if os.path.isfile(vmfb_path):
shark_module = SharkInference(mlir_module=None, device=args.device)
return (
_load_vmfb(shark_module, vmfb_path, model_name, precision),
None,
)
from shark.parser import shark_args
if "cuda" in args.device:
shark_args.enable_tf32 = True
(
mlir_module,
func_name,
) = import_with_fx(
model=model,
inputs=inputs,
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=extended_model_name,
save_dir=save_dir,
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
)
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
if "unet" in model_name.split("_")[0]:
args.annotation_model = "unet"
mlir_module = sd_model_annotation(
mlir_module, extended_model_name, base_model_id
)
mlir_module = sd_model_annotation(mlir_module, model_name)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="tm_tensor",
mlir_dialect="linalg",
)
if generate_vmfb:
return (
_compile_module(shark_module, extended_model_name, extra_args),
mlir_module,
)
del mlir_module
gc.collect()
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--device_allocator=caching",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
if args.enable_rgp:
@@ -284,43 +232,24 @@ def set_init_device_flags():
args.max_length = 64
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
if (
args.precision != "fp16"
or args.height not in [512, 768]
or (args.height == 512 and args.width != 512)
or (args.height == 768 and args.width != 768)
args.hf_model_id == "prompthero/openjourney"
or args.ckpt_loc != ""
or args.precision != "fp16"
or args.height != 512
or args.width != 512
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
args.use_tuned = False
elif base_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
"runwayml/stable-diffusion-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
]:
args.use_tuned = False
elif "vulkan" in args.device and not any(
x in args.iree_vulkan_target_triple for x in ["rdna2", "rdna3"]
elif (
"vulkan" in args.device
and "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]:
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80"]:
args.use_tuned = False
elif args.use_base_vae and args.hf_model_id not in [
@@ -329,22 +258,8 @@ def set_init_device_flags():
]:
args.use_tuned = False
elif (
args.height == 768
and args.width == 768
and (
base_model_id
not in [
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
]
or "rdna3" not in args.iree_vulkan_target_triple
)
):
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
print(f"Using tuned models for {args.hf_model_id}/fp16/{args.device}.")
else:
print("Tuned models are currently not supported for this setting.")
@@ -366,27 +281,6 @@ def set_init_device_flags():
elif args.height != 512 or args.width != 512 or args.batch_size != 1:
args.import_mlir = True
elif args.use_tuned and args.hf_model_id in [
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"stabilityai/stable-diffusion-2-1",
]:
args.import_mlir = True
elif (
args.use_tuned
and "vulkan" in args.device
and "rdna2" in args.iree_vulkan_target_triple
):
args.import_mlir = True
elif (
args.use_tuned
and "cuda" in args.device
and get_cuda_sm_cc() == "sm_89"
):
args.import_mlir = True
# Utility to get list of devices available.
def get_available_devices():
@@ -412,7 +306,7 @@ def get_available_devices():
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("device => cpu")
available_devices.append("cpu")
return available_devices
@@ -461,22 +355,17 @@ def get_opt_flags(model, precision="fp16"):
return iree_flags
def get_path_stem(path):
path = Path(path)
return path.stem
def get_path_to_diffusers_checkpoint(custom_weights):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = os.path.join("diffusers", path.stem)
diffusers_directory_name = path.stem
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
return path_to_diffusers
def preprocessCKPT(custom_weights, is_inpaint=False):
def preprocessCKPT(custom_weights):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
@@ -497,140 +386,46 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
)
pipe.save_pretrained(path_to_diffusers)
print("Loading complete")
def convert_original_vae(vae_checkpoint):
vae_state_dict = {}
for key in list(vae_checkpoint.keys()):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
vae_state_dict, vae_config
)
return converted_vae_checkpoint
def load_vmfb(vmfb_path, model, precision):
model = "vae" if "base_vae" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module = SharkInference(mlir_module=None, device=args.device)
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
def processLoRA(model, use_lora, splitting_prefix):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
# This utility returns vmfbs of Clip, Unet and Vae, in case all three of them
# are present; deletes them otherwise.
def fetch_or_delete_vmfbs(basic_model_name, use_base_vae, precision="fp32"):
model_name = ["clip", "unet", "base_vae" if use_base_vae else "vae"]
vmfb_path = [
get_vmfb_path_name(model + basic_model_name)[0] for model in model_name
]
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = functools.reduce(operator.__and__, vmfb_present)
compiled_models = [None] * 3
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(len(vmfb_path)):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
else:
state_dict = torch.load(use_lora)
alpha = 0.75
visited = []
# directly update weight in model
process_unet = "te" not in splitting_prefix
for key in state_dict:
if ".alpha" in key or key in visited:
continue
curr_layer = model
if ("text" not in key and process_unet) or (
"text" in key and not process_unet
):
layer_infos = (
key.split(".")[0].split(splitting_prefix)[-1].split("_")
for i in range(len(vmfb_path)):
compiled_models[i] = load_vmfb(
vmfb_path[i], model_name[i], precision
)
else:
continue
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = (
state_dict[pair_keys[0]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
weight_down = (
state_dict[pair_keys[1]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
return model
def update_lora_weight_for_unet(unet, use_lora):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
return unet
main_file_name = get_path_stem(use_lora)
if ".bin" in use_lora:
main_file_name += ".bin"
elif ".safetensors" in use_lora:
main_file_name += ".safetensors"
elif ".pt" in use_lora:
main_file_name += ".pt"
else:
sys.exit("Only .bin and .safetensors format for LoRA is supported")
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_")
def update_lora_weight(model, use_lora, model_name):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora)
try:
return processLoRA(model, use_lora, "lora_te_")
except:
return None
return compiled_models
# `fetch_and_update_base_model_id` is a resource utility function which
@@ -664,153 +459,3 @@ def sanitize_seed(seed):
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed
# clear all the cached objects to recompile cleanly.
def clear_all():
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
from glob import glob
import shutil
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
if os.path.exists(yaml):
os.remove(yaml)
home = os.path.expanduser("~")
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(
os.path.join(home, ".local/shark_tank"), ignore_errors=True
)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info={}):
output_path = args.output_dir if args.output_dir else Path.cwd()
generated_imgs_path = Path(
output_path, "generated_imgs", dt.now().strftime("%Y%m%d")
)
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
img_model = args.hf_model_id
if args.ckpt_loc:
img_model = Path(os.path.basename(args.ckpt_loc)).stem
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
)
new_entry = {
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
"PROMPT": args.prompts[0],
"NEG_PROMPT": args.negative_prompts[0],
"SEED": img_seed,
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
}
new_entry.update(extra_info)
with open(csv_path, "a", encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
dictwriter_obj.writerow(new_entry)
csv_obj.close()
if args.save_metadata_to_json:
del new_entry["OUTPUT"]
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
return text_output
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image):
width, height = image.size
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 128:
n_size = 128
if width == min_size:
width = n_size
height = n_size / aspect_ratio
else:
height = n_size
width = n_size * aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
min_size = min(width, height)
if min_size > 768:
n_size = 768
if width == min_size:
height = n_size
width = n_size * aspect_ratio
else:
width = n_size
height = n_size / aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
new_image = image.resize((n_width, n_height))
return new_image, n_width, n_height

View File

@@ -0,0 +1,70 @@
# Stable Diffusion optimized for AMD RDNA2/RDNA3 GPUs
Before you start, please be aware that this is beta software that relies on a special AMD driver. Like all StableDiffusion GUIs published so far, you need some technical expertise to set it up. We apologize in advance if you bump into issues. If that happens, please don't hesitate to ask our Discord community for help! Please be assured that we (Nod and AMD) are working hard to improve the user experience in coming months.
If it works well for you, please "star" the following GitHub projects... this is one of the best ways to help and spread the word!
* https://github.com/nod-ai/SHARK
* https://github.com/iree-org/iree
## Install this specific AMD Drivers (AMD latest may not have all the fixes).
### AMD KB Drivers for RDNA2 and RDNA3:
*AMD Software: Adrenalin Edition 22.11.1 for MLIR/IREE Driver Version 22.20.29.09 for Windows® 10 and Windows® 11 (Windows Driver Store Version 31.0.12029.9003)*
First, for RDNA2 users, download this special driver in a folder of your choice. We recommend you keep the installation files around, since you may need to re-install it later, if Windows Update decides to overwrite it:
https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mlir-iree
For RDNA3, the latest driver 23.1.2 supports MLIR/IREE as well: https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-1-2-kb
KNOWN ISSUES with this special AMD driver:
* `Windows Update` may (depending how it's configured) automatically install a new official AMD driver that overwrites this IREE-specific driver. If Stable Diffusion used to work, then a few days later, it slows down a lot or produces incorrect results (e.g. black images), this may be the cause. To fix this problem, please check the installed driver version, and re-install the special driver if needed. (TODO: document how to prevent this `Windows Update` behavior!)
* Some people using this special driver experience mouse pointer accuracy issues, especially if using a larger-than-default mouse pointer. The clicked point isn't centered properly. One possible work-around is to reset the pointer size to "1" in "Change pointer size and color".
## Installation
Download the latest Windows SHARK SD binary [492 here](https://github.com/nod-ai/SHARK/releases/download/20230203.492/shark_sd_20230203_492.exe) in a folder of your choice. If you want nighly builds, you can look for them on the GitHub releases page.
Notes:
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR which can be outdated if you run a new EXE from the same folder. You can use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you:
* clear all the local artifacts with `--clear_all` OR
* clear the Vulkan shader cache: For Windows users this can be done by clearing the contents of `C:\Users\%username%\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
* clear the `huggingface` cache. In Windows, this is `C:\Users\%username%\.cache\huggingface`.
## Running
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE to start the web browser)
* The first run may take about 10-15 minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
* If successful, you will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/?__theme=dark.
## Stopping
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment. The application should stop.
* Please make sure to do the above step before you attempt to update the EXE to a new version.
# Results
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
Here are some samples generated:
![tajmahal, snow, sunflowers, oil on canvas_0](https://user-images.githubusercontent.com/74956/204934186-141f7e43-6eb2-4e89-a99c-4704d20444b3.jpg)
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
The output on a 7900XTX would like:
```shell
Stats for run 0:
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
Total image generation time: 2.5788655281066895sec
```
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.

View File

@@ -0,0 +1,209 @@
/* Overwrite the Gradio default theme with their .dark theme declarations */
:root {
--color-focus-primary: var(--color-grey-700);
--color-focus-secondary: var(--color-grey-600);
--color-focus-ring: rgb(55 65 81);
--color-background-primary: var(--color-grey-950);
--color-background-secondary: var(--color-grey-900);
--color-background-tertiary: var(--color-grey-800);
--color-text-body: var(--color-grey-100);
--color-text-label: var(--color-grey-200);
--color-text-placeholder: var(--color-grey);
--color-text-subdued: var(--color-grey-400);
--color-text-link-base: var(--color-blue-500);
--color-text-link-hover: var(--color-blue-400);
--color-text-link-visited: var(--color-blue-600);
--color-text-link-active: var(--color-blue-500);
--color-text-code-background: var(--color-grey-800);
--color-text-code-border: color.border-primary;
--color-border-primary: var(--color-grey-700);
--color-border-secondary: var(--color-grey-600);
--color-border-highlight: var(--color-accent-base);
--color-accent-base: var(--color-orange-500);
--color-accent-light: var(--color-orange-300);
--color-accent-dark: var(--color-orange-700);
--color-functional-error-base: var(--color-red-400);
--color-functional-error-subdued: var(--color-red-300);
--color-functional-error-background: var(--color-background-primary);
--color-functional-info-base: var(--color-yellow);
--color-functional-info-subdued: var(--color-yellow-300);
--color-functional-success-base: var(--color-green);
--color-functional-success-subdued: var(--color-green-300);
--shadow-spread: 2px;
--api-background: linear-gradient(to bottom, rgba(255, 216, 180, .05), transparent);
--api-pill-background: var(--color-orange-400);
--api-pill-border: var(--color-orange-600);
--api-pill-text: var(--color-orange-900);
--block-border-color: var(--color-border-primary);
--block-background: var(--color-background-tertiary);
--uploadable-border-color-hover: var(--color-border-primary);
--uploadable-border-color-loaded: var(--color-functional-success);
--uploadable-text-color: var(--color-text-subdued);
--block_label-border-color: var(--color-border-primary);
--block_label-icon-color: var(--color-text-label);
--block_label-shadow: var(--shadow-drop);
--block_label-background: var(--color-background-secondary);
--icon_button-icon-color-base: var(--color-text-label);
--icon_button-icon-color-hover: var(--color-text-label);
--icon_button-background-base: var(--color-background-primary);
--icon_button-background-hover: var(--color-background-primary);
--icon_button-border-color-base: var(--color-background-primary);
--icon_button-border-color-hover: var(--color-border-secondary);
--input-text-color: var(--color-text-body);
--input-border-color-base: var(--color-border-primary);
--input-border-color-hover: var(--color-border-primary);
--input-border-color-focus: var(--color-border-primary);
--input-background-base: var(--color-background-tertiary);
--input-background-hover: var(--color-background-tertiary);
--input-background-focus: var(--color-background-tertiary);
--input-shadow: var(--shadow-inset);
--checkbox-border-color-base: var(--color-border-primary);
--checkbox-border-color-hover: var(--color-focus-primary);
--checkbox-border-color-focus: var(--color-blue-500);
--checkbox-background-base: var(--color-background-primary);
--checkbox-background-hover: var(--color-background-primary);
--checkbox-background-focus: var(--color-background-primary);
--checkbox-background-selected: var(--color-blue-600);
--checkbox-label-border-color-base: var(--color-border-primary);
--checkbox-label-border-color-hover: var(--color-border-primary);
--checkbox-label-border-color-focus: var(--color-border-secondary);
--checkbox-label-background-base: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
--checkbox-label-background-hover: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
--checkbox-label-background-focus: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
--form-seperator-color: var(--color-border-primary);
--button-primary-border-color-base: var(--color-orange-600);
--button-primary-border-color-hover: var(--color-orange-600);
--button-primary-border-color-focus: var(--color-orange-600);
--button-primary-text-color-base: white;
--button-primary-text-color-hover: white;
--button-primary-text-color-focus: white;
--button-primary-background-base: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-700));
--button-primary-background-hover: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-500));
--button-primary-background-focus: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-500));
--button-secondary-border-color-base: var(--color-grey-600);
--button-secondary-border-color-hover: var(--color-grey-600);
--button-secondary-border-color-focus: var(--color-grey-600);
--button-secondary-text-color-base: white;
--button-secondary-text-color-hover: white;
--button-secondary-text-color-focus: white;
--button-secondary-background-base: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-700));
--button-secondary-background-hover: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-600));
--button-secondary-background-focus: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-600));
--button-cancel-border-color-base: var(--color-red-600);
--button-cancel-border-color-hover: var(--color-red-600);
--button-cancel-border-color-focus: var(--color-red-600);
--button-cancel-text-color-base: white;
--button-cancel-text-color-hover: white;
--button-cancel-text-color-focus: white;
--button-cancel-background-base: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-700));
--button-cancel-background-focus: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-500));
--button-cancel-background-hover: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-500));
--button-plain-border-color-base: var(--color-grey-600);
--button-plain-border-color-hover: var(--color-grey-500);
--button-plain-border-color-focus: var(--color-grey-500);
--button-plain-text-color-base: var(--color-text-body);
--button-plain-text-color-hover: var(--color-text-body);
--button-plain-text-color-focus: var(--color-text-body);
--button-plain-background-base: var(--color-grey-700);
--button-plain-background-hover: var(--color-grey-700);
--button-plain-background-focus: var(--color-grey-700);
--gallery-label-background-base: var(--color-grey-50);
--gallery-label-background-hover: var(--color-grey-50);
--gallery-label-border-color-base: var(--color-border-primary);
--gallery-label-border-color-hover: var(--color-border-primary);
--gallery-thumb-background-base: var(--color-grey-900);
--gallery-thumb-background-hover: var(--color-grey-900);
--gallery-thumb-border-color-base: var(--color-border-primary);
--gallery-thumb-border-color-hover: var(--color-accent-base);
--gallery-thumb-border-color-focus: var(--color-blue-500);
--gallery-thumb-border-color-selected: var(--color-accent-base);
--chatbot-border-border-color-base: transparent;
--chatbot-border-border-color-latest: transparent;
--chatbot-user-background-base: ;
--chatbot-user-background-latest: ;
--chatbot-user-text-color-base: white;
--chatbot-user-text-color-latest: white;
--chatbot-bot-background-base: ;
--chatbot-bot-background-latest: ;
--chatbot-bot-text-color-base: white;
--chatbot-bot-text-color-latest: white;
--label-gradient-from: var(--color-orange-400);
--label-gradient-to: var(--color-orange-600);
--table-odd-background: var(--color-grey-900);
--table-even-background: var(--color-grey-950);
--table-background-edit: transparent;
--dataset-gallery-background-base: var(--color-background-primary);
--dataset-gallery-background-hover: var(--color-grey-800);
--dataset-dataframe-border-base: var(--color-border-primary);
--dataset-dataframe-border-hover: var(--color-border-secondary);
--dataset-table-background-base: transparent;
--dataset-table-background-hover: var(--color-grey-700);
--dataset-table-border-base: var(--color-grey-800);
--dataset-table-border-hover: var(--color-grey-800);
}
/* SHARK theme customization */
.gradio-container {
background-color: var(--color-background-primary);
}
.container {
background-color: black !important;
padding-top: 20px !important;
}
#ui_title {
padding: 10px !important;
}
#top_logo {
background-color: transparent;
border-radius: 0 !important;
border: 0;
}
#demo_title {
background-color: var(--color-background-primary);
border-radius: 0 !important;
border: 0;
padding-top: 15px;
padding-bottom: 0px;
width: 350px !important;
}
#demo_title_outer {
border-radius: 0;
}
#prompt_box_outer div:first-child {
border-radius: 0 !important
}
#prompt_box textarea {
background-color: var(--color-background-primary) !important;
}
#prompt_examples {
margin: 0 !important;
}
#prompt_examples svg {
display: none !important;
}
#ui_body {
background-color: var(--color-background-secondary) !important;
padding: 10px !important;
border-radius: 0.5em !important;
}
#img_result+div {
display: none !important;
}
footer {
display: none !important;
}

View File

@@ -1,313 +1,264 @@
from multiprocessing import Process, freeze_support
import os
import sys
import transformers
from apps.stable_diffusion.src import args, clear_all
import apps.stable_diffusion.web.utils.global_obj as global_obj
from pathlib import Path
import glob
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
if args.clear_all:
clear_all()
def launch_app(address):
from tkinter import Tk
import webview
window = Tk()
# getting screen width and height of display
width = window.winfo_screenwidth()
height = window.winfo_screenheight()
webview.create_window(
"SHARK AI Studio", url=address, width=width, height=height
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
webview.start(private_mode=False)
return os.path.join(base_path, relative_path)
if __name__ == "__main__":
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
if args.api or "api" in args.ui.split(","):
from apps.stable_diffusion.web.ui import (
txt2img_api,
img2img_api,
upscaler_api,
inpaint_api,
)
from fastapi import FastAPI, APIRouter
import uvicorn
import gradio as gr
from PIL import Image
from apps.stable_diffusion.src import (
prompt_examples,
args,
get_available_devices,
)
from apps.stable_diffusion.scripts import txt2img_inf
# init global sd pipeline and config
global_obj._init()
nodlogo_loc = resource_path("logos/nod-logo.png")
sdlogo_loc = resource_path("logos/sd-demo-logo.png")
app = FastAPI()
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route(
# "/sdapi/v1/outpaint", outpaint_api, methods=["post"]
# )
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
app.include_router(APIRouter())
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
sys.exit(0)
import gradio as gr
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
)
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
demo_css = resource_path("css/sd_dark_theme.css")
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create custom models folders if they don't exist
create_custom_models_folders()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
logo2 = Image.open(sdlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=100)
with gr.Column(scale=5, elem_id="demo_title_outer"):
gr.Image(
value=logo2,
show_label=False,
interactive=False,
elem_id="demo_title",
).style(width=150, height=100)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
ckpt_path = (
Path(args.ckpt_dir)
if args.ckpt_dir
else Path(Path.cwd(), "models")
)
ckpt_path.mkdir(parents=True, exist_ok=True)
types = (
"*.ckpt",
"*.safetensors",
) # the tuple of file types
ckpt_files = ["None"]
for extn in types:
files = glob.glob(os.path.join(ckpt_path, extn))
ckpt_files.extend(files)
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {ckpt_path})",
value="None",
choices=ckpt_files
+ [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
],
)
hf_model_id = gr.Textbox(
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
)
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
img2img_sendto_upscaler,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
inpaint_sendto_upscaler,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
outpaint_sendto_upscaler,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
lora_train_web,
model_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
modelmanager_sendto_inpaint,
modelmanager_sendto_outpaint,
modelmanager_sendto_upscaler,
stablelm_chat,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value="cyberpunk forest by Salvador Dali",
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="trees, green",
lines=1,
elem_id="prompt_box",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
label="Scheduler",
value="SharkEulerDiscrete",
choices=[
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
],
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=True,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=False,
interactive=True,
)
with gr.Row():
height = gr.Slider(
384, 786, value=512, step=8, label="Height"
)
width = gr.Slider(
384, 786, value=512, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=64,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=50, step=1, label="Steps"
)
guidance_scale = gr.Slider(
0,
50,
value=7.5,
step=0.1,
label="CFG Scale",
)
with gr.Row():
batch_count = gr.Slider(
1,
10,
value=1,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=1,
step=1,
label="Batch Size",
interactive=True,
)
with gr.Row():
seed = gr.Number(value=-1, precision=0, label="Seed")
available_devices = get_available_devices()
device = gr.Dropdown(
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => Math.floor(Math.random() * 4294967295)",
)
stable_diffusion = gr.Button("Generate Image")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
# init global sd pipeline and config
global_obj._init()
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
with gr.Column(scale=1, min_width=600):
with gr.Group():
gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2], height="auto")
std_output = gr.Textbox(
value="Nothing to show.",
lines=4,
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
kwargs = dict(
fn=txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
def register_modelmanager_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
prompt.submit(**kwargs)
stable_diffusion.click(**kwargs)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
img2img_web.render()
with gr.TabItem(label="Inpainting", id=2):
inpaint_web.render()
with gr.TabItem(label="Outpainting", id=3):
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.TabItem(label="Model Manager", id=5):
model_web.render()
with gr.TabItem(label="Chat Bot(Experimental)", id=6):
stablelm_chat.render()
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
lora_train_web.render()
register_button_click(
txt2img_sendto_img2img,
1,
[txt2img_gallery],
[img2img_init_image, tabs],
)
register_button_click(
txt2img_sendto_inpaint,
2,
[txt2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_outpaint,
3,
[txt2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_upscaler,
4,
[txt2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
img2img_sendto_inpaint,
2,
[img2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_outpaint,
3,
[img2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_upscaler,
4,
[img2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
inpaint_sendto_img2img,
1,
[inpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
inpaint_sendto_outpaint,
3,
[inpaint_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
inpaint_sendto_upscaler,
4,
[inpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
outpaint_sendto_img2img,
1,
[outpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
outpaint_sendto_inpaint,
2,
[outpaint_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
outpaint_sendto_upscaler,
4,
[outpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
upscaler_sendto_img2img,
1,
[upscaler_gallery],
[img2img_init_image, tabs],
)
register_button_click(
upscaler_sendto_inpaint,
2,
[upscaler_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
upscaler_sendto_outpaint,
3,
[upscaler_gallery],
[outpaint_init_image, tabs],
)
register_modelmanager_button(
modelmanager_sendto_txt2img,
0,
[hf_models],
[txt2img_custom_model, txt2img_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_img2img,
1,
[hf_models],
[img2img_custom_model, img2img_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_inpaint,
2,
[hf_models],
[inpaint_custom_model, inpaint_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_outpaint,
3,
[hf_models],
[outpaint_custom_model, outpaint_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_upscaler,
4,
[hf_models],
[upscaler_custom_model, upscaler_hf_model_id, tabs],
)
sd_web.queue()
if args.ui == "app":
t = Process(
target=launch_app, args=[f"http://localhost:{args.server_port}"]
)
t.start()
sd_web.launch(
share=args.share,
inbrowser=args.ui == "web",
server_name="0.0.0.0",
server_port=args.server_port,
)
shark_web.queue()
shark_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

View File

Before

Width:  |  Height:  |  Size: 10 KiB

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.0 KiB

View File

@@ -1,71 +0,0 @@
from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_inf,
txt2img_api,
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
img2img_api,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
img2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.inpaint_ui import (
inpaint_inf,
inpaint_api,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
inpaint_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.outpaint_ui import (
outpaint_inf,
outpaint_api,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
outpaint_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.upscaler_ui import (
upscaler_inf,
upscaler_api,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
)
from apps.stable_diffusion.web.ui.model_manager import (
model_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
modelmanager_sendto_inpaint,
modelmanager_sendto_outpaint,
modelmanager_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat

View File

@@ -1,232 +0,0 @@
/*
Apply Gradio dark theme to the default Gradio theme.
Procedure to upgrade the dark theme:
- Using your browser, visit http://localhost:8080/?__theme=dark
- Open your browser inspector, search for the .dark css class
- Copy .dark class declarations, apply them here into :root
*/
:root {
--body-background-fill: var(--background-fill-primary);
--body-text-color: var(--neutral-100);
--color-accent-soft: var(--neutral-700);
--background-fill-primary: var(--neutral-950);
--background-fill-secondary: var(--neutral-900);
--border-color-accent: var(--neutral-600);
--border-color-primary: var(--neutral-700);
--link-text-color-active: var(--secondary-500);
--link-text-color: var(--secondary-500);
--link-text-color-hover: var(--secondary-400);
--link-text-color-visited: var(--secondary-600);
--body-text-color-subdued: var(--neutral-400);
--shadow-spread: 1px;
--block-background-fill: var(--neutral-800);
--block-border-color: var(--border-color-primary);
--block_border_width: None;
--block-info-text-color: var(--body-text-color-subdued);
--block-label-background-fill: var(--background-fill-secondary);
--block-label-border-color: var(--border-color-primary);
--block_label_border_width: None;
--block-label-text-color: var(--neutral-200);
--block_shadow: None;
--block_title_background_fill: None;
--block_title_border_color: None;
--block_title_border_width: None;
--block-title-text-color: var(--neutral-200);
--panel-background-fill: var(--background-fill-secondary);
--panel-border-color: var(--border-color-primary);
--panel_border_width: None;
--checkbox-background-color: var(--neutral-800);
--checkbox-background-color-focus: var(--checkbox-background-color);
--checkbox-background-color-hover: var(--checkbox-background-color);
--checkbox-background-color-selected: var(--secondary-600);
--checkbox-border-color: var(--neutral-700);
--checkbox-border-color-focus: var(--secondary-500);
--checkbox-border-color-hover: var(--neutral-600);
--checkbox-border-color-selected: var(--secondary-600);
--checkbox-border-width: var(--input-border-width);
--checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
--checkbox-label-border-color: var(--border-color-primary);
--checkbox-label-border-color-hover: var(--checkbox-label-border-color);
--checkbox-label-border-width: var(--input-border-width);
--checkbox-label-text-color: var(--body-text-color);
--checkbox-label-text-color-selected: var(--checkbox-label-text-color);
--error-background-fill: var(--background-fill-primary);
--error-border-color: var(--border-color-primary);
--error_border_width: None;
--error-text-color: #ef4444;
--input-background-fill: var(--neutral-800);
--input-background-fill-focus: var(--secondary-600);
--input-background-fill-hover: var(--input-background-fill);
--input-border-color: var(--border-color-primary);
--input-border-color-focus: var(--neutral-700);
--input-border-color-hover: var(--input-border-color);
--input_border_width: None;
--input-placeholder-color: var(--neutral-500);
--input_shadow: None;
--input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset);
--loader_color: None;
--slider_color: None;
--stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600));
--table-border-color: var(--neutral-700);
--table-even-background-fill: var(--neutral-950);
--table-odd-background-fill: var(--neutral-900);
--table-row-focus: var(--color-accent-soft);
--button-border-width: var(--input-border-width);
--button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
--button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
--button-cancel-border-color: #dc2626;
--button-cancel-border-color-hover: var(--button-cancel-border-color);
--button-cancel-text-color: white;
--button-cancel-text-color-hover: var(--button-cancel-text-color);
--button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600));
--button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500));
--button-primary-border-color: var(--primary-500);
--button-primary-border-color-hover: var(--button-primary-border-color);
--button-primary-text-color: white;
--button-primary-text-color-hover: var(--button-primary-text-color);
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
--button-secondary-border-color: var(--neutral-600);
--button-secondary-border-color-hover: var(--button-secondary-border-color);
--button-secondary-text-color: white;
--button-secondary-text-color-hover: var(--button-secondary-text-color);
--block-border-width: 1px;
--block-label-border-width: 1px;
--form-gap-width: 1px;
--error-border-width: 1px;
--input-border-width: 1px;
}
/* SHARK theme */
body {
background-color: var(--background-fill-primary);
}
/* display in full width for desktop devices */
@media (min-width: 1536px)
{
.gradio-container {
max-width: var(--size-full) !important;
}
}
.gradio-container .contain {
padding: 0 var(--size-4) !important;
}
.container {
background-color: black !important;
padding-top: var(--size-5) !important;
}
#ui_title {
padding: var(--size-2) 0 0 var(--size-1);
}
#top_logo {
background-color: transparent;
border-radius: 0 !important;
border: 0;
}
#demo_title_outer {
border-radius: 0;
}
#prompt_box_outer div:first-child {
border-radius: 0 !important
}
#prompt_box textarea, #negative_prompt_box textarea {
background-color: var(--background-fill-primary) !important;
}
#prompt_examples {
margin: 0 !important;
}
#prompt_examples svg {
display: none !important;
}
#ui_body {
padding: var(--size-2) !important;
border-radius: 0.5em !important;
}
#img_result+div {
display: none !important;
}
footer {
display: none !important;
}
#gallery + div {
border-radius: 0 !important;
}
/* Gallery: Remove the default square ratio thumbnail and limit images height to the container */
#gallery .thumbnail-item.thumbnail-lg {
aspect-ratio: unset;
max-height: calc(55vh - (2 * var(--spacing-lg)));
}
@media (min-width: 1921px) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
#gallery .grid-wrap, #gallery .preview{
min-height: calc(768px + 4px + var(--size-14));
max-height: calc(768px + 4px + var(--size-14));
}
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
#gallery .thumbnail-item.thumbnail-lg {
max-height: 770px !important;
}
}
/* Don't upscale when viewing in solo image mode */
#gallery .preview img {
object-fit: scale-down;
}
/* Navbar images in cover mode*/
#gallery .preview .thumbnail-item img {
object-fit: cover;
}
/* Limit the stable diffusion text output height */
#std_output textarea {
max-height: 215px;
}
/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */
#gallery .wrap.default {
pointer-events: none;
}
/* Import Png info box */
#txt2img_prompt_image {
height: var(--size-32) !important;
}
/* Hide "remove buttons" from ui dropdowns */
#custom_model .token-remove.remove-all,
#lora_weights .token-remove.remove-all,
#scheduler .token-remove.remove-all,
#device .token-remove.remove-all,
#stencil_model .token-remove.remove-all {
display: none;
}
/* Hide selected items from ui dropdowns */
#custom_model .options .item .inner-item,
#scheduler .options .item .inner-item,
#device .options .item .inner-item,
#stencil_model .options .item .inner-item {
display:none;
}
/* Hide the download icon from the nod logo */
#top_logo .download {
display: none;
}

View File

@@ -1,653 +0,0 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
import PIL
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list_cpu_only,
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
StencilPipeline,
resize_stencil,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
import numpy as np
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def img2img_inf(
prompt: str,
negative_prompt: str,
image_dict,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.strength = strength
args.scheduler = scheduler
args.img_path = "not none"
args.ondemand = ondemand
if image_dict is None:
return None, "An Initial Image is required"
if use_stencil == "scribble":
image = image_dict["mask"].convert("RGB")
elif isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
else:
image = image_dict["image"].convert("RGB")
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_stencil = None if use_stencil == "None" else use_stencil
args.use_stencil = use_stencil
if use_stencil is not None:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
args.precision = precision
dtype = torch.float32 if precision == "fp32" else torch.half
new_config_obj = Config(
"img2img",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=use_stencil,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
if use_stencil is not None:
args.use_tuned = False
global_obj.set_sd_obj(
StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
else:
global_obj.set_sd_obj(
Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(args.scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"STRENGTH": strength}
text_output = ""
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
batch_size,
height,
width,
steps,
strength,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
# yield generated_imgs, text_output
return generated_imgs, text_output
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Img2Img Rest API.
def img2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = img2img_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["denoising_strength"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
use_stencil=InputData["use_stencil"]
if "use_stencil" in InputData.keys()
else "None",
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
img2img_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
img2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
elem_id="negative_prompt_box",
)
img2img_init_image = gr.Image(
label="Input Image",
source="upload",
tool="sketch",
type="pil",
).style(height=300)
with gr.Accordion(label="Stencil Options", open=False):
with gr.Row():
use_stencil = gr.Dropdown(
elem_id="stencil_model",
label="Stencil model",
value="None",
choices=["None", "canny", "openpose", "scribble"],
)
def show_canvas(choice):
if choice == "scribble":
return (
gr.Slider.update(visible=True),
gr.Slider.update(visible=True),
gr.Button.update(visible=True),
)
else:
return (
gr.Slider.update(visible=False),
gr.Slider.update(visible=False),
gr.Button.update(visible=False),
)
def create_canvas(w, h):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
create_button = gr.Button(
label="Start",
value="Open drawing canvas!",
visible=False,
)
create_button.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[img2img_init_image],
)
use_stencil.change(
fn=show_canvas,
inputs=use_stencil,
outputs=[canvas_width, canvas_height, create_button],
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
strength = gr.Slider(
0,
1,
value=args.strength,
step=0.01,
label="Denoising Strength",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
img2img_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
output_dir = (
args.output_dir if args.output_dir else Path.cwd()
)
output_dir = Path(output_dir, "generated_imgs")
std_output = gr.Textbox(
value=f"Images will be saved at {output_dir}",
lines=1,
elem_id="std_output",
show_label=False,
)
with gr.Row():
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
img2img_sendto_outpaint = gr.Button(
value="SendTo Outpaint"
)
img2img_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
kwargs = dict(
fn=img2img_inf,
inputs=[
prompt,
negative_prompt,
img2img_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
img2img_custom_model,
img2img_hf_model_id,
custom_vae,
precision,
device,
max_length,
use_stencil,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[img2img_gallery, std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -1,554 +0,0 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list_cpu_only,
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def inpaint_inf(
prompt: str,
negative_prompt: str,
image_dict,
height: int,
width: int,
inpaint_full_res: bool,
inpaint_full_res_padding: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
args.mask_path = "not none"
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"inpaint",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
mask_image,
batch_size,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Inpaint Rest API.
def inpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["image"])
mask = decode_base64_to_image(InputData["mask"])
res = inpaint_inf(
InputData["prompt"],
InputData["negative_prompt"],
{"image": init_image, "mask": mask},
InputData["height"],
InputData["width"],
InputData["is_full_res"],
InputData["full_res_padding"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
inpaint_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
)
inpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
elem_id="negative_prompt_box",
)
inpaint_init_image = gr.Image(
label="Masked Image",
source="upload",
tool="sketch",
type="pil",
).style(height=350)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
inpaint_full_res = gr.Radio(
choices=["Whole picture", "Only masked"],
type="index",
value="Whole picture",
label="Inpaint area",
)
inpaint_full_res_padding = gr.Slider(
minimum=0,
maximum=256,
step=4,
value=32,
label="Only masked padding, pixels",
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
inpaint_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
output_dir = (
args.output_dir if args.output_dir else Path.cwd()
)
output_dir = Path(output_dir, "generated_imgs")
std_output = gr.Textbox(
value=f"Images will be saved at {output_dir}",
lines=1,
elem_id="std_output",
show_label=False,
)
with gr.Row():
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
inpaint_sendto_outpaint = gr.Button(
value="SendTo Outpaint"
)
inpaint_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
kwargs = dict(
fn=inpaint_inf,
inputs=[
prompt,
negative_prompt,
inpaint_init_image,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
inpaint_custom_model,
inpaint_hf_model_id,
custom_vae,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[inpaint_gallery, std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -1,223 +0,0 @@
from pathlib import Path
import os
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import lora_train
from apps.stable_diffusion.src import prompt_examples, args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
get_custom_vae_or_lora_weights,
scheduler_list,
predefined_models,
)
with gr.Blocks(title="Lora Training") as lora_train_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights to initialize weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID to initialize weights",
lines=3,
)
with gr.Group(elem_id="image_dir_box_outer"):
training_images_dir = gr.Textbox(
label="ImageDirectory",
value=args.training_images_dir,
lines=1,
elem_id="prompt_box",
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
choices=scheduler_list,
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1,
2000,
value=args.training_steps,
step=1,
label="Training Steps",
)
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
train_lora = gr.Button("Train LoRA")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
std_output = gr.Textbox(
value="Nothing to show.",
lines=1,
show_label=False,
)
lora_save_dir = (
args.lora_save_dir if args.lora_save_dir else Path.cwd()
)
lora_save_dir = Path(lora_save_dir, "lora")
output_loc = gr.Textbox(
label="Saving Lora at",
value=lora_save_dir,
)
kwargs = dict(
fn=lora_train,
inputs=[
prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
training_images_dir,
output_loc,
get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
),
],
outputs=[std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
train_click = train_lora.click(**kwargs)
stop_batch.click(fn=None, cancels=[prompt_submit, train_click])

View File

@@ -1,157 +0,0 @@
import os
import gradio as gr
import requests
from io import BytesIO
from PIL import Image
def get_hf_list(num_of_models=20):
path = "https://huggingface.co/api/models"
params = {
"search": "stable-diffusion",
"sort": "downloads",
"direction": "-1",
"limit": {num_of_models},
"full": "true",
}
response = requests.get(path, params=params)
return response.json()
def get_civit_list(num_of_models=50):
path = f"https://civitai.com/api/v1/models?limit={num_of_models}&types=Checkpoint"
headers = {"Content-Type": "application/json"}
raw_json = requests.get(path, headers=headers).json()
models = list(raw_json.items())[0][1]
safe_models = [
safe_model for safe_model in models if not safe_model["nsfw"]
]
version_id = 0 # Currently just using the first version.
safe_models = [
safe_model
for safe_model in safe_models
if safe_model["modelVersions"][version_id]["files"][0]["metadata"][
"format"
]
== "SafeTensor"
]
first_version_models = []
for model_iter in safe_models:
# The modelVersion would only keep the version name.
if (
model_iter["modelVersions"][version_id]["images"][0]["nsfw"]
!= "None"
):
continue
model_iter["modelVersions"][version_id]["modelName"] = model_iter[
"name"
]
model_iter["modelVersions"][version_id]["rating"] = model_iter[
"stats"
]["rating"]
model_iter["modelVersions"][version_id]["favoriteCount"] = model_iter[
"stats"
]["favoriteCount"]
model_iter["modelVersions"][version_id]["downloadCount"] = model_iter[
"stats"
]["downloadCount"]
first_version_models.append(model_iter["modelVersions"][version_id])
return first_version_models
def get_image_from_model(model_json):
model_id = model_json["modelId"]
image = None
for img_info in model_json["images"]:
if img_info["nsfw"] == "None":
image_url = model_json["images"][0]["url"]
response = requests.get(image_url)
image = BytesIO(response.content)
break
return image
with gr.Blocks() as model_web:
with gr.Row():
model_source = gr.Radio(
value=None,
choices=["Hugging Face", "Civitai"],
type="value",
label="Model Source",
)
model_numebr = gr.Slider(
1,
100,
value=10,
step=1,
label="Number of models",
interactive=True,
)
# TODO: add more filters
get_model_btn = gr.Button(value="Get Models")
hf_models = gr.Dropdown(
label="Hugging Face Model List",
choices=None,
value=None,
visible=False,
)
# TODO: select and SendTo
civit_models = gr.Gallery(
label="Civitai Model Gallery",
value=None,
interactive=True,
visible=False,
)
with gr.Row(visible=False) as sendto_btns:
modelmanager_sendto_txt2img = gr.Button(value="SendTo Txt2Img")
modelmanager_sendto_img2img = gr.Button(value="SendTo Img2Img")
modelmanager_sendto_inpaint = gr.Button(value="SendTo Inpaint")
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
def get_model_list(model_source, model_numebr):
if model_source == "Hugging Face":
hf_model_list = get_hf_list(model_numebr)
models = []
for model in hf_model_list:
# TODO: add model info
models.append(f'{model["modelId"]}')
return (
gr.Dropdown.update(choices=models, visible=True),
gr.Gallery.update(value=None, visible=False),
gr.Row.update(visible=True),
)
elif model_source == "Civitai":
civit_model_list = get_civit_list(model_numebr)
models = []
for model in civit_model_list:
image = get_image_from_model(model)
if image is None:
continue
# TODO: add model info
models.append(
(Image.open(image), f'{model["files"][0]["downloadUrl"]}')
)
return (
gr.Dropdown.update(value=None, choices=None, visible=False),
gr.Gallery.update(value=models, visible=True),
gr.Row.update(visible=False),
)
else:
return (
gr.Dropdown.update(value=None, choices=None, visible=False),
gr.Gallery.update(value=None, visible=False),
gr.Row.update(visible=False),
)
get_model_btn.click(
fn=get_model_list,
inputs=[model_source, model_numebr],
outputs=[
hf_models,
civit_models,
sendto_btns,
],
)

View File

@@ -1,585 +0,0 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list_cpu_only,
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def outpaint_inf(
prompt: str,
negative_prompt: str,
init_image,
pixels: int,
mask_blur: int,
directions: list,
noise_q: float,
color_variation: float,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"outpaint",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
left = True if "left" in directions else False
right = True if "right" in directions else False
top = True if "up" in directions else False
bottom = True if "down" in directions else False
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
init_image,
pixels,
mask_blur,
left,
right,
top,
bottom,
noise_q,
color_variation,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Inpaint Rest API.
def outpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = outpaint_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["pixels"],
InputData["mask_blur"],
InputData["directions"],
InputData["noise_q"],
InputData["color_variation"],
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
outpaint_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
)
outpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
elem_id="negative_prompt_box",
)
outpaint_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
pixels = gr.Slider(
8,
256,
value=args.pixels,
step=8,
label="Pixels to expand",
)
mask_blur = gr.Slider(
0,
64,
value=args.mask_blur,
step=1,
label="Mask blur",
)
with gr.Row():
directions = gr.CheckboxGroup(
label="Outpainting direction",
choices=["left", "right", "up", "down"],
value=["left", "right", "up", "down"],
)
with gr.Row():
noise_q = gr.Slider(
0.0,
4.0,
value=1.0,
step=0.01,
label="Fall-off exponent (lower=higher detail)",
)
color_variation = gr.Slider(
0.0,
1.0,
value=0.05,
step=0.01,
label="Color variation",
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=20, step=1, label="Steps"
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
outpaint_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
output_dir = (
args.output_dir if args.output_dir else Path.cwd()
)
output_dir = Path(output_dir, "generated_imgs")
std_output = gr.Textbox(
value=f"Images will be saved at {output_dir}",
lines=1,
elem_id="std_output",
show_label=False,
)
with gr.Row():
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
outpaint_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
kwargs = dict(
fn=outpaint_inf,
inputs=[
prompt,
negative_prompt,
outpaint_init_image,
pixels,
mask_blur,
directions,
noise_q,
color_variation,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
outpaint_custom_model,
outpaint_hf_model_id,
custom_vae,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[outpaint_gallery, std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -1,217 +0,0 @@
import gradio as gr
import torch
import os
from apps.language_models.scripts.stablelm import (
compile_stableLM,
StopOnTokens,
generate,
get_tokenizer,
StableLMModel,
)
from transformers import (
AutoModelForCausalLM,
TextIteratorStreamer,
StoppingCriteriaList,
)
from apps.stable_diffusion.web.ui.utils import available_devices
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
input_ids = torch.randint(3, (1, 256))
attention_mask = torch.randint(3, (1, 256))
sharkModel = 0
sharded_model = 0
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
past_key_values = None
def chat(curr_system_message, history, model):
global sharded_model
global past_key_values
if "vicuna" in model:
from apps.language_models.scripts.sharded_vicuna_fp32 import (
tokenizer,
get_sharded_model,
)
SAMPLE_INPUT_LEN = 137
curr_system_message = start_message_vicuna
if sharded_model == 0:
sharded_model = get_sharded_model()
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
prompt = messages.strip()
print("prompt = ", prompt)
input_ids = tokenizer(prompt).input_ids
new_sentence = ""
for _ in range(200):
original_input_ids = input_ids
input_id_len = len(input_ids)
pad_len = SAMPLE_INPUT_LEN - input_id_len
attention_mask = torch.ones([1, input_id_len], dtype=torch.int64)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
attention_mask = torch.nn.functional.pad(
torch.tensor(attention_mask),
(0, pad_len),
mode="constant",
value=0,
)
if _ == 0:
output = sharded_model.forward(input_ids, is_first=True)
else:
output = sharded_model.forward(
input_ids, past_key_values=past_key_values, is_first=False
)
logits = output["logits"]
past_key_values = output["past_key_values"]
new_word = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=1))
if new_word == "</s>":
break
new_sentence += " " + new_word
history[-1][1] = new_sentence
yield history
next_token = torch.argmax(logits[:, input_id_len - 1, :], dim=1)
original_input_ids.append(next_token)
input_ids = [next_token]
print(new_sentence)
return history
global sharkModel
print("In chat")
if sharkModel == 0:
tok = get_tokenizer()
# sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb")
m = AutoModelForCausalLM.from_pretrained(
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32
)
stableLMModel = StableLMModel(m)
sharkModel = compile_stableLM(
stableLMModel,
tuple([input_ids, attention_mask]),
"stableLM_linalg_f32_seqLen256",
os.getcwd(),
)
# Initialize a StopOnTokens object
stop = StopOnTokens()
# Construct the input message string for the model by concatenating the current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
curr_system_message = start_message
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
# print(messages)
# Tokenize the messages string
streamer = TextIteratorStreamer(
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
new_text=messages,
streamer=streamer,
max_new_tokens=512,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=1.0,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop]),
sharkStableLM=sharkModel,
)
words_list = generate(**generate_kwargs)
partial_text = ""
for new_text in words_list:
# print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
yield history
return words_list
with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model = gr.Dropdown(
label="Select Model",
value="TheBloke/vicuna-7B-1.1-HF",
choices=[
"stabilityai/stablelm-tuned-alpha-3b",
"TheBloke/vicuna-7B-1.1-HF",
],
)
device_value = None
for d in available_devices:
if "vulkan" in d:
device_value = d
break
device = gr.Dropdown(
label="Device",
value=device_value if device_value else available_devices[0],
interactive=False,
choices=available_devices,
)
chatbot = gr.Chatbot().style(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
).style(container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit")
stop = gr.Button("Stop")
clear = gr.Button("Clear")
system_msg = gr.Textbox(
start_message, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model],
outputs=[chatbot],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model],
outputs=[chatbot],
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)

View File

@@ -1,557 +0,0 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.utils.png_metadata import import_png_metadata
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
utils,
save_output_img,
prompt_examples,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"txt2img",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Text2Img Rest API.
def txt2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
res = txt2img_inf(
InputData["prompt"],
InputData["negative_prompt"],
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
txt2img_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"]
+ get_custom_model_files("vae"),
)
with gr.Column(scale=1, min_width=170):
png_info_img = gr.Image(
label="Import PNG info",
elem_id="txt2img_prompt_image",
type="pil",
tool="None",
visible=True,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
elem_id="negative_prompt_box",
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
choices=scheduler_list,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
384,
768,
value=args.height,
step=8,
label="Height",
)
width = gr.Slider(
384,
768,
value=args.width,
step=8,
label="Width",
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
txt2img_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
output_dir = (
args.output_dir if args.output_dir else Path.cwd()
)
output_dir = Path(output_dir, "generated_imgs")
std_output = gr.Textbox(
value=f"Images will be saved at {output_dir}",
lines=1,
elem_id="std_output",
show_label=False,
)
with gr.Row():
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
txt2img_sendto_outpaint = gr.Button(
value="SendTo Outpaint"
)
txt2img_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
kwargs = dict(
fn=txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
txt2img_custom_model,
txt2img_hf_model_id,
custom_vae,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[txt2img_gallery, std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
png_info_img.change(
fn=import_png_metadata,
inputs=[
png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
txt2img_custom_model,
txt2img_hf_model_id,
],
outputs=[
png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
txt2img_custom_model,
txt2img_hf_model_id,
],
)

View File

@@ -1,551 +0,0 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list_cpu_only,
predefined_upscaler_models,
cancel_sd,
)
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def upscaler_inf(
prompt: str,
negative_prompt: str,
init_image,
height: int,
width: int,
steps: int,
noise_level: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
if init_image is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB").resize((height, width))
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
args.height = 128
args.width = 128
new_config_obj = Config(
"upscaler",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
args.height,
args.width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_size = batch_size
args.max_length = max_length
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
"DDPM"
)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"NOISE LEVEL": noise_level}
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
low_res_img = image
high_res_img = Image.new("RGB", (height * 4, width * 4))
for i in range(0, width, 128):
for j in range(0, height, 128):
box = (j, i, j + 128, i + 128)
upscaled_image = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
low_res_img.crop(box),
batch_size,
args.height,
args.width,
steps,
noise_level,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
save_output_img(high_res_img, img_seed, extra_info)
generated_imgs.append(high_res_img)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Upscaler Rest API.
def upscaler_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = upscaler_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["noise_level"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
upscaler_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-x4-upscaler",
choices=["None"]
+ get_custom_model_files(
custom_checkpoint_type="upscaler"
)
+ predefined_upscaler_models,
)
upscaler_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
elem_id="negative_prompt_box",
)
upscaler_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="DDIM",
choices=scheduler_list_cpu_only,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
128,
512,
value=args.height,
step=128,
label="Height",
)
width = gr.Slider(
128,
512,
value=args.width,
step=128,
label="Width",
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
noise_level = gr.Slider(
0,
100,
value=args.noise_level,
step=1,
label="Noise Level",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
upscaler_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
output_dir = (
args.output_dir if args.output_dir else Path.cwd()
)
output_dir = Path(output_dir, "generated_imgs")
std_output = gr.Textbox(
value=f"Images will be saved at {output_dir}",
lines=1,
elem_id="std_output",
show_label=False,
)
with gr.Row():
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
upscaler_sendto_outpaint = gr.Button(
value="SendTo Outpaint"
)
kwargs = dict(
fn=upscaler_inf,
inputs=[
prompt,
negative_prompt,
upscaler_init_image,
height,
width,
steps,
noise_level,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
upscaler_custom_model,
upscaler_hf_model_id,
custom_vae,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[upscaler_gallery, std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
)

View File

@@ -1,162 +0,0 @@
import os
import sys
from apps.stable_diffusion.src import get_available_devices
import glob
from pathlib import Path
from apps.stable_diffusion.src import args
from dataclasses import dataclass
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
@dataclass
class Config:
mode: str
model_id: str
ckpt_loc: str
custom_vae: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
use_lora: str
use_stencil: str
ondemand: str
custom_model_filetypes = (
"*.ckpt",
"*.safetensors",
) # the tuple of file types
scheduler_list_cpu_only = [
"DDIM",
"PNDM",
"LMSDiscrete",
"KDPM2Discrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
]
scheduler_list = scheduler_list_cpu_only + [
"SharkEulerDiscrete",
]
predefined_models = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]
predefined_paint_models = [
"runwayml/stable-diffusion-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
]
predefined_upscaler_models = [
"stabilityai/stable-diffusion-x4-upscaler",
]
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
def create_custom_models_folders():
dir = ["vae", "lora"]
if not args.ckpt_dir:
dir.insert(0, "models")
else:
if not os.path.isdir(args.ckpt_dir):
sys.exit(
f"Invalid --ckpt_dir argument, {args.ckpt_dir} folder does not exists."
)
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
def get_custom_model_path(model="models"):
# structure in WebUI :-
# models or args.ckpt_dir
# |___lora
# |___vae
sub_folder = "" if model == "models" else model
if args.ckpt_dir:
return Path(Path(args.ckpt_dir), sub_folder)
else:
return Path(Path.cwd(), "models/" + sub_folder)
def get_custom_model_pathfile(custom_model_name, model="models"):
return os.path.join(get_custom_model_path(model), custom_model_name)
def get_custom_model_files(model="models", custom_checkpoint_type=""):
ckpt_files = []
file_types = custom_model_filetypes
if model == "lora":
file_types = custom_model_filetypes + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(
os.path.join(get_custom_model_path(model), extn)
)
]
match custom_checkpoint_type:
case "inpainting":
files = [
val
for val in files
if val.endswith("inpainting" + extn.removeprefix("*"))
]
case "upscaler":
files = [
val
for val in files
if val.endswith("upscaler" + extn.removeprefix("*"))
]
case _:
files = [
val
for val in files
if not (
val.endswith("inpainting" + extn.removeprefix("*"))
or val.endswith("upscaler" + extn.removeprefix("*"))
)
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_custom_vae_or_lora_weights(weights, hf_id, model):
use_weight = ""
if weights == "None" and not hf_id:
use_weight = ""
elif not hf_id:
use_weight = get_custom_model_pathfile(weights, model)
else:
use_weight = hf_id
return use_weight
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:
global_obj.set_sd_status(SD_STATE_CANCEL)
except Exception:
pass
nodlogo_loc = resource_path("logos/nod-logo.png")
available_devices = get_available_devices()

View File

@@ -1,75 +0,0 @@
import gc
"""
The global objects include SD pipeline and config.
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
Also we could avoid memory leak when switching models by clearing the cache.
"""
def _init():
global _sd_obj
global _config_obj
global _schedulers
_sd_obj = None
_config_obj = None
_schedulers = None
def set_sd_obj(value):
global _sd_obj
_sd_obj = value
def set_sd_scheduler(key):
global _sd_obj
_sd_obj.scheduler = _schedulers[key]
def set_sd_status(value):
global _sd_obj
_sd_obj.status = value
def set_cfg_obj(value):
global _config_obj
_config_obj = value
def set_schedulers(value):
global _schedulers
_schedulers = value
def get_sd_obj():
global _sd_obj
return _sd_obj
def get_sd_status():
global _sd_obj
return _sd_obj.status
def get_cfg_obj():
global _config_obj
return _config_obj
def get_scheduler(key):
global _schedulers
return _schedulers[key]
def clear_cache():
global _sd_obj
global _config_obj
global _schedulers
del _sd_obj
del _config_obj
del _schedulers
gc.collect()
_sd_obj = None
_config_obj = None
_schedulers = None

View File

@@ -1,31 +0,0 @@
import os
import tempfile
import gradio
from os import listdir
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
# Clear all gradio tmp images
def clear_gradio_tmp_imgs_folder():
if not os.path.exists(gradio_tmp_imgs_folder):
return
for fileName in listdir(gradio_tmp_imgs_folder):
# Delete tmp png files
if fileName.startswith("tmp") and fileName.endswith(".png"):
os.remove(gradio_tmp_imgs_folder + fileName)
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
def save_pil_to_file(pil_image, dir=None):
if not os.path.exists(gradio_tmp_imgs_folder):
os.mkdir(gradio_tmp_imgs_folder)
file_obj = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
)
pil_image.save(file_obj)
return file_obj
# Register save_pil_to_file override
gradio.processing_utils.save_pil_to_file = save_pil_to_file

View File

@@ -1,152 +0,0 @@
import re
from pathlib import Path
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
scheduler_list,
predefined_models,
)
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
def parse_generation_parameters(x: str):
res = {}
prompt = ""
negative_prompt = ""
done_with_prompt = False
*lines, lastline = x.strip().split("\n")
if len(re_param.findall(lastline)) < 3:
lines.append(lastline)
lastline = ""
for i, line in enumerate(lines):
line = line.strip()
if line.startswith("Negative prompt:"):
done_with_prompt = True
line = line[16:].strip()
if done_with_prompt:
negative_prompt += ("" if negative_prompt == "" else "\n") + line
else:
prompt += ("" if prompt == "" else "\n") + line
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v)
if m is not None:
res[k + "-1"] = m.group(1)
res[k + "-2"] = m.group(2)
else:
res[k] = v
# Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res:
res["Clip skip"] = "1"
hypernet = res.get("Hypernet", None)
if hypernet is not None:
res[
"Prompt"
] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
res["Hires resize-2"] = 0
return res
def import_png_metadata(
pil_data,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
hf_model_id,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
png_hf_model_id = ""
png_custom_model = ""
if "Model" in metadata:
# Remove extension from model info
if metadata["Model"].endswith(".safetensors") or metadata[
"Model"
].endswith(".ckpt"):
metadata["Model"] = Path(metadata["Model"]).stem
# Check for the model name match with one of the local ckpt or safetensors files
if Path(
get_custom_model_pathfile(metadata["Model"] + ".ckpt")
).is_file():
png_custom_model = metadata["Model"] + ".ckpt"
if Path(
get_custom_model_pathfile(metadata["Model"] + ".safetensors")
).is_file():
png_custom_model = metadata["Model"] + ".safetensors"
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if metadata["Model"] in predefined_models:
png_custom_model = metadata["Model"]
# If nothing had matched, check vendor/hf_model_id
if not png_custom_model and metadata["Model"].count("/"):
png_hf_model_id = metadata["Model"]
# No matching model was found
if not png_custom_model and not png_hf_model_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% metadata["Model"]
)
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
cfg_scale = float(metadata["CFG scale"])
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
hf_model_id = ""
if "Model" in metadata and png_hf_model_id:
custom_model = "None"
hf_model_id = png_hf_model_id
if "Prompt" in metadata:
prompt = metadata["Prompt"]
if "Sampler" in metadata:
if metadata["Sampler"] in scheduler_list:
sampler = metadata["Sampler"]
else:
print(
"Import PNG info: Unable to find a scheduler for %s"
% metadata["Sampler"]
)
except Exception as ex:
if pil_data and pil_data.info.get("parameters"):
print("import_png_metadata failed with %s" % ex)
pass
return (
None,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
hf_model_id,
)

View File

@@ -30,15 +30,9 @@ def compare_images(new_filename, golden_filename):
diff = np.abs(new - golden)
mean = np.mean(diff)
if mean > 0.1:
if os.name != "nt":
subprocess.run(
[
"gsutil",
"cp",
new_filename,
"gs://shark_tank/testdata/builder/",
]
)
subprocess.run(
["gsutil", "cp", new_filename, "gs://shark_tank/testdata/builder/"]
)
raise SystemExit("new and golden not close")
else:
print("SUCCESS")

View File

@@ -2,4 +2,4 @@
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python tank/generate_sharktank.py
python generate_sharktank.py

View File

@@ -1,16 +1,13 @@
import os
from sys import executable
import subprocess
from apps.stable_diffusion.src.utils.resources import (
get_json_file,
)
from datetime import datetime as dt
from shark.shark_downloader import download_public_file
from image_comparison import compare_images
import argparse
from glob import glob
import shutil
import requests
model_config_dicts = get_json_file(
os.path.join(
@@ -20,202 +17,51 @@ model_config_dicts = get_json_file(
)
def parse_sd_out(filename, command, device, use_tune, model_name, import_mlir):
with open(filename, "r+") as f:
lines = f.readlines()
metrics = {}
vals_to_read = [
"Clip Inference time",
"Average step",
"VAE Inference time",
"Total image generation",
]
for line in lines:
for val in vals_to_read:
if val in line:
metrics[val] = line.split(" ")[-1].strip("\n")
metrics["Average step"] = metrics["Average step"].strip("ms/it")
metrics["Total image generation"] = metrics[
"Total image generation"
].strip("sec")
metrics["device"] = device
metrics["use_tune"] = use_tune
metrics["model_name"] = model_name
metrics["import_mlir"] = import_mlir
metrics["command"] = command
return metrics
def get_inpaint_inputs():
os.mkdir("./test_images/inputs")
img_url = (
"https://huggingface.co/datasets/diffusers/test-arrays/resolve"
"/main/stable_diffusion_inpaint/input_bench_image.png"
)
mask_url = (
"https://huggingface.co/datasets/diffusers/test-arrays/resolve"
"/main/stable_diffusion_inpaint/input_bench_mask.png"
)
img = requests.get(img_url)
mask = requests.get(mask_url)
open("./test_images/inputs/image.png", "wb").write(img.content)
open("./test_images/inputs/mask.png", "wb").write(mask.content)
def test_loop(device="vulkan", beta=False, extra_flags=[]):
# Get golden values from tank
shutil.rmtree("./test_images", ignore_errors=True)
model_metrics = []
os.mkdir("./test_images")
os.mkdir("./test_images/golden")
get_inpaint_inputs()
hf_model_names = model_config_dicts[0].values()
tuned_options = ["--no-use_tuned", "--use_tuned"]
import_options = ["--import_mlir", "--no-import_mlir"]
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
if os.name == "nt":
prompt_text = '--prompt="cyberpunk forest by Salvador Dali"'
inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"'
tuned_options = ["--no-use_tuned", "use_tuned"]
if beta:
extra_flags.append("--beta_models=True")
extra_flags.append("--no-progress_bar")
to_skip = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]
counter = 0
for import_opt in import_options:
for model_name in hf_model_names:
if model_name in to_skip:
continue
for use_tune in tuned_options:
if (
model_name == "stabilityai/stable-diffusion-2-1"
and use_tune == tuned_options[0]
):
continue
elif (
model_name == "stabilityai/stable-diffusion-2-1-base"
and use_tune == tuned_options[1]
):
continue
command = (
[
executable, # executable is the python from the venv used to run this
"apps/stable_diffusion/scripts/txt2img.py",
"--device=" + device,
prompt_text,
"--negative_prompts=" + '""',
"--seed=42",
import_opt,
"--output_dir="
+ os.path.join(os.getcwd(), "test_images", model_name),
"--hf_model_id=" + model_name,
use_tune,
]
if "inpainting" not in model_name
else [
executable,
"apps/stable_diffusion/scripts/inpaint.py",
"--device=" + device,
inpaint_prompt_text,
"--negative_prompts=" + '""',
"--img_path=./test_images/inputs/image.png",
"--mask_path=./test_images/inputs/mask.png",
"--seed=42",
"--import_mlir",
"--output_dir="
+ os.path.join(os.getcwd(), "test_images", model_name),
"--hf_model_id=" + model_name,
use_tune,
]
)
command += extra_flags
if os.name == "nt":
command = " ".join(command)
dumpfile_name = "_".join(model_name.split("/")) + ".txt"
dumpfile_name = os.path.join(os.getcwd(), dumpfile_name)
with open(dumpfile_name, "w+") as f:
generated_image = not subprocess.call(
command,
stdout=f,
stderr=f,
)
if os.name != "nt":
command = " ".join(command)
if generated_image:
model_metrics.append(
parse_sd_out(
dumpfile_name,
command,
device,
use_tune,
model_name,
import_opt,
)
)
print(command)
print("Successfully generated image")
os.makedirs(
"./test_images/golden/" + model_name, exist_ok=True
)
download_public_file(
"gs://shark_tank/testdata/golden/" + model_name,
"./test_images/golden/" + model_name,
)
test_file_path = os.path.join(
os.getcwd(),
"test_images",
model_name,
"generated_imgs",
dt.now().strftime("%Y%m%d"),
"*.png",
)
test_file = glob(test_file_path)[0]
golden_path = (
"./test_images/golden/" + model_name + "/*.png"
)
golden_file = glob(golden_path)[0]
compare_images(test_file, golden_file)
else:
print(command)
print("failed to generate image for this configuration")
with open(dumpfile_name, "r+") as f:
output = f.readlines()
print("\n".join(output))
exit(1)
if os.name == "nt":
counter += 1
if counter % 2 == 0:
extra_flags.append(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
else:
if counter != 1:
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)
for metric in model_metrics:
output = [
metric["model_name"],
metric["device"],
metric["use_tune"],
metric["import_mlir"],
metric["Clip Inference time"],
metric["Average step"],
metric["VAE Inference time"],
metric["Total image generation"],
metric["command"],
for model_name in hf_model_names:
for use_tune in tuned_options:
command = [
"python",
"apps/stable_diffusion/scripts/txt2img.py",
"--device=" + device,
"--prompt=cyberpunk forest by Salvador Dali",
"--output_dir="
+ os.path.join(os.getcwd(), "test_images", model_name),
"--hf_model_id=" + model_name,
use_tune,
]
f.write(";".join(output) + "\n")
command += extra_flags
generated_image = not subprocess.call(
command, stdout=subprocess.DEVNULL
)
if generated_image:
print(" ".join(command))
print("Successfully generated image")
os.makedirs(
"./test_images/golden/" + model_name, exist_ok=True
)
download_public_file(
"gs://shark_tank/testdata/golden/" + model_name,
"./test_images/golden/" + model_name,
)
test_file_path = os.path.join(
os.getcwd(), "test_images", model_name, "generated_imgs"
)
test_file = glob(test_file_path + "/*.png")[0]
golden_path = "./test_images/golden/" + model_name + "/*.png"
golden_file = glob(golden_path)[0]
compare_images(test_file, golden_file)
else:
print(" ".join(command))
print("failed to generate image for this configuration")
parser = argparse.ArgumentParser()

View File

@@ -2,11 +2,9 @@ def pytest_addoption(parser):
# Attaches SHARK command-line arguments to the pytest machinery.
parser.addoption(
"--benchmark",
action="store",
type=str,
default=None,
choices=("baseline", "native", "all"),
help="Benchmarks specified engine(s) and writes bench_results.csv.",
action="store_true",
default="False",
help="Pass option to benchmark and write results.csv",
)
parser.addoption(
"--onnx_bench",
@@ -42,13 +40,7 @@ def pytest_addoption(parser):
"--update_tank",
action="store_true",
default="False",
help="Update local shark tank with latest artifacts if model artifact hash mismatched.",
)
parser.addoption(
"--force_update_tank",
action="store_true",
default="False",
help="Force-update local shark tank with artifacts from specified shark_tank URL (defaults to nightly).",
help="Update local shark tank with latest artifacts.",
)
parser.addoption(
"--ci_sha",
@@ -59,34 +51,12 @@ def pytest_addoption(parser):
parser.addoption(
"--local_tank_cache",
action="store",
default=None,
default="",
help="Specify the directory in which all downloaded shark_tank artifacts will be cached.",
)
parser.addoption(
"--tank_url",
type=str,
default="gs://shark_tank/nightly",
default="gs://shark_tank/latest",
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
)
parser.addoption(
"--tank_prefix",
type=str,
default=None,
help="Prefix to gs://shark_tank/ model directories from which to download SHARK tank artifacts. Default is nightly.",
)
parser.addoption(
"--benchmark_dispatches",
default=None,
help="Benchmark individual dispatch kernels produced by IREE compiler. Use 'All' for all, or specific dispatches e.g. '0 1 2 10'",
)
parser.addoption(
"--dispatch_benchmarks_dir",
default="./temp_dispatch_benchmarks",
help="Directory in which dispatch benchmarks are saved.",
)
parser.addoption(
"--batchsize",
default=1,
type=int,
help="Batch size for the tested model.",
)

View File

@@ -40,7 +40,7 @@ cmake --build build/
*Prepare the model*
```bash
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvm-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
```
*Prepare the input*
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
```bash
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
```
VAE and Autoencoder are also available
```bash
# VAE
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
# CLIP Autoencoder
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
```

View File

@@ -1,118 +0,0 @@
# Overview
This document is intended to provide a starting point for profiling with SHARK/IREE. At it's core
[SHARK](https://github.com/nod-ai/SHARK/tree/main/tank) is a python API that links the MLIR lowerings from various
frameworks + frontends (e.g. PyTorch -> Torch-MLIR) with the compiler + runtime offered by IREE. More information
on model coverage and framework support can be found [here](https://github.com/nod-ai/SHARK/tree/main/tank). The intended
use case for SHARK is for compilation and deployment of performant state of the art AI models.
![image](https://user-images.githubusercontent.com/22101546/217151219-9bb184a3-cfb9-4788-bb7e-5b502953525c.png)
## Benchmarking with SHARK
TODO: Expand this section.
SHARK offers native benchmarking support, although because it is model focused, fine grain profiling is
hidden when compared against the common "model benchmarking suite" use case SHARK is good at.
### SharkBenchmarkRunner
SharkBenchmarkRunner is a class designed for benchmarking models against other runtimes.
TODO: List supported runtimes for comparison + example on how to benchmark with it.
## Directly profiling IREE
A number of excellent developer resources on profiling with IREE can be
found [here](https://github.com/iree-org/iree/tree/main/docs/developers/developing_iree). As a result this section will
focus on the bridging the gap between the two.
- https://github.com/iree-org/iree/blob/main/docs/developers/developing_iree/profiling.md
- https://github.com/iree-org/iree/blob/main/docs/developers/developing_iree/profiling_with_tracy.md
- https://github.com/iree-org/iree/blob/main/docs/developers/developing_iree/profiling_vulkan_gpu.md
- https://github.com/iree-org/iree/blob/main/docs/developers/developing_iree/profiling_cpu_events.md
Internally, SHARK builds a pair of IREE commands to compile + run a model. At a high level the flow starts with the
model represented with a high level dialect (commonly Linalg) and is compiled to a flatbuffer (.vmfb) that
the runtime is capable of ingesting. At this point (with potentially a few runtime flags) the compiled model is then run
through the IREE runtime. This is all facilitated with the IREE python bindings, which offers a convenient method
to capture the compile command SHARK comes up with. This is done by setting the environment variable
`IREE_SAVE_TEMPS` to point to a directory of choice, e.g. for stable diffusion
```
# Linux
$ export IREE_SAVE_TEMPS=/path/to/some/directory
# Windows
$ $env:IREE_SAVE_TEMPS="C:\path\to\some\directory"
$ python apps/stable_diffusion/scripts/txt2img.py -p "a photograph of an astronaut riding a horse" --save_vmfb
```
NOTE: Currently this will only save the compile command + input MLIR for a single model if run in a pipeline.
In the case of stable diffusion this (should) be UNet so to get examples for other models in the pipeline they
need to be extracted and tested individually.
The save temps directory should contain three files: `core-command-line.txt`, `core-input.mlir`, and `core-output.bin`.
The command line for compilation will start something like this, where the `-` needs to be replaced with the path to `core-input.mlir`.
```
/home/quinn/nod/iree-build/compiler/bindings/python/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=none ...
```
The `-o output_filename.vmfb` flag can be used to specify the location to save the compiled vmfb. Note that a dump of the
dispatches that can be compiled + run in isolation can be generated by adding `--iree-hal-dump-executable-benchmarks-to=/some/directory`. Say, if they are in the `benchmarks` directory, the following compile/run commands would work for Vulkan on RDNA3.
```
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
iree-benchmark-module --module=benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb --function=forward --device=vulkan
```
Where `${NUM}` is the dispatch number that you want to benchmark/profile in isolation.
### Enabling Tracy for Vulkan profiling
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime builds an
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SHARK-Runtime/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
```
$env:IREE_ENABLE_RUNTIME_TRACING="ON"
```
Getting a trace can then be done by setting environment variable `TRACY_NO_EXIT=1` and running the program that is to be
traced. Then, to actually capture the trace, use the `iree-tracy-capture` tool in a different terminal. Note that to get
the capture and profiler tools the `IREE_BUILD_TRACY=ON` CMake flag needs to be set.
```
TRACY_NO_EXIT=1 python apps/stable_diffusion/scripts/txt2img.py -p "a photograph of an astronaut riding a horse"
# (in another terminal, either on the same machine or through ssh with a tunnel through port 8086)
iree-tracy-capture -o trace_filename.tracy
```
To do it over ssh, the flow looks like this
```
# From terminal 1 on local machine
ssh -L 8086:localhost:8086 <remote_server_name>
TRACY_NO_EXIT=1 python apps/stable_diffusion/scripts/txt2img.py -p "a photograph of an astronaut riding a horse"
# From terminal 2 on local machine. Requires having built IREE with the CMake flag `IREE_BUILD_TRACY=ON` to build the required tooling.
iree-tracy-capture -o /path/to/trace.tracy
```
The trace can then be viewed with
```
iree-tracy-profiler /path/to/trace.tracy
```
Capturing a runtime trace will work with any IREE tooling that uses the runtime. For example, `iree-benchmark-module`
can be used for benchmarking an individual module. Importantly this means that any SHARK script can be profiled with tracy.
NOTE: Not all backends have the same tracy support. This writeup is focused on CPU/Vulkan backends but there is recently added support for tracing on CUDA (requires the `--cuda_tracing` flag).
## Experimental RGP support
TODO: This section is temporary until proper RGP support is added.
Currently, for stable diffusion there is a flag for enabling UNet to be visible to RGP with `--enable_rgp`. To get a proper capture though, the `DevModeSqttPrepareFrameCount=1` flag needs to be set for the driver (done with `VkPanel` on Windows).
With these two settings, a single iteration of UNet can be captured.
(AMD only) To get a dump of the pipelines (result of compiled SPIR-V) the `EnablePipelineDump=1` driver flag can be set. The
files will typically be dumped to a directory called `spvPipeline` (on Linux `/var/tmp/spvPipeline`. The dumped files will
include header information that can be used to map back to the source dispatch/SPIR-V, e.g.
```
[Version]
version = 57
[CsSpvFile]
fileName = Shader_0x946C08DFD0C10D9A.spv
[CsInfo]
entryPoint = forward_dispatch_193_matmul_256x65536x2304
```

View File

@@ -1,75 +0,0 @@
# Overview
This document is intended to provide a starting point for using SHARK stable diffusion with Blender.
We currently make use of the [AI-Render Plugin](https://github.com/benrugg/AI-Render) to integrate with Blender.
## Setup SHARK and prerequisites:
* Download the latest SHARK SD webui .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow instructions on the [README](https://github.com/nod-ai/SHARK#readme)
* Once you have the .exe where you would like SHARK to install, run the .exe from terminal/PowerShell with the `--api` flag:
```
## Run the .exe in API mode:
.\shark_sd_<date>_<ver>.exe --api
## For example:
.\shark_sd_20230411_671.exe --api --server_port=8082
## From a the base directory of a source clone of SHARK:
./setup_venv.ps1
python apps\stable_diffusion\web\index.py --api
```
Your local SD server should start and look something like this:
![image](https://user-images.githubusercontent.com/87458719/231369758-e2c3c45a-eccc-4fe5-a788-4a3bf1ace1d1.png)
* Note: When running in api mode with `--api`, the .exe will not function as a webUI. Thus, the address in the terminal output will only be useful for API requests.
### Install AI Render
- Get AI Render on [Blender Market](https://blendermarket.com/products/ai-render) or [Gumroad](https://airender.gumroad.com/l/ai-render)
- Open Blender, then go to Edit > Preferences > Add-ons > Install and then find the zip file
- We will be using the Automatic1111 SD backend for the AI-Render plugin. Follow instructions [here](https://github.com/benrugg/AI-Render/wiki/Local-Installation) to setup local SD backend.
Your AI-Render preferences should be configured as shown; the highlighted part should match your terminal output:
![image](https://user-images.githubusercontent.com/87458719/231390322-59a54a09-520a-4a08-b658-6e37bd63e932.png)
The [AI-Render README](https://github.com/benrugg/AI-Render/blob/main/README.md) has more details on installation and usage, as well as video tutorials.
## Using AI-Render + SHARK in your Blender project
- In the Render Properties tab, in the AI-Render dropdown, enable AI-Render.
![image](https://user-images.githubusercontent.com/87458719/231392843-9bd51744-3ce2-464e-843a-0c4d4c96df0c.png)
- Select an image size (it's usually better to upscale later than go high on the img2img resolution here.)
![image](https://user-images.githubusercontent.com/87458719/231394288-0c4ab8c5-dc30-4dbe-8bc1-7520ded5efe8.png)
- From here, you can enter a prompt and configure img2img Stable Diffusion parameters, and AI-Render will run SHARK SD img2img on the rendered scene.
- AI-Render has useful presets for aesthetic styles, so you should be able to keep your subject prompt simple and focus on creating a decent Blender scene to start from.
![image](https://user-images.githubusercontent.com/87458719/231440729-2fe69586-41cb-4274-9ce7-f6c08def600b.png)
## Examples:
Scene (Input image):
![blender-sample-2](https://user-images.githubusercontent.com/87458719/231450408-0e680086-3e52-4962-a5c1-c703a94d1583.png)
Prompt:
"A bowl of tangerines in front of rocks, masterpiece, oil on canvas, by Georgia O'Keefe, trending on artstation, landscape painting by Caspar David Friedrich"
Negative Prompt (default):
"ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
Example output:
![blender-sample-2_out](https://user-images.githubusercontent.com/87458719/231451145-a0b56897-a7d0-4add-bbed-7e8af21a65df.png)

View File

@@ -26,17 +26,16 @@ from apps.stable_diffusion.src.utils.stable_args import (
def create_hash(file_name):
with open(file_name, "rb") as f:
file_hash = hashlib.blake2b(digest_size=64)
while chunk := f.read(2**10):
file_hash = hashlib.blake2b()
while chunk := f.read(2**20):
file_hash.update(chunk)
return file_hash.hexdigest()
def save_torch_model(torch_model_list, local_tank_cache, import_args):
def save_torch_model(torch_model_list):
from tank.model_utils import (
get_hf_model,
get_hf_seq2seq_model,
get_vision_model,
get_hf_img_cls_model,
get_fp16_model,
@@ -50,17 +49,17 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
tracing_required = row[1]
model_type = row[2]
is_dynamic = row[3]
mlir_type = row[4]
tracing_required = False if tracing_required == "False" else True
is_dynamic = False if is_dynamic == "False" else True
print("generating artifacts for: " + torch_model_name)
model = None
input = None
if model_type == "stable_diffusion":
args.use_tuned = False
args.import_mlir = True
args.local_tank_cache = local_tank_cache
args.use_tuned = False
args.local_tank_cache = WORKDIR
precision_values = ["fp16"]
seq_lengths = [64, 77]
@@ -75,41 +74,24 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
width=512,
height=512,
use_base_vae=False,
custom_vae="",
debug=True,
sharktank_dir=local_tank_cache,
sharktank_dir=WORKDIR,
generate_vmfb=False,
)
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(
torch_model_name, import_args
)
model, input, _ = get_vision_model(torch_model_name)
elif model_type == "hf":
model, input, _ = get_hf_model(torch_model_name, import_args)
elif model_type == "hf_seq2seq":
model, input, _ = get_hf_seq2seq_model(
torch_model_name, import_args
)
model, input, _ = get_hf_model(torch_model_name)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(
torch_model_name, import_args
)
model, input, _ = get_hf_img_cls_model(torch_model_name)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name, import_args)
model, input, _ = get_fp16_model(torch_model_name)
torch_model_name = torch_model_name.replace("/", "_")
if import_args["batch_size"] != 1:
torch_model_dir = os.path.join(
local_tank_cache,
str(torch_model_name)
+ "_torch"
+ f"_BS{str(import_args['batch_size'])}",
)
else:
torch_model_dir = os.path.join(
local_tank_cache, str(torch_model_name) + "_torch"
)
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"
)
os.makedirs(torch_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
@@ -122,8 +104,13 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name,
mlir_type=mlir_type,
)
mlir_hash = create_hash(
os.path.join(
torch_model_dir, torch_model_name + "_torch" + ".mlir"
)
)
np.save(os.path.join(torch_model_dir, "hash"), np.array(mlir_hash))
# Generate torch dynamic models.
if is_dynamic:
mlir_importer.import_debug(
@@ -131,22 +118,16 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name + "_dynamic",
mlir_type=mlir_type,
)
def save_tf_model(tf_model_list, local_tank_cache, import_args):
def save_tf_model(tf_model_list):
from tank.model_utils_tf import (
get_causal_image_model,
get_masked_lm_model,
get_causal_lm_model,
get_keras_model,
get_TFhf_model,
get_tfhf_seq2seq_model,
)
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
@@ -170,52 +151,34 @@ def save_tf_model(tf_model_list, local_tank_cache, import_args):
input = None
print(f"Generating artifacts for model {tf_model_name}")
if model_type == "hf":
model, input, _ = get_masked_lm_model(
tf_model_name, import_args
)
elif model_type == "img":
model, input, _ = get_causal_image_model(
tf_model_name, import_args
)
elif model_type == "keras":
model, input, _ = get_keras_model(tf_model_name, import_args)
elif model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name, import_args)
elif model_type == "tfhf_seq2seq":
model, input, _ = get_tfhf_seq2seq_model(
tf_model_name, import_args
)
elif model_type == "hf_causallm":
model, input, _ = get_causal_lm_model(
tf_model_name, import_args
)
model, input, _ = get_causal_lm_model(tf_model_name)
if model_type == "img":
model, input, _ = get_causal_image_model(tf_model_name)
if model_type == "keras":
model, input, _ = get_keras_model(tf_model_name)
if model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name)
tf_model_name = tf_model_name.replace("/", "_")
if import_args["batch_size"] != 1:
tf_model_dir = os.path.join(
local_tank_cache,
str(tf_model_name)
+ "_tf"
+ f"_BS{str(import_args['batch_size'])}",
)
else:
tf_model_dir = os.path.join(
local_tank_cache, str(tf_model_name) + "_tf"
)
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
os.makedirs(tf_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
model,
inputs=input,
input,
frontend="tf",
)
mlir_importer.import_debug(
is_dynamic=False,
dir=tf_model_dir,
model_name=tf_model_name,
)
mlir_hash = create_hash(
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
)
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
def save_tflite_model(tflite_model_list):
from shark.tflite_utils import TFLitePreprocessor
with open(tflite_model_list) as csvfile:
@@ -227,18 +190,18 @@ def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
print("tflite_model_name", tflite_model_name)
print("tflite_model_link", tflite_model_link)
tflite_model_name_dir = os.path.join(
local_tank_cache, str(tflite_model_name) + "_tflite"
WORKDIR, str(tflite_model_name) + "_tflite"
)
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
# Preprocess to get SharkImporter input import_args
# Preprocess to get SharkImporter input args
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()
# Use SharkImporter to get SharkInference input import_args
# Use SharkImporter to get SharkInference input args
my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
@@ -262,71 +225,6 @@ def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
)
def check_requirements(frontend):
import importlib
has_pkgs = False
if frontend == "torch":
tv_spec = importlib.util.find_spec("torchvision")
has_pkgs = tv_spec is not None
elif frontend in ["tensorflow", "tf"]:
tf_spec = importlib.util.find_spec("tensorflow")
has_pkgs = tf_spec is not None
return has_pkgs
class NoImportException(Exception):
"Raised when requirements are not met for OTF model artifact generation."
pass
def gen_shark_files(modelname, frontend, tank_dir, importer_args):
# If a model's artifacts are requested by shark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
# TODO: Add TFlite support.
import tempfile
import_args = importer_args
if check_requirements(frontend):
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tf_model_list.csv"
)
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
)
# Create a temporary .csv with only the desired entry.
if frontend == "tf":
with open(tf_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_tf_model(custom_model_csv.name, tank_dir, import_args)
elif frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_torch_model(custom_model_csv.name, tank_dir, import_args)
else:
raise NoImportException
# Validates whether the file is present or not.
def is_valid_file(arg):
if not os.path.exists(arg):
@@ -336,7 +234,7 @@ def is_valid_file(arg):
if __name__ == "__main__":
# Note, all of these flags are overridden by the import of import_args from stable_args.py, flags are duplicated temporarily to preserve functionality
# Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--torch_model_csv",
@@ -364,26 +262,20 @@ if __name__ == "__main__":
# )
# parser.add_argument("--upload", type=bool, default=False)
# old_import_args = parser.parse_import_args()
import_args = {
"batch_size": "1",
}
print(import_args)
# old_args = parser.parse_args()
home = str(Path.home())
WORKDIR = os.path.join(os.path.dirname(__file__), "..", "gen_shark_tank")
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
os.path.dirname(__file__), "tank", "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tf_model_list.csv"
)
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
tflite_model_csv = os.path.join(
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
)
save_torch_model(
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
WORKDIR,
import_args,
)
save_torch_model(torch_model_csv, WORKDIR, import_args)
save_tf_model(tf_model_csv, WORKDIR, import_args)
save_tflite_model(tflite_model_csv, WORKDIR, import_args)
save_torch_model(torch_model_csv)
save_tf_model(tf_model_csv)
save_tflite_model(tflite_model_csv)

View File

@@ -1,58 +0,0 @@
# This script will toggle the comment/uncommenting aspect for dealing
# with __file__ AttributeError arising in case of a few modules in
# `torch/_dynamo/skipfiles.py` (within shark.venv)
from distutils.sysconfig import get_python_lib
import fileinput
from pathlib import Path
# Temorary workaround for transformers/__init__.py.
path_to_tranformers_hook = Path(
get_python_lib()
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
)
if path_to_tranformers_hook.is_file():
pass
else:
with open(path_to_tranformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
modules_to_comment = ["abc,", "os,", "posixpath,", "_collections_abc,"]
startMonitoring = 0
for line in fileinput.input(path_to_skipfiles, inplace=True):
if "SKIP_DIRS = " in line:
startMonitoring = 1
print(line, end="")
elif startMonitoring in [1, 2]:
if "]" in line:
startMonitoring += 1
print(line, end="")
else:
flag = True
for module in modules_to_comment:
if module in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
flag = False
break
if flag:
print(line, end="")
else:
print(line, end="")
# For getting around scikit-image's packaging, laze_loader has had a patch merged but yet to be released.
# Refer: https://github.com/scientific-python/lazy_loader
path_to_lazy_loader = Path(get_python_lib() + "/lazy_loader/__init__.py")
for line in fileinput.input(path_to_lazy_loader, inplace=True):
if 'stubfile = filename if filename.endswith("i")' in line:
print(
' stubfile = (filename if filename.endswith("i") else f"{os.path.splitext(filename)[0]}.pyi")',
end="",
)
else:
print(line, end="")

View File

@@ -10,8 +10,3 @@ requires = [
"iree-runtime>=20221022.190",
]
build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79
include = '\.pyi?$'

View File

@@ -1,3 +1,3 @@
[pytest]
addopts = --verbose -s -p no:warnings
addopts = --verbose -p no:warnings
norecursedirs = inference tank/tflite examples benchmarks shark

View File

@@ -1,9 +1,9 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
numpy>1.22.4
numpy==1.22.4
torchvision
pytorch-triton
torchvision==0.16.0.dev20230322
tabulate
tqdm
@@ -15,8 +15,8 @@ iree-tools-tf
# TensorFlow and JAX.
gin-config
tensorflow>2.11
keras
tensorflow==2.10.1
keras==2.10
#tf-models-nightly
#tensorflow-text-nightly
transformers
@@ -33,7 +33,6 @@ lit
pyyaml
python-dateutil
sacremoses
sentencepiece
# web dependecies.
gradio

View File

@@ -16,19 +16,13 @@ parameterized
# Add transformers, diffusers and scipy since it most commonly used
transformers
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
diffusers
scipy
ftfy
gradio==3.22.0
gradio
altair
omegaconf
safetensors
opencv-python
scikit-image
pytorch_lightning # for runwayml models
tk
pywebview
sentencepiece
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile

View File

@@ -1,54 +1,19 @@
<#
.SYNOPSIS
A script to update and install the SHARK runtime and its dependencies.
.DESCRIPTION
This script updates and installs the SHARK runtime and its dependencies.
It checks the Python version installed and installs any required build
dependencies into a Python virtual environment.
If that environment does not exist, it creates it.
.PARAMETER update-src
git pulls latest version
.PARAMETER force
removes and recreates venv to force update of all dependencies
.EXAMPLE
.\setup_venv.ps1 --force
.EXAMPLE
.\setup_venv.ps1 --update-src
.INPUTS
None
.OUTPUTS
None
#>
param([string]$arguments)
if ($arguments -eq "--update-src"){
git pull
}
if ($arguments -eq "--force"){
if (Test-Path env:VIRTUAL_ENV) {
Write-Host "deactivating..."
Deactivate
}
if (Test-Path .\shark.venv\) {
Write-Host "removing and recreating venv..."
Remove-Item .\shark.venv -Force -Recurse
if (Test-Path .\shark.venv\) {
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
exit 1
}
}
}
#Write-Host "Installing python"
#Start-Process winget install Python.Python.3.10 '/quiet InstallAllUsers=1 PrependPath=1' -wait -NoNewWindow
#Write-Host "python installation completed successfully"
#Write-Host "Reload environment variables"
#$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
#Write-Host "Reloaded environment variables"
# redirect stderr into stdout
$p = &{python -V} 2>&1
@@ -60,36 +25,19 @@ $version = if($p -is [System.Management.Automation.ErrorRecord])
}
else
{
# otherwise return complete Python list
$ErrorActionPreference = 'SilentlyContinue'
$PyVer = py --list
# otherwise return as is
$p
}
# deactivate any activated venvs
if ($PyVer -like "*venv*")
{
deactivate # make sure we don't update the wrong venv
$PyVer = py --list # update list
}
Write-Host "Python version found is"
Write-Host $p
Write-Host "Python versions found are"
Write-Host ($PyVer | Out-String) # formatted output with line breaks
if (!($PyVer.length -ne 0)) {$p} # return Python --version String if py.exe is unavailable
if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any list
{
Write-Host "Please install Python 3.11 and try again"
exit 34
}
Write-Host "Installing Build Dependencies"
# make sure we really use 3.11 from list, even if it's not the default.
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
else {python -m venv .\shark.venv\}
python -m venv .\shark.venv\
.\shark.venv\Scripts\activate
python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --pre torch-mlir torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html

View File

@@ -2,10 +2,9 @@
# Sets up a venv suitable for running samples.
# e.g:
# ./setup_venv.sh #setup a default $PYTHON3 shark.venv
# Environment variables used by the script.
# Environment Variables by the script.
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
# VENV_DIR=myshark.venv #create a venv called myshark.venv
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
# IMPORTER=1 #Install importer deps
# BENCHMARK=1 #Install benchmark deps
@@ -27,17 +26,15 @@ PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; prin
echo "Python: $PYTHON"
echo "Python version: $PYTHON_VERSION_X_Y"
if [[ "$SKIP_VENV" != "1" ]]; then
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
VENV_DIR=${VENV_DIR:-shark.venv}
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
PYTHON="$(which python3)"
else
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
fi
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
VENV_DIR=${VENV_DIR:-shark.venv}
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
PYTHON="$(which python3)"
else
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
fi
Red=`tput setaf 1`
@@ -45,7 +42,7 @@ Green=`tput setaf 2`
Yellow=`tput setaf 3`
# Assume no binary torch-mlir.
# Currently available for macOS m1&intel (3.11) and Linux(3.8,3.10,3.11)
# Currently available for macOS m1&intel (3.10) and Linux(3.7,3.8,3.9,3.10)
torch_mlir_bin=false
if [[ $(uname -s) = 'Darwin' ]]; then
echo "${Yellow}Apple macOS detected"
@@ -63,12 +60,12 @@ if [[ $(uname -s) = 'Darwin' ]]; then
fi
echo "${Yellow}Run the following commands to setup your SSL certs for your Python version if you see SSL errors with tests"
echo "${Yellow}/Applications/Python\ 3.XX/Install\ Certificates.command"
if [ "$PYTHON_VERSION_X_Y" == "3.11" ]; then
if [ "$PYTHON_VERSION_X_Y" == "3.10" ]; then
torch_mlir_bin=true
fi
elif [[ $(uname -s) = 'Linux' ]]; then
echo "${Yellow}Linux detected"
if [ "$PYTHON_VERSION_X_Y" == "3.8" ] || [ "$PYTHON_VERSION_X_Y" == "3.10" ] || [ "$PYTHON_VERSION_X_Y" == "3.11" ] ; then
if [ "$PYTHON_VERSION_X_Y" == "3.7" ] || [ "$PYTHON_VERSION_X_Y" == "3.8" ] || [ "$PYTHON_VERSION_X_Y" == "3.9" ] || [ "$PYTHON_VERSION_X_Y" == "3.10" ] ; then
torch_mlir_bin=true
fi
else
@@ -81,7 +78,7 @@ $PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
if [ "$torch_mlir_bin" = true ]; then
if [[ $(uname -s) = 'Darwin' ]]; then
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
else
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
if [ $? -eq 0 ];then
@@ -92,7 +89,7 @@ if [ "$torch_mlir_bin" = true ]; then
fi
else
echo "${Red}No binaries found for Python $PYTHON_VERSION_X_Y on $(uname -s)"
echo "${Yello}Python 3.11 supported on macOS and 3.8,3.10 and 3.11 on Linux"
echo "${Yello}Python 3.10 supported on macOS and 3.7,3.8,3.9 and 3.10 on Linux"
echo "${Red}Please build torch-mlir from source in your environment"
exit 1
fi
@@ -101,11 +98,11 @@ if [[ -z "${USE_IREE}" ]]; then
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
else
touch ./.use-iree
RUNTIME="https://openxla.github.io/iree/pip-release-links.html"
RUNTIME="https://iree-org.github.io/iree/pip-release-links.html"
fi
if [[ -z "${NO_BACKEND}" ]]; then
echo "Installing ${RUNTIME}..."
$PYTHON -m pip install --pre --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
$PYTHON -m pip install --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
else
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
fi
@@ -115,7 +112,7 @@ if [[ ! -z "${IMPORTER}" ]]; then
if [[ $(uname -s) = 'Linux' ]]; then
echo "${Yellow}Linux detected.. installing Linux importer tools"
#Always get the importer tools from upstream IREE
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://openxla.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://iree-org.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
elif [[ $(uname -s) = 'Darwin' ]]; then
echo "${Yellow}macOS detected.. installing macOS importer tools"
#Conda seems to have some problems installing these packages and hope they get resolved upstream.
@@ -132,11 +129,11 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
TV_VERSION=${TV_VER:9:18}
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu118."
echo "Successfully Installed torch + cu117."
else
echo "Could not install torch + cu118." >&2
echo "Could not install torch + cu117." >&2
fi
fi
@@ -150,7 +147,8 @@ if [[ ! -z "${ONNX}" ]]; then
fi
fi
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
if [[ -z "${CONDA_PREFIX}" ]]; then
echo "${Green}Before running examples activate venv with:"
echo " ${Green}source $VENV_DIR/bin/activate"
fi

View File

@@ -1,18 +0,0 @@
# SHARK LLaMA
## TORCH-MLIR Version
```
https://github.com/nod-ai/torch-mlir.git
```
Then check out the `complex` branch and `git submodule update --init` and then build with `.\build_tools\python_deploy\build_windows.ps1`
### Setup & Run
```
git clone https://github.com/nod-ai/llama.git
```
Then in this repository
```
pip install -e .
python llama/shark_model.py
```

View File

@@ -1,73 +0,0 @@
from transformers import AutoTokenizer, FlaxAutoModel
import torch
import jax
from typing import Union, Dict, List, Any
import numpy as np
from shark.shark_inference import SharkInference
import io
NumpyTree = Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]
def convert_torch_tensor_tree_to_numpy(
tree: Union[torch.tensor, Dict[str, torch.tensor], List[torch.tensor]]
) -> NumpyTree:
return jax.tree_util.tree_map(
lambda torch_tensor: torch_tensor.cpu().detach().numpy(), tree
)
def convert_int64_to_int32(tree: NumpyTree) -> NumpyTree:
return jax.tree_util.tree_map(
lambda tensor: np.array(tensor, dtype=np.int32)
if tensor.dtype == np.int64
else tensor,
tree,
)
def get_sample_input():
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased"
)
inputs_torch = tokenizer("Hello, World!", return_tensors="pt")
return convert_int64_to_int32(
convert_torch_tensor_tree_to_numpy(inputs_torch.data)
)
def get_jax_model():
return FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
def export_jax_to_mlir(jax_model: Any, sample_input: NumpyTree):
model_mlir = jax.jit(jax_model).lower(**sample_input).compiler_ir()
byte_stream = io.BytesIO()
model_mlir.operation.write_bytecode(file=byte_stream)
return byte_stream.getvalue()
def assert_array_list_allclose(x, y, *args, **kwargs):
assert len(x) == len(y)
for a, b in zip(x, y):
np.testing.assert_allclose(
np.asarray(a), np.asarray(b), *args, **kwargs
)
sample_input = get_sample_input()
jax_model = get_jax_model()
mlir = export_jax_to_mlir(jax_model, sample_input)
# Compile and load module.
shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo")
shark_inference.compile()
# Run main function.
result = shark_inference("main", jax.tree_util.tree_flatten(sample_input)[0])
# Run JAX model.
reference_result = jax.tree_util.tree_flatten(jax_model(**sample_input))[0]
# Verify result.
assert_array_list_allclose(result, reference_result, atol=1e-5)

View File

@@ -1,6 +0,0 @@
flax
jax[cpu]
nodai-SHARK
orbax
transformers
torch

View File

@@ -1,842 +0,0 @@
####################################################################################
# Please make sure you have transformers 4.21.2 installed before running this demo
#
# -p --model_path: the directory in which you want to store the bloom files.
# -dl --device_list: the list of device indices you want to use. if you want to only use the first device, or you are running on cpu leave this blank.
# Otherwise, please give this argument in this format: "[0, 1, 2]"
# -de --device: the device you want to run bloom on. E.G. cpu, cuda
# -c, --recompile: set to true if you want to recompile to vmfb.
# -d, --download: set to true if you want to redownload the mlir files
# -cm, --create_mlirs: set to true if you want to create the mlir files from scratch. please make sure you have transformers 4.21.2 before using this option
# -t --token_count: the number of tokens you want to generate
# -pr --prompt: the prompt you want to feed to the model
# -m --model_name: the name of the model, e.g. bloom-560m
#
# If you don't specify a prompt when you run this example, you will be able to give prompts through the terminal. Run the
# example in this way if you want to run multiple examples without reinitializing the model
#####################################################################################
import os
import io
import torch
import torch.nn as nn
from collections import OrderedDict
import torch_mlir
from torch_mlir import TensorPlaceholder
import re
from transformers.models.bloom.configuration_bloom import BloomConfig
import json
import sys
import argparse
import json
import urllib.request
import subprocess
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from transformers import (
BloomTokenizerFast,
BloomForSequenceClassification,
BloomForCausalLM,
)
from transformers.models.bloom.modeling_bloom import (
BloomBlock,
build_alibi_tensor,
)
IS_CUDA = False
class ShardedBloom:
def __init__(self, src_folder):
f = open(f"{src_folder}/config.json")
config = json.load(f)
f.close()
self.layers_initialized = False
self.src_folder = src_folder
try:
self.n_embed = config["n_embed"]
except KeyError:
self.n_embed = config["hidden_size"]
self.vocab_size = config["vocab_size"]
self.n_layer = config["n_layer"]
try:
self.n_head = config["num_attention_heads"]
except KeyError:
self.n_head = config["n_head"]
def _init_layer(self, layer_name, device, replace, device_idx):
if replace or not os.path.exists(
f"{self.src_folder}/{layer_name}.vmfb"
):
f_ = open(f"{self.src_folder}/{layer_name}.mlir", encoding="utf-8")
module = f_.read()
f_.close()
module = bytes(module, "utf-8")
shark_module = SharkInference(
module,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
)
shark_module.save_module(
module_name=f"{self.src_folder}/{layer_name}",
extra_args=[
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-stream-resource-max-allocation-size=1000000000",
"--iree-codegen-check-ir-before-llvm-conversion=false",
],
)
else:
shark_module = SharkInference(
"",
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
)
return shark_module
def init_layers(self, device, replace=False, device_idx=[0]):
if device_idx is not None:
n_devices = len(device_idx)
self.word_embeddings_module = self._init_layer(
"word_embeddings",
device,
replace,
device_idx if device_idx is None else device_idx[0 % n_devices],
)
self.word_embeddings_layernorm_module = self._init_layer(
"word_embeddings_layernorm",
device,
replace,
device_idx if device_idx is None else device_idx[1 % n_devices],
)
self.ln_f_module = self._init_layer(
"ln_f",
device,
replace,
device_idx if device_idx is None else device_idx[2 % n_devices],
)
self.lm_head_module = self._init_layer(
"lm_head",
device,
replace,
device_idx if device_idx is None else device_idx[3 % n_devices],
)
self.block_modules = [
self._init_layer(
f"bloom_block_{i}",
device,
replace,
device_idx
if device_idx is None
else device_idx[(i + 4) % n_devices],
)
for i in range(self.n_layer)
]
self.layers_initialized = True
def load_layers(self):
assert self.layers_initialized
self.word_embeddings_module.load_module(
f"{self.src_folder}/word_embeddings.vmfb"
)
self.word_embeddings_layernorm_module.load_module(
f"{self.src_folder}/word_embeddings_layernorm.vmfb"
)
for block_module, i in zip(self.block_modules, range(self.n_layer)):
block_module.load_module(f"{self.src_folder}/bloom_block_{i}.vmfb")
self.ln_f_module.load_module(f"{self.src_folder}/ln_f.vmfb")
self.lm_head_module.load_module(f"{self.src_folder}/lm_head.vmfb")
def forward_pass(self, input_ids, device):
if IS_CUDA:
cudaSetDevice(self.word_embeddings_module.device_idx)
input_embeds = self.word_embeddings_module(
inputs=(input_ids,), function_name="forward"
)
input_embeds = torch.tensor(input_embeds).float()
if IS_CUDA:
cudaSetDevice(self.word_embeddings_layernorm_module.device_idx)
hidden_states = self.word_embeddings_layernorm_module(
inputs=(input_embeds,), function_name="forward"
)
hidden_states = torch.tensor(hidden_states).float()
attention_mask = torch.ones(
[hidden_states.shape[0], len(input_ids[0])]
)
alibi = build_alibi_tensor(
attention_mask,
self.n_head,
hidden_states.dtype,
hidden_states.device,
)
causal_mask = _prepare_attn_mask(
attention_mask, input_ids.size(), input_embeds, 0
)
causal_mask = torch.tensor(causal_mask).float()
presents = ()
all_hidden_states = tuple(hidden_states)
for block_module, i in zip(self.block_modules, range(self.n_layer)):
if IS_CUDA:
cudaSetDevice(block_module.device_idx)
output = block_module(
inputs=(
hidden_states.detach().numpy(),
alibi.detach().numpy(),
causal_mask.detach().numpy(),
),
function_name="forward",
)
hidden_states = torch.tensor(output[0]).float()
all_hidden_states = all_hidden_states + (hidden_states,)
presents = presents + (
tuple(
(
output[1],
output[2],
)
),
)
if IS_CUDA:
cudaSetDevice(self.ln_f_module.device_idx)
hidden_states = self.ln_f_module(
inputs=(hidden_states,), function_name="forward"
)
if IS_CUDA:
cudaSetDevice(self.lm_head_module.device_idx)
logits = self.lm_head_module(
inputs=(hidden_states,), function_name="forward"
)
logits = torch.tensor(logits).float()
return torch.argmax(logits[:, -1, :], dim=-1)
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
past_key_values_length: int = 0,
):
"""
Make causal mask used for bi-directional self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
mask_cond = torch.arange(mask.size(-1))
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
mask.masked_fill_(intermediate_mask, 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
target_length, past_key_values_length, dtype=dtype
),
mask,
],
dim=-1,
)
expanded_mask = mask[None, None, :, :].expand(
batch_size, 1, target_length, target_length + past_key_values_length
)
return expanded_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
batch_size, source_length = mask.size()
tgt_len = tgt_len if tgt_len is not None else source_length
expanded_mask = (
mask[:, None, None, :]
.expand(batch_size, 1, tgt_len, source_length)
.to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def _prepare_attn_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
past_key_values_length=past_key_values_length,
).to(attention_mask.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def download_model(destination_folder, model_name):
download_public_file(
f"gs://shark_tank/sharded_bloom/{model_name}/", destination_folder
)
def compile_embeddings(embeddings_layer, input_ids, path):
input_ids_placeholder = torch_mlir.TensorPlaceholder.like(
input_ids, dynamic_axes=[1]
)
module = torch_mlir.compile(
embeddings_layer,
(input_ids_placeholder),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = io.BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(path, "w+")
f_.write(str(module))
f_.close()
return
def compile_word_embeddings_layernorm(
embeddings_layer_layernorm, embeds, path
):
embeds_placeholder = torch_mlir.TensorPlaceholder.like(
embeds, dynamic_axes=[1]
)
module = torch_mlir.compile(
embeddings_layer_layernorm,
(embeds_placeholder),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = io.BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(path, "w+")
f_.write(str(module))
f_.close()
return
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
def compile_to_mlir(
bblock,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
use_cache=None,
output_attentions=False,
alibi=None,
block_index=0,
path=".",
):
fx_g = make_fx(
bblock,
decomposition_table=get_decompositions(
[
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
tracing_mode="real",
_allow_non_fake_inputs=False,
)(hidden_states, alibi, attention_mask)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
strip_overloads(fx_g)
hidden_states_placeholder = TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
attention_mask, dynamic_axes=[2, 3]
)
alibi_placeholder = TensorPlaceholder.like(alibi, dynamic_axes=[2])
ts_g = torch.jit.script(fx_g)
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
alibi_placeholder,
attention_mask_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
module_placeholder = module
module_context = module_placeholder.context
def check_valid_line(line, line_n, mlir_file_len):
if "private" in line:
return False
if "attributes" in line:
return False
if mlir_file_len - line_n == 2:
return False
return True
mlir_file_len = len(str(module).split("\n"))
def remove_constant_dim(line):
if "17x" in line:
line = re.sub("17x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi eq" in line:
line = re.sub("c17", "dim", line)
if " 17," in line:
line = re.sub(" 17,", " %dim,", line)
return line
module = "\n".join(
[
remove_constant_dim(line)
for line, line_n in zip(
str(module).split("\n"), range(mlir_file_len)
)
if check_valid_line(line, line_n, mlir_file_len)
]
)
module = module_placeholder.parse(module, context=module_context)
bytecode_stream = io.BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(path, "w+")
f_.write(str(module))
f_.close()
return
def compile_ln_f(ln_f, hidden_layers, path):
hidden_layers_placeholder = torch_mlir.TensorPlaceholder.like(
hidden_layers, dynamic_axes=[1]
)
module = torch_mlir.compile(
ln_f,
(hidden_layers_placeholder),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = io.BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(path, "w+")
f_.write(str(module))
f_.close()
return
def compile_lm_head(lm_head, hidden_layers, path):
hidden_layers_placeholder = torch_mlir.TensorPlaceholder.like(
hidden_layers, dynamic_axes=[1]
)
module = torch_mlir.compile(
lm_head,
(hidden_layers_placeholder),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = io.BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(path, "w+")
f_.write(str(module))
f_.close()
return
def create_mlirs(destination_folder, model_name):
model_config = "bigscience/" + model_name
sample_input_ids = torch.ones([1, 17], dtype=torch.int64)
urllib.request.urlretrieve(
f"https://huggingface.co/bigscience/{model_name}/resolve/main/config.json",
filename=f"{destination_folder}/config.json",
)
urllib.request.urlretrieve(
f"https://huggingface.co/bigscience/bloom/resolve/main/tokenizer.json",
filename=f"{destination_folder}/tokenizer.json",
)
class HuggingFaceLanguage(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = BloomForCausalLM.from_pretrained(model_config)
def forward(self, tokens):
return self.model.forward(tokens)[0]
class HuggingFaceBlock(torch.nn.Module):
def __init__(self, block):
super().__init__()
self.model = block
def forward(self, tokens, alibi, attention_mask):
output = self.model(
hidden_states=tokens,
alibi=alibi,
attention_mask=attention_mask,
use_cache=True,
output_attentions=False,
)
return (output[0], output[1][0], output[1][1])
model = HuggingFaceLanguage()
compile_embeddings(
model.model.transformer.word_embeddings,
sample_input_ids,
f"{destination_folder}/word_embeddings.mlir",
)
inputs_embeds = model.model.transformer.word_embeddings(sample_input_ids)
compile_word_embeddings_layernorm(
model.model.transformer.word_embeddings_layernorm,
inputs_embeds,
f"{destination_folder}/word_embeddings_layernorm.mlir",
)
hidden_states = model.model.transformer.word_embeddings_layernorm(
inputs_embeds
)
input_shape = sample_input_ids.size()
current_sequence_length = hidden_states.shape[1]
past_key_values_length = 0
past_key_values = tuple([None] * len(model.model.transformer.h))
attention_mask = torch.ones(
(hidden_states.shape[0], current_sequence_length), device="cpu"
)
alibi = build_alibi_tensor(
attention_mask,
model.model.transformer.n_head,
hidden_states.dtype,
"cpu",
)
causal_mask = _prepare_attn_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
head_mask = model.model.transformer.get_head_mask(
None, model.model.transformer.config.n_layer
)
output_attentions = model.model.transformer.config.output_attentions
all_hidden_states = ()
for i, (block, layer_past) in enumerate(
zip(model.model.transformer.h, past_key_values)
):
all_hidden_states = all_hidden_states + (hidden_states,)
proxy_model = HuggingFaceBlock(block)
compile_to_mlir(
proxy_model,
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=True,
output_attentions=output_attentions,
alibi=alibi,
block_index=i,
path=f"{destination_folder}/bloom_block_{i}.mlir",
)
compile_ln_f(
model.model.transformer.ln_f,
hidden_states,
f"{destination_folder}/ln_f.mlir",
)
hidden_states = model.model.transformer.ln_f(hidden_states)
compile_lm_head(
model.model.lm_head,
hidden_states,
f"{destination_folder}/lm_head.mlir",
)
def run_large_model(
token_count,
recompile,
model_path,
prompt,
device_list,
script_path,
device,
):
f = open(f"{model_path}/prompt.txt", "w+")
f.write(prompt)
f.close()
for i in range(token_count):
if i == 0:
will_compile = recompile
else:
will_compile = False
f = open(f"{model_path}/prompt.txt", "r")
prompt = f.read()
f.close()
subprocess.run(
[
"python",
script_path,
model_path,
"start",
str(will_compile),
"cpu",
"None",
prompt,
]
)
for i in range(config["n_layer"]):
if device_list is not None:
device_idx = str(device_list[i % len(device_list)])
else:
device_idx = "None"
subprocess.run(
[
"python",
script_path,
model_path,
str(i),
str(will_compile),
device,
device_idx,
prompt,
]
)
subprocess.run(
[
"python",
script_path,
model_path,
"end",
str(will_compile),
"cpu",
"None",
prompt,
]
)
f = open(f"{model_path}/prompt.txt", "r")
output = f.read()
f.close()
print(output)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog="Bloom-560m")
parser.add_argument("-p", "--model_path")
parser.add_argument("-dl", "--device_list", default=None)
parser.add_argument("-de", "--device", default="cpu")
parser.add_argument("-c", "--recompile", default=False, type=bool)
parser.add_argument("-d", "--download", default=False, type=bool)
parser.add_argument("-t", "--token_count", default=10, type=int)
parser.add_argument("-m", "--model_name", default="bloom-560m")
parser.add_argument("-cm", "--create_mlirs", default=False, type=bool)
parser.add_argument(
"-lm", "--large_model_memory_efficient", default=False, type=bool
)
parser.add_argument(
"-pr",
"--prompt",
default=None,
)
args = parser.parse_args()
if args.create_mlirs and args.large_model_memory_efficient:
print(
"Warning: If you need to use memory efficient mode, you probably want to use 'download' instead"
)
if not os.path.isdir(args.model_path):
os.mkdir(args.model_path)
if args.device_list is not None:
args.device_list = json.loads(args.device_list)
if args.device == "cuda" and args.device_list is not None:
IS_CUDA = True
from cuda.cudart import cudaSetDevice
if args.download and args.create_mlirs:
print(
"WARNING: It is not advised to turn on both download and create_mlirs"
)
if args.download:
download_model(args.model_path, args.model_name)
if args.create_mlirs:
create_mlirs(args.model_path, args.model_name)
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
if args.prompt is not None:
input_ids = tokenizer.encode(args.prompt, return_tensors="pt")
if args.large_model_memory_efficient:
f = open(f"{args.model_path}/config.json")
config = json.load(f)
f.close()
self_path = os.path.dirname(os.path.abspath(__file__))
script_path = os.path.join(self_path, "sharded_bloom_large_models.py")
if args.prompt is not None:
run_large_model(
args.token_count,
args.recompile,
args.model_path,
args.prompt,
args.device_list,
script_path,
args.device,
)
else:
while True:
prompt = input("Enter Prompt: ")
try:
token_count = int(
input("Enter number of tokens you want to generate: ")
)
except:
print(
"Invalid integer entered. Using default value of 10"
)
token_count = 10
run_large_model(
token_count,
args.recompile,
args.model_path,
prompt,
args.device_list,
script_path,
args.device,
)
else:
shardedbloom = ShardedBloom(args.model_path)
shardedbloom.init_layers(
device=args.device,
replace=args.recompile,
device_idx=args.device_list,
)
shardedbloom.load_layers()
if args.prompt is not None:
for _ in range(args.token_count):
next_token = shardedbloom.forward_pass(
torch.tensor(input_ids), device=args.device
)
input_ids = torch.cat(
[input_ids, next_token.unsqueeze(-1)], dim=-1
)
print(tokenizer.decode(input_ids.squeeze()))
else:
while True:
prompt = input("Enter Prompt: ")
try:
token_count = int(
input("Enter number of tokens you want to generate: ")
)
except:
print(
"Invalid integer entered. Using default value of 10"
)
token_count = 10
input_ids = tokenizer.encode(prompt, return_tensors="pt")
for _ in range(token_count):
next_token = shardedbloom.forward_pass(
torch.tensor(input_ids), device=args.device
)
input_ids = torch.cat(
[input_ids, next_token.unsqueeze(-1)], dim=-1
)
print(tokenizer.decode(input_ids.squeeze()))

View File

@@ -1,381 +0,0 @@
import sys
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
import re
from shark.shark_inference import SharkInference
import torch
import torch.nn as nn
from collections import OrderedDict
from transformers.models.bloom.modeling_bloom import (
BloomBlock,
build_alibi_tensor,
)
import time
import json
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
batch_size, source_length = mask.size()
tgt_len = tgt_len if tgt_len is not None else source_length
expanded_mask = (
mask[:, None, None, :]
.expand(batch_size, 1, tgt_len, source_length)
.to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def _prepare_attn_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
past_key_values_length=past_key_values_length,
).to(attention_mask.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
past_key_values_length: int = 0,
):
"""
Make causal mask used for bi-directional self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
mask_cond = torch.arange(mask.size(-1))
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
mask.masked_fill_(intermediate_mask, 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
target_length, past_key_values_length, dtype=dtype
),
mask,
],
dim=-1,
)
expanded_mask = mask[None, None, :, :].expand(
batch_size, 1, target_length, target_length + past_key_values_length
)
return expanded_mask
if __name__ == "__main__":
working_dir = sys.argv[1]
layer_name = sys.argv[2]
will_compile = sys.argv[3]
device = sys.argv[4]
device_idx = sys.argv[5]
prompt = sys.argv[6]
if device_idx.lower().strip() == "none":
device_idx = None
else:
device_idx = int(device_idx)
if will_compile.lower().strip() == "true":
will_compile = True
else:
will_compile = False
f = open(f"{working_dir}/config.json")
config = json.load(f)
f.close()
layers_initialized = False
try:
n_embed = config["n_embed"]
except KeyError:
n_embed = config["hidden_size"]
vocab_size = config["vocab_size"]
n_layer = config["n_layer"]
try:
n_head = config["num_attention_heads"]
except KeyError:
n_head = config["n_head"]
if not os.path.isdir(working_dir):
os.mkdir(working_dir)
if layer_name == "start":
tokenizer = AutoTokenizer.from_pretrained(working_dir)
input_ids = tokenizer.encode(prompt, return_tensors="pt")
mlir_str = ""
if will_compile:
f = open(f"{working_dir}/word_embeddings.mlir", encoding="utf-8")
mlir_str = f.read()
f.close()
mlir_str = bytes(mlir_str, "utf-8")
shark_module = SharkInference(
mlir_str,
device="cpu",
mlir_dialect="tm_tensor",
device_idx=None,
)
if will_compile:
shark_module.save_module(
module_name=f"{working_dir}/word_embeddings",
extra_args=[
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-stream-resource-max-allocation-size=1000000000",
"--iree-codegen-check-ir-before-llvm-conversion=false",
],
)
shark_module.load_module(f"{working_dir}/word_embeddings.vmfb")
input_embeds = shark_module(
inputs=(input_ids,), function_name="forward"
)
input_embeds = torch.tensor(input_embeds).float()
mlir_str = ""
if will_compile:
f = open(
f"{working_dir}/word_embeddings_layernorm.mlir",
encoding="utf-8",
)
mlir_str = f.read()
f.close()
shark_module = SharkInference(
mlir_str,
device="cpu",
mlir_dialect="tm_tensor",
device_idx=None,
)
if will_compile:
shark_module.save_module(
module_name=f"{working_dir}/word_embeddings_layernorm",
extra_args=[
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-stream-resource-max-allocation-size=1000000000",
"--iree-codegen-check-ir-before-llvm-conversion=false",
],
)
shark_module.load_module(
f"{working_dir}/word_embeddings_layernorm.vmfb"
)
hidden_states = shark_module(
inputs=(input_embeds,), function_name="forward"
)
hidden_states = torch.tensor(hidden_states).float()
torch.save(hidden_states, f"{working_dir}/hidden_states_0.pt")
attention_mask = torch.ones(
[hidden_states.shape[0], len(input_ids[0])]
)
attention_mask = torch.tensor(attention_mask).float()
alibi = build_alibi_tensor(
attention_mask,
n_head,
hidden_states.dtype,
device="cpu",
)
torch.save(alibi, f"{working_dir}/alibi.pt")
causal_mask = _prepare_attn_mask(
attention_mask, input_ids.size(), input_embeds, 0
)
causal_mask = torch.tensor(causal_mask).float()
torch.save(causal_mask, f"{working_dir}/causal_mask.pt")
elif layer_name in [str(x) for x in range(n_layer)]:
hidden_states = torch.load(
f"{working_dir}/hidden_states_{layer_name}.pt"
)
alibi = torch.load(f"{working_dir}/alibi.pt")
causal_mask = torch.load(f"{working_dir}/causal_mask.pt")
mlir_str = ""
if will_compile:
f = open(
f"{working_dir}/bloom_block_{layer_name}.mlir",
encoding="utf-8",
)
mlir_str = f.read()
f.close()
mlir_str = bytes(mlir_str, "utf-8")
shark_module = SharkInference(
mlir_str,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
)
if will_compile:
shark_module.save_module(
module_name=f"{working_dir}/bloom_block_{layer_name}",
extra_args=[
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-stream-resource-max-allocation-size=1000000000",
"--iree-codegen-check-ir-before-llvm-conversion=false",
],
)
shark_module.load_module(
f"{working_dir}/bloom_block_{layer_name}.vmfb"
)
output = shark_module(
inputs=(
hidden_states.detach().numpy(),
alibi.detach().numpy(),
causal_mask.detach().numpy(),
),
function_name="forward",
)
hidden_states = torch.tensor(output[0]).float()
torch.save(
hidden_states,
f"{working_dir}/hidden_states_{int(layer_name) + 1}.pt",
)
elif layer_name == "end":
mlir_str = ""
if will_compile:
f = open(f"{working_dir}/ln_f.mlir", encoding="utf-8")
mlir_str = f.read()
f.close()
mlir_str = bytes(mlir_str, "utf-8")
shark_module = SharkInference(
mlir_str,
device="cpu",
mlir_dialect="tm_tensor",
device_idx=None,
)
if will_compile:
shark_module.save_module(
module_name=f"{working_dir}/ln_f",
extra_args=[
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-stream-resource-max-allocation-size=1000000000",
"--iree-codegen-check-ir-before-llvm-conversion=false",
],
)
shark_module.load_module(f"{working_dir}/ln_f.vmfb")
hidden_states = torch.load(f"{working_dir}/hidden_states_{n_layer}.pt")
hidden_states = shark_module(
inputs=(hidden_states,), function_name="forward"
)
mlir_str = ""
if will_compile:
f = open(f"{working_dir}/lm_head.mlir", encoding="utf-8")
mlir_str = f.read()
f.close()
mlir_str = bytes(mlir_str, "utf-8")
if config["n_embed"] == 14336:
def get_state_dict():
d = torch.load(
f"{working_dir}/pytorch_model_00001-of-00072.bin"
)
return OrderedDict(
(k.replace("word_embeddings.", ""), v)
for k, v in d.items()
)
def load_causal_lm_head():
linear = nn.utils.skip_init(
nn.Linear, 14336, 250880, bias=False, dtype=torch.float
)
linear.load_state_dict(get_state_dict(), strict=False)
return linear.float()
lm_head = load_causal_lm_head()
logits = lm_head(torch.tensor(hidden_states).float())
else:
shark_module = SharkInference(
mlir_str,
device="cpu",
mlir_dialect="tm_tensor",
device_idx=None,
)
if will_compile:
shark_module.save_module(
module_name=f"{working_dir}/lm_head",
extra_args=[
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-stream-resource-max-allocation-size=1000000000",
"--iree-codegen-check-ir-before-llvm-conversion=false",
],
)
shark_module.load_module(f"{working_dir}/lm_head.vmfb")
logits = shark_module(
inputs=(hidden_states,), function_name="forward"
)
logits = torch.tensor(logits).float()
tokenizer = AutoTokenizer.from_pretrained(working_dir)
next_token = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=-1))
f = open(f"{working_dir}/prompt.txt", "w+")
f.write(prompt + next_token)
f.close()

View File

@@ -1,43 +0,0 @@
# Stable Diffusion Fine Tuning
## Installation (Linux)
### Activate shark.venv Virtual Environment
```shell
source shark.venv/bin/activate
# Some older pip installs may not be able to handle the recent PyTorch deps
python -m pip install --upgrade pip
```
## Install dependencies
### Run the following installation commands:
```
pip install -U git+https://github.com/huggingface/diffusers.git
pip install accelerate transformers ftfy
```
### Build torch-mlir with the following branch:
Please cherry-pick this branch of torch-mlir: https://github.com/vivekkhandelwal1/torch-mlir/tree/sd-ops
and build it locally. You can find the instructions for using locally build Torch-MLIR,
here: https://github.com/nod-ai/SHARK#how-to-use-your-locally-built-iree--torch-mlir-with-shark
## Run the Stable diffusion fine tuning
To run the model with the default set of images and params, run:
```shell
python stable_diffusion_fine_tuning.py
```
By default the training is run through the PyTorch path. If you want to train the model using the Torchdynamo path of Torch-MLIR, you need to specify `--use_torchdynamo=True`.
The default number of training steps are `2000`, which would take many hours to complete based on your system config. You can pass the smaller value with the arg `--training_steps`. You can specify the number of images to be sampled for the result with the `--num_inference_samples` arg. For the number of inference steps you can use `--inference_steps` flag.
For example, you can run the training for a limited set of steps via the dynamo path by using the following command:
```
python stable_diffusion_fine_tuning.py --training_steps=1 --inference_steps=1 --num_inference_samples=1 --train_batch_size=1 --use_torchdynamo=True
```
You can also specify the device to be used via the flag `--device`. The default value is `cpu`, for GPU execution you can specify `--device="cuda"`.

View File

@@ -1,914 +0,0 @@
# Install the required libs
# pip install -U git+https://github.com/huggingface/diffusers.git
# pip install accelerate transformers ftfy
# Import required libraries
import argparse
import itertools
import math
import os
from typing import List
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import PIL
import logging
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
import torch._dynamo as dynamo
from torch.fx.experimental.proxy_tensor import make_fx
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
from shark.shark_inference import SharkInference
torch._dynamo.config.verbose = True
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker,
)
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer,
)
# Enter your HuggingFace Token
# Note: You can comment this prompt and just set your token instead of passing it through cli for every execution.
hf_token = input("Please enter your huggingface token here: ")
YOUR_TOKEN = hf_token
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
# `pretrained_model_name_or_path` which Stable Diffusion checkpoint you want to use
# Options: 1.) "stabilityai/stable-diffusion-2"
# 2.) "stabilityai/stable-diffusion-2-base"
# 3.) "CompVis/stable-diffusion-v1-4"
# 4.) "runwayml/stable-diffusion-v1-5"
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2"
# Add here the URLs to the images of the concept you are adding. 3-5 should be fine
urls = [
"https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg",
"https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg",
"https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg",
"https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg",
## You can add additional images here
]
# Downloading Images
import requests
import glob
from io import BytesIO
def download_image(url):
try:
response = requests.get(url)
except:
return None
return Image.open(BytesIO(response.content)).convert("RGB")
images = list(filter(None, [download_image(url) for url in urls]))
save_path = "./my_concept"
if not os.path.exists(save_path):
os.mkdir(save_path)
[image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)]
p = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
p.add_argument(
"--input_dir",
type=str,
default="my_concept/",
help="the directory contains the images used for fine tuning",
)
p.add_argument(
"--output_dir",
type=str,
default="sd_result",
help="the directory contains the images used for fine tuning",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="the maximum number of training steps",
)
p.add_argument(
"--train_batch_size",
type=int,
default=4,
help="The batch size for training",
)
p.add_argument(
"--save_steps",
type=int,
default=250,
help="the number of steps after which to save the learned concept",
)
p.add_argument("--seed", type=int, default=42, help="the random seed")
p.add_argument(
"--what_to_teach",
type=str,
choices=["object", "style"],
default="object",
help="what is it that you are teaching?",
)
p.add_argument(
"--placeholder_token",
type=str,
default="<cat-toy>",
help="It is the token you are going to use to represent your new concept",
)
p.add_argument(
"--initializer_token",
type=str,
default="toy",
help="It is a word that can summarise what is your new concept",
)
p.add_argument(
"--inference_steps",
type=int,
default=50,
help="the number of steps for inference",
)
p.add_argument(
"--num_inference_samples",
type=int,
default=4,
help="the number of samples for inference",
)
p.add_argument(
"--prompt",
type=str,
default="a grafitti in a wall with a *s on it",
help="the text prompt to use",
)
p.add_argument(
"--device",
type=str,
default="cpu",
help="The device to use",
)
p.add_argument(
"--use_torchdynamo",
type=bool,
default=False,
help="This flag is used to determine whether the training has to be done through the torchdynamo path or not.",
)
args = p.parse_args()
torch.manual_seed(args.seed)
if "*s" not in args.prompt:
raise ValueError(
f'The prompt should have a "*s" which will be replaced by a placeholder token.'
)
prompt1, prompt2 = args.prompt.split("*s")
args.prompt = prompt1 + args.placeholder_token + prompt2
# `images_path` is a path to directory containing the training images.
images_path = args.input_dir
while not os.path.exists(str(images_path)):
print(
"The images_path specified does not exist, use the colab file explorer to copy the path :"
)
images_path = input("")
save_path = images_path
# Setup and check the images you have just added
images = []
for file_path in os.listdir(save_path):
try:
image_path = os.path.join(save_path, file_path)
images.append(Image.open(image_path).resize((512, 512)))
except:
print(
f"{image_path} is not a valid image, please make sure to remove this file from the directory otherwise the training could fail."
)
image_grid(images, 1, len(images))
########### Create Dataset ##########
# Setup the prompt templates for training
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
# Setup the dataset
class TextualInversionDataset(Dataset):
def __init__(
self,
data_root,
tokenizer,
learnable_property="object", # [object, style]
size=512,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
center_crop=False,
):
self.data_root = data_root
self.tokenizer = tokenizer
self.learnable_property = learnable_property
self.size = size
self.placeholder_token = placeholder_token
self.center_crop = center_crop
self.flip_p = flip_p
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
self.num_images = len(self.image_paths)
self._length = self.num_images
if set == "train":
self._length = self.num_images * repeats
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.templates = (
imagenet_style_templates_small
if learnable_property == "style"
else imagenet_templates_small
)
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
placeholder_string = self.placeholder_token
text = random.choice(self.templates).format(placeholder_string)
example["input_ids"] = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids[0]
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip_transform(image)
image = np.array(image).astype(np.uint8)
image = (image / 127.5 - 1.0).astype(np.float32)
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
return example
########## Setting up the model ##########
# Load the tokenizer and add the placeholder token as a additional special token.
tokenizer = CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer",
)
# Add the placeholder token in tokenizer
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# Get token ids for our placeholder and initializer token.
# This code block will complain if initializer string is not a single token
# Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
# Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1:
raise ValueError("The initializer token must be a single token.")
initializer_token_id = token_ids[0]
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
# Load the Stable Diffusion model
# Load models and create wrapper for stable diffusion
# pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path)
# del pipeline
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path, subfolder="vae"
)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet"
)
# We have added the placeholder_token in the tokenizer so we resize the token embeddings here
# this will a new embedding vector in the token embeddings for our placeholder_token
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
# In Textual-Inversion we only train the newly added embedding vector
# so lets freeze rest of the model parameters here
def freeze_params(params):
for param in params:
param.requires_grad = False
# Freeze vae and unet
freeze_params(vae.parameters())
freeze_params(unet.parameters())
# Freeze all parameters except for the token embeddings in text encoder
params_to_freeze = itertools.chain(
text_encoder.text_model.encoder.parameters(),
text_encoder.text_model.final_layer_norm.parameters(),
text_encoder.text_model.embeddings.position_embedding.parameters(),
)
freeze_params(params_to_freeze)
# Move vae and unet to device
# For the dynamo path default compilation device is `cpu`, since torch-mlir
# supports only that. Therefore, convert to device only for PyTorch path.
if not args.use_torchdynamo:
vae.to(args.device)
unet.to(args.device)
# Keep vae in eval mode as we don't train it
vae.eval()
# Keep unet in train mode to enable gradient checkpointing
unet.train()
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = vae
def forward(self, input):
x = self.vae.encode(input, return_dict=False)[0]
return x
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = unet
def forward(self, x, y, z):
return self.unet.forward(x, y, z, return_dict=False)[0]
shark_vae = VaeModel()
shark_unet = UnetModel()
####### Creating our training data ########
# Let's create the Dataset and Dataloader
train_dataset = TextualInversionDataset(
data_root=save_path,
tokenizer=tokenizer,
size=vae.sample_size,
placeholder_token=args.placeholder_token,
repeats=100,
learnable_property=args.what_to_teach, # Option selected above between object and style
center_crop=False,
set="train",
)
def create_dataloader(train_batch_size=1):
return torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True
)
# Create noise_scheduler for training
noise_scheduler = DDPMScheduler.from_config(
pretrained_model_name_or_path, subfolder="scheduler"
)
######## Training ###########
# Define hyperparameters for our training. If you are not happy with your results,
# you can tune the `learning_rate` and the `max_train_steps`
# Setting up all training args
hyperparameters = {
"learning_rate": 5e-04,
"scale_lr": True,
"max_train_steps": args.training_steps,
"save_steps": args.save_steps,
"train_batch_size": args.train_batch_size,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True,
"mixed_precision": "fp16",
"seed": 42,
"output_dir": "sd-concept-output",
}
# creating output directory
cwd = os.getcwd()
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
while not os.path.exists(str(out_dir)):
try:
os.mkdir(out_dir)
except OSError as error:
print("Output directory not created")
###### Torch-MLIR Compilation ######
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
@make_simple_dynamo_backend
def refbackend_torchdynamo_backend(
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
):
# handling usage of empty tensor without initializing
transform_fx(fx_graph)
fx_graph.recompile()
if _returns_nothing(fx_graph):
return fx_graph
removed_none_indexes = _remove_nones(fx_graph)
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
mlir_module = torch_mlir.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors"
)
bytecode_stream = BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
def compiled_callable(*inputs):
inputs = [x.numpy() for x in inputs]
result = shark_module("forward", inputs)
if was_unwrapped:
result = [
result,
]
if not isinstance(result, list):
result = torch.from_numpy(result)
else:
result = tuple(torch.from_numpy(x) for x in result)
result = list(result)
for removed_index in removed_none_indexes:
result.insert(removed_index, None)
result = tuple(result)
return result
return compiled_callable
def predictions(torch_func, jit_func, batchA, batchB):
res = jit_func(batchA.numpy(), batchB.numpy())
if res is not None:
prediction = res
else:
prediction = None
return prediction
logger = logging.getLogger(__name__)
# def save_progress(text_encoder, placeholder_token_id, accelerator, save_path):
def save_progress(text_encoder, placeholder_token_id, save_path):
logger.info("Saving embeddings")
learned_embeds = (
# accelerator.unwrap_model(text_encoder)
text_encoder.get_input_embeddings().weight[placeholder_token_id]
)
learned_embeds_dict = {
args.placeholder_token: learned_embeds.detach().cpu()
}
torch.save(learned_embeds_dict, save_path)
train_batch_size = hyperparameters["train_batch_size"]
gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"]
learning_rate = hyperparameters["learning_rate"]
if hyperparameters["scale_lr"]:
learning_rate = (
learning_rate
* gradient_accumulation_steps
* train_batch_size
# * accelerator.num_processes
)
# Initialize the optimizer
optimizer = torch.optim.AdamW(
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
lr=learning_rate,
)
# Training function
def train_func(batch_pixel_values, batch_input_ids):
# Convert images to latent space
latents = shark_vae(batch_pixel_values).sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.num_train_timesteps,
(bsz,),
device=latents.device,
).long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch_input_ids)[0]
# Predict the noise residual
noise_pred = shark_unet(
noisy_latents,
timesteps,
encoder_hidden_states,
)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
)
loss = (
F.mse_loss(noise_pred, target, reduction="none").mean([1, 2, 3]).mean()
)
loss.backward()
# Zero out the gradients for all token embeddings except the newly added
# embeddings for the concept, as we only want to optimize the concept embeddings
grads = text_encoder.get_input_embeddings().weight.grad
# Get the index for tokens that we want to zero the grads for
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
grads.data[index_grads_to_zero, :] = grads.data[
index_grads_to_zero, :
].fill_(0)
optimizer.step()
optimizer.zero_grad()
return loss
def training_function():
max_train_steps = hyperparameters["max_train_steps"]
output_dir = hyperparameters["output_dir"]
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
train_dataloader = create_dataloader(train_batch_size)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
# Train!
total_batch_size = (
train_batch_size
* gradient_accumulation_steps
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(
# range(max_train_steps), disable=not accelerator.is_local_main_process
range(max_train_steps)
)
progress_bar.set_description("Steps")
global_step = 0
params_ = [i for i in text_encoder.get_input_embeddings().parameters()]
if args.use_torchdynamo:
print("******** TRAINING STARTED - TORCHYDNAMO PATH ********")
else:
print("******** TRAINING STARTED - PYTORCH PATH ********")
print("Initial weights:")
print(params_, params_[0].shape)
for epoch in range(num_train_epochs):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
if args.use_torchdynamo:
dynamo_callable = dynamo.optimize(
refbackend_torchdynamo_backend
)(train_func)
lam_func = lambda x, y: dynamo_callable(
torch.from_numpy(x), torch.from_numpy(y)
)
loss = predictions(
train_func,
lam_func,
batch["pixel_values"],
batch["input_ids"],
# params[0].detach(),
)
else:
loss = train_func(batch["pixel_values"], batch["input_ids"])
print(loss)
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1)
global_step += 1
if global_step % hyperparameters["save_steps"] == 0:
save_path = os.path.join(
output_dir,
f"learned_embeds-step-{global_step}.bin",
)
save_progress(
text_encoder,
placeholder_token_id,
save_path,
)
logs = {"loss": loss.detach().item()}
progress_bar.set_postfix(**logs)
if global_step >= max_train_steps:
break
# Create the pipeline using using the trained modules and save it.
params__ = [i for i in text_encoder.get_input_embeddings().parameters()]
print("******** TRAINING PROCESS FINISHED ********")
print("Updated weights:")
print(params__, params__[0].shape)
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path,
# text_encoder=accelerator.unwrap_model(text_encoder),
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
unet=unet,
)
pipeline.save_pretrained(output_dir)
# Also save the newly trained embeddings
save_path = os.path.join(output_dir, f"learned_embeds.bin")
save_progress(text_encoder, placeholder_token_id, save_path)
training_function()
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
if param.grad is not None:
del param.grad # free some memory
torch.cuda.empty_cache()
# Set up the pipeline
from diffusers import DPMSolverMultistepScheduler
pipe = StableDiffusionPipeline.from_pretrained(
hyperparameters["output_dir"],
scheduler=DPMSolverMultistepScheduler.from_pretrained(
hyperparameters["output_dir"], subfolder="scheduler"
),
)
if not args.use_torchdynamo:
pipe.to(args.device)
# Run the Stable Diffusion pipeline
# Don't forget to use the placeholder token in your prompt
all_images = []
for _ in range(args.num_inference_samples):
images = pipe(
[args.prompt],
num_inference_steps=args.inference_steps,
guidance_scale=7.5,
).images
all_images.extend(images)
output_path = os.path.abspath(os.path.join(os.getcwd(), args.output_dir))
if not os.path.isdir(args.output_dir):
os.mkdir(args.output_dir)
[
image.save(f"{args.output_dir}/{i}.jpeg")
for i, image in enumerate(all_images)
]

View File

@@ -19,14 +19,10 @@ import sys
import subprocess
def run_cmd(cmd, debug=False):
def run_cmd(cmd):
"""
Inputs: cli command string.
"""
if debug:
print("IREE run command: \n\n")
print(cmd)
print("\n\n")
try:
result = subprocess.run(
cmd,
@@ -35,9 +31,8 @@ def run_cmd(cmd, debug=False):
stderr=subprocess.PIPE,
check=True,
)
stdout = result.stdout.decode()
stderr = result.stderr.decode()
return stdout, stderr
result_str = result.stdout.decode()
return result_str
except subprocess.CalledProcessError as e:
print(e.output)
sys.exit(f"Exiting program due to error running {cmd}")
@@ -45,15 +40,10 @@ def run_cmd(cmd, debug=False):
def iree_device_map(device):
uri_parts = device.split("://", 2)
iree_driver = (
_IREE_DEVICE_MAP[uri_parts[0]]
if uri_parts[0] in _IREE_DEVICE_MAP
else uri_parts[0]
)
if len(uri_parts) == 1:
return iree_driver
return _IREE_DEVICE_MAP[uri_parts[0]]
else:
return f"{iree_driver}://{uri_parts[1]}"
return f"{_IREE_DEVICE_MAP[uri_parts[0]]}://{uri_parts[1]}"
def get_supported_device_list():
@@ -73,7 +63,7 @@ _IREE_DEVICE_MAP = {
def iree_target_map(device):
if "://" in device:
device = device.split("://")[0]
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
return _IREE_TARGET_MAP[device]
_IREE_TARGET_MAP = {
@@ -115,8 +105,10 @@ def check_device_drivers(device):
subprocess.check_output("rocminfo")
except Exception:
return True
# Unknown device.
else:
return True
# Unknown device. We assume drivers are installed.
return False

View File

@@ -90,7 +90,6 @@ def build_benchmark_args(
benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}")
# if time_extractor:
# benchmark_cl.append(time_extractor)
benchmark_cl.append(f"--print_statistics=true")
return benchmark_cl
@@ -130,8 +129,7 @@ def build_benchmark_args_non_tensor_input(
def run_benchmark_module(benchmark_cl):
"""
Run benchmark command, extract result and return iteration/seconds, host
peak memory, and device peak memory.
Run benchmark command, extract result and return iteration/seconds.
# TODO: Add an example of the benchmark command.
Input: benchmark command.
@@ -140,22 +138,10 @@ def run_benchmark_module(benchmark_cl):
assert os.path.exists(
benchmark_path
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
try:
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
match = regex_split.search(bench_stdout)
time_ms = float(match.group(1))
unit = match.group(3)
except AttributeError:
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
match = regex_split.search(bench_stdout)
time_ms = float(match.group(1))
unit = match.group(2)
iter_per_second = 1.0 / (time_ms * 0.001)
# Extract peak memory.
host_regex = re.compile(r".*HOST_LOCAL:\s*([0-9]+)B peak")
host_peak_b = int(host_regex.search(bench_stderr).group(1))
device_regex = re.compile(r".*DEVICE_LOCAL:\s*([0-9]+)B peak")
device_peak_b = int(device_regex.search(bench_stderr).group(1))
return iter_per_second, host_peak_b, device_peak_b
bench_result = run_cmd(" ".join(benchmark_cl))
print(bench_result)
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
unit = match.group(3)
return 1.0 / (time * 0.001)

View File

@@ -23,7 +23,6 @@ import re
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan"]:
@@ -31,9 +30,6 @@ def get_iree_device_args(device, extra_args=[]):
f"Specific device selection only supported for vulkan now."
f"Proceeding with {device} as device."
)
device_num = device_uri[1]
else:
device_num = 0
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
@@ -46,9 +42,7 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] in ["metal", "vulkan"]:
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
return get_iree_vulkan_args(extra_args=extra_args)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
@@ -58,11 +52,11 @@ def get_iree_device_args(device, extra_args=[]):
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
return ["--iree-llvmcpu-target-cpu-features=host"]
if frontend in ["torch", "pytorch", "linalg"]:
return ["--iree-llvm-target-cpu-features=host"]
elif frontend in ["tensorflow", "tf", "mhlo"]:
return [
"--iree-llvmcpu-target-cpu-features=host",
"--iree-llvm-target-cpu-features=host",
"--iree-mhlo-demote-i64-to-i32=false",
"--iree-flow-demote-i64-to-i32",
]
@@ -195,23 +189,21 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
benchmark_bash.write(" ".join(benchmark_cl))
benchmark_bash.close()
iter_per_second, _, _ = run_benchmark_module(
benchmark_cl
)
benchmark_data = run_benchmark_module(benchmark_cl)
benchmark_file = open(
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
)
benchmark_file.write(f"DISPATCH: {d_}\n")
benchmark_file.write(str(iter_per_second) + "\n")
benchmark_file.write(str(benchmark_data) + "\n")
benchmark_file.write(
"SHARK BENCHMARK RESULT: "
+ str(1 / (iter_per_second * 0.001))
+ str(1 / (benchmark_data * 0.001))
+ "\n"
)
benchmark_file.close()
benchmark_runtimes[d_] = 1 / (iter_per_second * 0.001)
benchmark_runtimes[d_] = 1 / (benchmark_data * 0.001)
elif ".mlir" in f_ and "benchmark" not in f_:
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
@@ -302,8 +294,7 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
haldriver.query_available_devices()[device_idx]["device_id"]
)
config = ireert.Config(device=haldevice)
else:
@@ -313,7 +304,7 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
ModuleCompiled = getattr(ctx.modules, vm_module.name)
ModuleCompiled = ctx.modules.module
return ModuleCompiled, config
@@ -412,10 +403,5 @@ def get_results(
def get_iree_runtime_config(device):
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device_by_uri(
device,
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
config = ireert.Config(device=ireert.get_device(device))
return config

View File

@@ -44,4 +44,4 @@ def get_iree_cpu_args():
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
raise Exception(error_message)
print(f"Target triple found:{target_triple}")
return [f"--iree-llvmcpu-target-triple={target_triple}"]
return [f"-iree-llvm-target-triple={target_triple}"]

View File

@@ -22,7 +22,7 @@ from shark.parser import shark_args
# Get the default gpu args given the architecture.
def get_iree_gpu_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
ireert.flags.parse_flags("--cuda_allow_inline_execution")
ireert.flags.parse_flags("--cuda_allow_inline_execution", "--device_allocator=caching")
# TODO: Give the user_interface to pass the sm_arch.
sm_arch = get_cuda_sm_cc()
if (
@@ -30,10 +30,11 @@ def get_iree_gpu_args():
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
) and (shark_args.enable_tf32 == True):
return [
"--iree-hal-cuda-disable-loop-nounroll-wa",
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
]
else:
return []
return ["--iree-hal-cuda-disable-loop-nounroll-wa"]
# Get the default gpu args given the architecture.

View File

@@ -131,9 +131,7 @@ def get_vendor(triple):
return "ARM"
if arch == "m1":
return "Apple"
if arch in ["arc", "UHD"]:
return "Intel"
if arch in ["turing", "ampere", "pascal"]:
if arch in ["turing", "ampere"]:
return "NVIDIA"
if arch == "ardeno":
return "Qualcomm"
@@ -151,7 +149,7 @@ def get_device_type(triple):
return "Unknown"
if arch == "cpu":
return "CPU"
if arch in ["turing", "ampere", "arc", "pascal"]:
if arch in ["turing", "ampere"]:
return "DiscreteGPU"
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]:
if product == "ivega10":
@@ -345,37 +343,6 @@ def get_vulkan_target_capabilities(triple):
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "arc":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
cap["subgroupSize"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderFloat64"] = False
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = False
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "cpu":
if product == "swiftshader":
cap["maxComputeSharedMemorySize"] = 16384
@@ -389,39 +356,6 @@ def get_vulkan_target_capabilities(triple):
"ShuffleRelative",
]
elif arch in ["pascal"]:
cap["maxComputeSharedMemorySize"] = 49152
cap["maxComputeWorkGroupInvocations"] = 1536
cap["maxComputeWorkGroupSize"] = [1536, 1024, 64]
cap["subgroupSize"] = 32
cap["minSubgroupSize"] = 32
cap["maxSubgroupSize"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderFloat64"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch in ["ampere", "turing"]:
cap["maxComputeSharedMemorySize"] = 49152
cap["maxComputeWorkGroupInvocations"] = 1024

View File

@@ -21,9 +21,8 @@ from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_vulkan_device_name(device_num=0):
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
def get_vulkan_device_name():
vulkaninfo_dump = run_cmd("vulkaninfo").split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
@@ -31,8 +30,8 @@ def get_vulkan_device_name(device_num=0):
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing device: {vulkaninfo_list[device_num]}")
return vulkaninfo_list[device_num]
print(f"Choosing first one: {vulkaninfo_list[0]}")
return vulkaninfo_list[0]
def get_os_name():
@@ -107,26 +106,21 @@ def get_vulkan_target_triple(device_name):
# Windows: AMD Radeon RX 7900 XTX
elif all(x in device_name for x in ("RX", "7900")):
triple = f"rdna3-7900-{system_os}"
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
triple = f"rdna3-w7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
triple = f"rdna2-unknown-{system_os}"
# Intel Targets
elif any(x in device_name for x in ("A770", "A750")):
triple = f"arc-770-{system_os}"
else:
triple = None
return triple
def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
def get_vulkan_triple_flag(device_name="", extra_args=[]):
for flag in extra_args:
if "-iree-vulkan-target-triple=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None
if device_name == "" or device_name == [] or device_name is None:
vulkan_device = get_vulkan_device_name(device_num=device_num)
vulkan_device = get_vulkan_device_name()
else:
vulkan_device = device_name
triple = get_vulkan_target_triple(vulkan_device)
@@ -144,10 +138,9 @@ def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
return None
def get_iree_vulkan_args(device_num=0, extra_args=[]):
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
def get_iree_vulkan_args(extra_args=[]):
res_vulkan_flag = ["--device_allocator=caching"]
res_vulkan_flag = []
vulkan_triple_flag = None
for arg in extra_args:
if "-iree-vulkan-target-triple=" in arg:
@@ -156,9 +149,7 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
break
if vulkan_triple_flag is None:
vulkan_triple_flag = get_vulkan_triple_flag(
device_num=device_num, extra_args=extra_args
)
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
if vulkan_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)

Some files were not shown because too many files have changed in this diff Show More