mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 14:58:11 -05:00
Compare commits
94 Commits
enable_cac
...
20230224.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bae208bcc4 | ||
|
|
b6c14ad468 | ||
|
|
0064cc2a6e | ||
|
|
0a0567e944 | ||
|
|
694b1d43a8 | ||
|
|
e7eb116bd2 | ||
|
|
596499a08c | ||
|
|
2a2e460df2 | ||
|
|
a9039b35ed | ||
|
|
a01154a507 | ||
|
|
1d9204282d | ||
|
|
5ff40a0d2d | ||
|
|
fab6d2e4e0 | ||
|
|
abab59c25f | ||
|
|
c25840b585 | ||
|
|
1b3f9125bb | ||
|
|
b5d9f5ba49 | ||
|
|
1c22aa9c8f | ||
|
|
e1d7fb879c | ||
|
|
e912c42bf0 | ||
|
|
e6841acf36 | ||
|
|
bc4459b6f4 | ||
|
|
9b544491e0 | ||
|
|
9c5415b598 | ||
|
|
040dbc317f | ||
|
|
65775046d8 | ||
|
|
b18bc36127 | ||
|
|
f01c526efd | ||
|
|
16168ab6b3 | ||
|
|
4233218629 | ||
|
|
b63fb36dc0 | ||
|
|
4e92304b89 | ||
|
|
2ae047f1a8 | ||
|
|
6d2a485264 | ||
|
|
4f045db024 | ||
|
|
5b33597b6d | ||
|
|
962470f610 | ||
|
|
ba8c116380 | ||
|
|
ad7330eae4 | ||
|
|
cf126e4839 | ||
|
|
c96d25c3e2 | ||
|
|
006aa0dae2 | ||
|
|
5b204bee86 | ||
|
|
d98b2afbe9 | ||
|
|
681332ef32 | ||
|
|
c3a4fdcbfc | ||
|
|
aac5de5b02 | ||
|
|
13a255afad | ||
|
|
3bffda52f9 | ||
|
|
d4e62ce557 | ||
|
|
9738483b18 | ||
|
|
143492fe94 | ||
|
|
ecc5c662c4 | ||
|
|
d973ba191d | ||
|
|
0198b183a2 | ||
|
|
0d44a3527b | ||
|
|
2147b6a397 | ||
|
|
6b5b4ba27b | ||
|
|
67005bf57c | ||
|
|
0430c741c6 | ||
|
|
1ce02e365d | ||
|
|
eae862adc2 | ||
|
|
dffa89524a | ||
|
|
2af1102441 | ||
|
|
c4b472842a | ||
|
|
750a7d806f | ||
|
|
bc7333f1e5 | ||
|
|
55ae50f991 | ||
|
|
a590c331ef | ||
|
|
8c241b06cb | ||
|
|
9c072c8068 | ||
|
|
ebd8b5122a | ||
|
|
055e484a40 | ||
|
|
912c4a1d12 | ||
|
|
c203b65bf1 | ||
|
|
307f0334ee | ||
|
|
5167df08b9 | ||
|
|
dd2e482214 | ||
|
|
87fd13d8eb | ||
|
|
dd423bc6de | ||
|
|
899cb9cc1f | ||
|
|
0464c7e558 | ||
|
|
f64e1fb926 | ||
|
|
ef7d31293d | ||
|
|
6d54eb68dc | ||
|
|
30eb10c990 | ||
|
|
591bbcd058 | ||
|
|
99aa77d036 | ||
|
|
9c13f1e635 | ||
|
|
24af983cfb | ||
|
|
67842a7525 | ||
|
|
3159a6f3e1 | ||
|
|
b2f3c96835 | ||
|
|
6582475955 |
21
.github/workflows/nightly.yml
vendored
21
.github/workflows/nightly.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.11"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -44,18 +44,20 @@ jobs:
|
||||
body: |
|
||||
Automatic snapshot release of nod.ai SHARK.
|
||||
draft: true
|
||||
prerelease: false
|
||||
prerelease: true
|
||||
|
||||
- name: Build Package
|
||||
shell: powershell
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
|
||||
python process_skipfiles.py
|
||||
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
|
||||
|
||||
# GHA windows VM OOMs so disable for now
|
||||
@@ -65,9 +67,9 @@ jobs:
|
||||
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
- uses: actions/upload-artifact@v2
|
||||
with:
|
||||
path: dist/*
|
||||
#- uses: actions/upload-artifact@v2
|
||||
# with:
|
||||
# path: dist/*
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
@@ -77,6 +79,7 @@ jobs:
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ./dist/*
|
||||
#asset_content_type: application/vnd.microsoft.portable-executable
|
||||
|
||||
- name: Publish Release
|
||||
id: publish_release
|
||||
@@ -92,7 +95,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.11"]
|
||||
backend: [IREE, SHARK]
|
||||
|
||||
steps:
|
||||
@@ -131,7 +134,7 @@ jobs:
|
||||
source iree.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://iree-org.github.io/iree/pip-release-links.html
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
|
||||
3
.github/workflows/test-models.yml
vendored
3
.github/workflows/test-models.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
matrix:
|
||||
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
|
||||
suite: [cpu,cuda,vulkan]
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.11"]
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
suite: lint
|
||||
@@ -158,5 +158,4 @@ jobs:
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
./shark.venv/Scripts/activate
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
|
||||
57
README.md
57
README.md
@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
|
||||
<summary>Prerequisites - Drivers </summary>
|
||||
|
||||
#### Install your Windows hardware drivers
|
||||
* [AMD RDNA Users] Download this specific driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mril-iree). Latest drivers may not work.
|
||||
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
|
||||
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
@@ -25,18 +25,32 @@ Other users please ensure you have your latest vendor drivers and Vulkan SDK fro
|
||||
|
||||
### Quick Start for SHARK Stable Diffusion for Windows 10/11 Users
|
||||
|
||||
Install Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
|
||||
Install the Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
|
||||
|
||||
Download the latest .exe https://github.com/nod-ai/SHARK/releases.
|
||||
Download the [stable release](https://github.com/nod-ai/shark/releases/latest)
|
||||
|
||||
Double click the .exe and you should have the [UI]( http://localhost:8080/?__theme=dark) in the browser.
|
||||
Double click the .exe and you should have the [UI](http://localhost:8080/) in the browser.
|
||||
|
||||
If you have custom models (ckpt, safetensors) put in a `models/` directory where the .exe is.
|
||||
If you have custom models put them in a `models/` directory where the .exe is.
|
||||
|
||||
Enjoy.
|
||||
|
||||
Some known AMD Driver quirks and fixes with cursors are documented [here](https://github.com/nod-ai/SHARK/blob/main/apps/stable_diffusion/stable_diffusion_amd.md ).
|
||||
<details>
|
||||
<summary>More installation notes</summary>
|
||||
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
|
||||
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
|
||||
|
||||
## Running
|
||||
|
||||
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE)
|
||||
* The first run may take few minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
|
||||
* You will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
|
||||
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/.
|
||||
|
||||
## Stopping
|
||||
|
||||
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Advanced Installation (Only for developers)</summary>
|
||||
@@ -54,7 +68,7 @@ cd SHARK
|
||||
|
||||
### Windows 10/11 Users
|
||||
|
||||
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
|
||||
* Install the latest Python 3.11.x version from [here](https://www.python.org/downloads/windows/)
|
||||
|
||||
* Install Git for Windows from [here](https://git-scm.com/download/win)
|
||||
|
||||
@@ -105,16 +119,15 @@ source shark.venv/bin/activate
|
||||
|
||||
#### Linux / macOS Users
|
||||
```shell
|
||||
python3.10 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
python3.11 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
```
|
||||
|
||||
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
|
||||
</details>
|
||||
|
||||
The output on a 7900XTX would like:
|
||||
The output on a AMD 7900XTX would look something like:
|
||||
|
||||
```shell
|
||||
Stats for run 0:
|
||||
```shell
|
||||
Average step time: 47.19188690185547ms/it
|
||||
Clip Inference time (ms) = 109.531
|
||||
VAE Inference time (ms): 78.590
|
||||
@@ -140,7 +153,7 @@ Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any
|
||||
This step sets up a new VirtualEnv for Python
|
||||
|
||||
```shell
|
||||
python --version #Check you have 3.10 on Linux, macOS or Windows Powershell
|
||||
python --version #Check you have 3.11 on Linux, macOS or Windows Powershell
|
||||
python -m venv shark_venv
|
||||
source shark_venv/bin/activate # Use shark_venv/Scripts/activate on Windows
|
||||
|
||||
@@ -154,7 +167,7 @@ python -m pip install --upgrade pip
|
||||
|
||||
### Install SHARK
|
||||
|
||||
This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10
|
||||
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
|
||||
|
||||
```shell
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
@@ -189,10 +202,10 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
|
||||
<details>
|
||||
<summary>Development, Testing and Benchmarks</summary>
|
||||
|
||||
If you want to use Python3.10 and with TF Import tools you can use the environment variables like:
|
||||
If you want to use Python3.11 and with TF Import tools you can use the environment variables like:
|
||||
Set `USE_IREE=1` to use upstream IREE
|
||||
```
|
||||
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
|
||||
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
|
||||
```
|
||||
|
||||
### Run any of the hundreds of SHARK tank models via the test framework
|
||||
@@ -202,14 +215,14 @@ python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use g
|
||||
pytest tank/test_models.py -k "MiniLM"
|
||||
```
|
||||
|
||||
|
||||
### How to use your locally built IREE / Torch-MLIR with SHARK
|
||||
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
|
||||
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
|
||||
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
|
||||
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
|
||||
for Torch-MLIR.
|
||||
|
||||
### How to use your locally built Torch-MLIR with SHARK
|
||||
How to use your locally built Torch-MLIR with SHARK:
|
||||
```shell
|
||||
1.) Run `./setup_venv.sh in SHARK` and activate `shark.venv` virtual env.
|
||||
2.) Run `pip uninstall torch-mlir`.
|
||||
@@ -227,9 +240,15 @@ Now the SHARK will use your locally build Torch-MLIR repo.
|
||||
|
||||
## Benchmarking Dispatches
|
||||
|
||||
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your command line argument.
|
||||
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
|
||||
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
|
||||
|
||||
For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
|
||||
```
|
||||
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
|
||||
```
|
||||
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.
|
||||
|
||||
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
|
||||
|
||||
```
|
||||
@@ -253,7 +272,7 @@ Output will include:
|
||||
- A .txt file containing benchmark output
|
||||
|
||||
|
||||
See tank/README.md for instructions on how to run model tests and benchmarks from the SHARK tank.
|
||||
See tank/README.md for further instructions on how to run model tests and benchmarks from the SHARK tank.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -1 +1,4 @@
|
||||
from apps.stable_diffusion.scripts.txt2img import txt2img_inf
|
||||
from apps.stable_diffusion.scripts.img2img import img2img_inf
|
||||
from apps.stable_diffusion.scripts.inpaint import inpaint_inf
|
||||
from apps.stable_diffusion.scripts.outpaint import outpaint_inf
|
||||
|
||||
@@ -0,0 +1,264 @@
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
img2img_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def img2img_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image: Image,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
):
|
||||
global img2img_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.seed = seed
|
||||
args.steps = steps
|
||||
args.strength = strength
|
||||
args.scheduler = scheduler
|
||||
args.img_path = "not none"
|
||||
|
||||
if init_image is None:
|
||||
return None, "An Initial Image is required"
|
||||
image = init_image.convert("RGB")
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if not img2img_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = ""
|
||||
args.use_tuned = True
|
||||
args.import_mlir = True
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "runwayml/stable-diffusion-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
img2img_obj = Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
img2img_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
img2img_obj.log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = img2img_obj.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
img2img_obj.log += "\n"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={steps}, strength={args.strength}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
text_output += img2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
# When the models get uploaded, it should be default to False.
|
||||
args.import_mlir = True
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
if args.scheduler != "PNDM":
|
||||
if "Shark" in args.scheduler:
|
||||
print(
|
||||
f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler"
|
||||
)
|
||||
args.scheduler = "PNDM"
|
||||
else:
|
||||
sys.exit(
|
||||
"Img2Img works best with PNDM scheduler. Other schedulers are not supported yet."
|
||||
)
|
||||
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
image = Image.open(args.img_path).convert("RGB")
|
||||
seed = utils.sanitize_seed(args.seed)
|
||||
|
||||
# Adjust for height and width based on model
|
||||
|
||||
img2img_obj = Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = img2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.strength,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, strength={args.strength}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += img2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
256
apps/stable_diffusion/scripts/inpaint.py
Normal file
256
apps/stable_diffusion/scripts/inpaint.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
inpaint_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def inpaint_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image_dict,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
):
|
||||
global inpaint_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.img_path = "not none"
|
||||
args.mask_path = "not none"
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if not inpaint_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = ""
|
||||
args.use_tuned = True
|
||||
args.import_mlir = False
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
inpaint_obj = InpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
inpaint_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
inpaint_obj.log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
image = image_dict["image"]
|
||||
mask_image = image_dict["mask"]
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = inpaint_obj.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
mask_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
inpaint_obj.log += "\n"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += inpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
if args.mask_path is None:
|
||||
print("Flag --mask_path is required.")
|
||||
exit()
|
||||
if "inpaint" not in args.hf_model_id:
|
||||
print("Please use inpainting model with --hf_model_id.")
|
||||
exit()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
image = Image.open(args.img_path)
|
||||
mask_image = Image.open(args.mask_path)
|
||||
|
||||
inpaint_obj = InpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = inpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
mask_image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += inpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
275
apps/stable_diffusion/scripts/outpaint.py
Normal file
275
apps/stable_diffusion/scripts/outpaint.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
OutpaintPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
outpaint_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def outpaint_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image: str,
|
||||
pixels: int,
|
||||
mask_blur: int,
|
||||
directions: list,
|
||||
noise_q: float,
|
||||
color_variation: float,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
):
|
||||
global outpaint_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.img_path = init_image
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if not outpaint_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = ""
|
||||
args.use_tuned = True
|
||||
args.import_mlir = False
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
outpaint_obj = OutpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
outpaint_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
outpaint_obj.log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
image = Image.open(args.img_path)
|
||||
|
||||
left = True if "left" in directions else False
|
||||
right = True if "right" in directions else False
|
||||
top = True if "up" in directions else False
|
||||
bottom = True if "down" in directions else False
|
||||
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = outpaint_obj.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
pixels,
|
||||
mask_blur,
|
||||
left,
|
||||
right,
|
||||
top,
|
||||
bottom,
|
||||
noise_q,
|
||||
color_variation,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
outpaint_obj.log += "\n"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += outpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
if "inpaint" not in args.hf_model_id:
|
||||
print("Please use inpainting model with --hf_model_id.")
|
||||
exit()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
image = Image.open(args.img_path)
|
||||
|
||||
outpaint_obj = OutpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = outpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
args.pixels,
|
||||
args.mask_blur,
|
||||
args.left,
|
||||
args.right,
|
||||
args.top,
|
||||
args.bottom,
|
||||
args.noise_q,
|
||||
args.color_variation,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += outpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
@@ -1,24 +1,15 @@
|
||||
import os
|
||||
|
||||
if "AMD_ENABLE_LLPC" not in os.environ:
|
||||
os.environ["AMD_ENABLE_LLPC"] = "1"
|
||||
|
||||
import sys
|
||||
import json
|
||||
import torch
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from PIL import PngImagePlugin
|
||||
from datetime import datetime as dt
|
||||
from dataclasses import dataclass
|
||||
from csv import DictWriter
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImagePipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,96 +25,6 @@ class Config:
|
||||
device: str
|
||||
|
||||
|
||||
# This has to come before importing cache objects
|
||||
if args.clear_all:
|
||||
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
|
||||
from glob import glob
|
||||
import shutil
|
||||
|
||||
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
|
||||
for vmfb in vmfbs:
|
||||
if os.path.exists(vmfb):
|
||||
os.remove(vmfb)
|
||||
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
|
||||
# TODO: Remove this once we have better weight updation logic.
|
||||
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
|
||||
for yaml in inference_yaml:
|
||||
if os.path.exists(yaml):
|
||||
os.remove(yaml)
|
||||
home = os.path.expanduser("~")
|
||||
if os.name == "nt": # Windows
|
||||
appdata = os.getenv("LOCALAPPDATA")
|
||||
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
|
||||
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
|
||||
elif os.name == "unix":
|
||||
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed):
|
||||
output_path = args.output_dir if args.output_dir else Path.cwd()
|
||||
generated_imgs_path = Path(output_path, "generated_imgs")
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
|
||||
out_img_name = (
|
||||
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
img_model = args.hf_model_id
|
||||
if args.ckpt_loc:
|
||||
img_model = os.path.basename(args.ckpt_loc)
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if args.write_metadata_to_png:
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
if args.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {args.output_img_format} is not supported yet."
|
||||
"Image saved as png instead. Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
new_entry = {
|
||||
"VARIANT": img_model,
|
||||
"SCHEDULER": args.scheduler,
|
||||
"PROMPT": args.prompts[0],
|
||||
"NEG_PROMPT": args.negative_prompts[0],
|
||||
"SEED": img_seed,
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"HEIGHT": args.height,
|
||||
"WIDTH": args.width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
}
|
||||
|
||||
with open(csv_path, "a") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
if args.save_metadata_to_json:
|
||||
del new_entry["OUTPUT"]
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
|
||||
txt2img_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
@@ -193,7 +94,7 @@ def txt2img_inf(
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if config_obj != new_config_obj:
|
||||
if not txt2img_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
@@ -201,8 +102,10 @@ def txt2img_inf(
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = ""
|
||||
args.use_tuned = True
|
||||
args.import_mlir = False
|
||||
args.img_path = None
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
@@ -216,6 +119,7 @@ def txt2img_inf(
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
@@ -223,11 +127,9 @@ def txt2img_inf(
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
if not txt2img_obj:
|
||||
sys.exit("text to image pipeline must not return a null value")
|
||||
|
||||
txt2img_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
@@ -262,8 +164,10 @@ def txt2img_inf(
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += (
|
||||
f"\nsteps={steps}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
)
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
@@ -271,6 +175,9 @@ def txt2img_inf(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
@@ -283,6 +190,7 @@ if __name__ == "__main__":
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
@@ -290,10 +198,11 @@ if __name__ == "__main__":
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
for run in range(args.runs):
|
||||
if run > 0:
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
@@ -323,7 +232,7 @@ if __name__ == "__main__":
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
# TODO: if using --runs=x txt2img_obj.log will output on each display every iteration infos from the start
|
||||
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
|
||||
@@ -15,12 +15,11 @@ datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torchvision')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('diffusers')
|
||||
datas += copy_metadata('transformers')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
@@ -30,8 +29,8 @@ datas += [
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
( 'src/utils/resources/opt_flags.json', 'resources' ),
|
||||
( 'src/utils/resources/base_model.json', 'resources' ),
|
||||
( 'web/css/*', 'css' ),
|
||||
( 'web/logos/*', 'logos' )
|
||||
( 'web/ui/css/*', 'ui/css' ),
|
||||
( 'web/ui/logos/*', 'logos' )
|
||||
]
|
||||
|
||||
binaries = []
|
||||
@@ -44,7 +43,7 @@ a = Analysis(
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
|
||||
hiddenimports=['shark', 'shark.shark_inference', 'apps'],
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
|
||||
@@ -15,12 +15,11 @@ datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torchvision')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('diffusers')
|
||||
datas += copy_metadata('transformers')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
@@ -42,7 +41,7 @@ a = Analysis(
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
|
||||
hiddenimports=['shark', 'shark.shark_inference', 'apps'],
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
|
||||
@@ -3,6 +3,13 @@ from apps.stable_diffusion.src.utils import (
|
||||
set_init_device_flags,
|
||||
prompt_examples,
|
||||
get_available_devices,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines import (
|
||||
Text2ImagePipeline,
|
||||
Image2ImagePipeline,
|
||||
InpaintPipeline,
|
||||
OutpaintPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines import Text2ImagePipeline
|
||||
from apps.stable_diffusion.src.schedulers import get_schedulers
|
||||
|
||||
@@ -2,6 +2,7 @@ from apps.stable_diffusion.src.models.model_wrappers import (
|
||||
SharkifyStableDiffusionModel,
|
||||
)
|
||||
from apps.stable_diffusion.src.models.opt_params import (
|
||||
get_vae_encode,
|
||||
get_vae,
|
||||
get_unet,
|
||||
get_clip,
|
||||
|
||||
@@ -2,8 +2,8 @@ from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
import safetensors.torch
|
||||
import traceback
|
||||
import re
|
||||
import sys
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
@@ -14,6 +14,8 @@ from apps.stable_diffusion.src.utils import (
|
||||
preprocessCKPT,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -28,15 +30,19 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
|
||||
elif shape[i] == "width":
|
||||
new_shape.append(width)
|
||||
elif isinstance(shape[i], str):
|
||||
mul_val = int(shape[i].split("*")[0])
|
||||
if "batch_size" in shape[i]:
|
||||
mul_val = int(shape[i].split("*")[0])
|
||||
new_shape.append(batch_size * mul_val)
|
||||
elif "height" in shape[i]:
|
||||
new_shape.append(height * mul_val)
|
||||
elif "width" in shape[i]:
|
||||
new_shape.append(width * mul_val)
|
||||
else:
|
||||
new_shape.append(shape[i])
|
||||
return new_shape
|
||||
|
||||
|
||||
# Get the input info for various models i.e. "unet", "clip", "vae".
|
||||
# Get the input info for various models i.e. "unet", "clip", "vae", "vae_encode".
|
||||
def get_input_info(model_info, max_len, width, height, batch_size):
|
||||
dtype_config = {"f32": torch.float32, "i64": torch.int64}
|
||||
input_map = defaultdict(list)
|
||||
@@ -66,6 +72,7 @@ class SharkifyStableDiffusionModel:
|
||||
self,
|
||||
model_id: str,
|
||||
custom_weights: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
max_len: int = 64,
|
||||
width: int = 512,
|
||||
@@ -73,6 +80,7 @@ class SharkifyStableDiffusionModel:
|
||||
batch_size: int = 1,
|
||||
use_base_vae: bool = False,
|
||||
use_tuned: bool = False,
|
||||
low_cpu_mem_usage: bool = False
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
@@ -86,6 +94,10 @@ class SharkifyStableDiffusionModel:
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
# TODO: remove the following line when stable-diffusion-2-1 works
|
||||
if self.model_id == "stabilityai/stable-diffusion-2-1":
|
||||
self.model_id = "stabilityai/stable-diffusion-2-1-base"
|
||||
self.custom_vae = custom_vae
|
||||
self.precision = precision
|
||||
self.base_vae = use_base_vae
|
||||
self.model_name = (
|
||||
@@ -102,17 +114,22 @@ class SharkifyStableDiffusionModel:
|
||||
self.use_tuned = use_tuned
|
||||
if use_tuned:
|
||||
self.model_name = self.model_name + "_tuned"
|
||||
# We need a better naming convention for the .vmfbs because despite
|
||||
# using the custom model variant the .vmfb names remain the same and
|
||||
# it'll always pick up the compiled .vmfb instead of compiling the
|
||||
# custom model.
|
||||
# So, currently, we add `self.model_id` in the `self.model_name` of
|
||||
# .vmfb file.
|
||||
# TODO: Have a better way of naming the vmfbs using self.model_name.
|
||||
model_name = re.sub(r"\W+", "_", self.model_id)
|
||||
if model_name[0] == "_":
|
||||
model_name = model_name[1:]
|
||||
self.model_name = self.model_name + "_" + model_name
|
||||
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
|
||||
def get_extended_name_for_all_model(self):
|
||||
model_name = {}
|
||||
sub_model_list = ["clip", "unet", "vae", "vae_encode"]
|
||||
for model in sub_model_list:
|
||||
sub_model = model
|
||||
model_config = self.model_name
|
||||
if "vae" == model:
|
||||
if self.custom_vae != "":
|
||||
model_config = model_config + get_path_stem(self.custom_vae)
|
||||
if self.base_vae:
|
||||
sub_model = "base_vae"
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
return model_name
|
||||
|
||||
def check_params(self, max_len, width, height):
|
||||
if not (max_len >= 32 and max_len <= 77):
|
||||
@@ -122,14 +139,57 @@ class SharkifyStableDiffusionModel:
|
||||
if not (height % 8 == 0 and height >= 384):
|
||||
sys.exit("height should be greater than 384 and multiple of 8")
|
||||
|
||||
def get_vae(self):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, base_vae=self.base_vae):
|
||||
def get_vae_encode(self):
|
||||
class VaeEncodeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
latents = self.vae.encode(input).latent_dist.sample()
|
||||
return 0.18215 * latents
|
||||
|
||||
vae_encode = VaeEncodeModel()
|
||||
inputs = tuple(self.inputs["vae_encode"])
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
shark_vae_encode = compile_through_fx(
|
||||
vae_encode,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae_encode"],
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
)
|
||||
return shark_vae_encode
|
||||
|
||||
def get_vae(self):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.vae = None
|
||||
if custom_vae == "":
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif not isinstance(custom_vae, dict):
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.vae.load_state_dict(custom_vae)
|
||||
self.base_vae = base_vae
|
||||
|
||||
def forward(self, input):
|
||||
@@ -142,30 +202,35 @@ class SharkifyStableDiffusionModel:
|
||||
x = x * 255.0
|
||||
return x.round()
|
||||
|
||||
vae = VaeModel()
|
||||
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
vae_name = "base_vae" if self.base_vae else "vae"
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=vae_name + self.model_name,
|
||||
model_name=self.model_name["vae"],
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
def get_unet(self):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
if(args.attention_slicing is not None and args.attention_slicing != "none"):
|
||||
if(args.attention_slicing.isdigit()):
|
||||
self.unet.set_attention_slice(int(args.attention_slicing))
|
||||
else:
|
||||
self.unet.set_attention_slice(args.attention_slicing)
|
||||
|
||||
def forward(
|
||||
self, latent, timestep, text_embedding, guidance_scale
|
||||
@@ -181,14 +246,14 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel()
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
input_mask = [True, True, True, False]
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
model_name="unet" + self.model_name,
|
||||
model_name=self.model_name["unet"],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
@@ -198,28 +263,50 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
tuple(self.inputs["clip"]),
|
||||
model_name="clip" + self.model_name,
|
||||
model_name=self.model_name["clip"],
|
||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
def process_custom_vae(self):
|
||||
custom_vae = self.custom_vae.lower()
|
||||
if not custom_vae.endswith((".ckpt", ".safetensors")):
|
||||
return self.custom_vae
|
||||
try:
|
||||
preprocessCKPT(self.custom_vae)
|
||||
return get_path_to_diffusers_checkpoint(self.custom_vae)
|
||||
except:
|
||||
print("Processing standalone Vae checkpoint")
|
||||
vae_checkpoint = None
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
if custom_vae.endswith(".ckpt"):
|
||||
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
|
||||
else:
|
||||
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
|
||||
if "state_dict" in vae_checkpoint:
|
||||
vae_checkpoint = vae_checkpoint["state_dict"]
|
||||
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict
|
||||
|
||||
|
||||
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
|
||||
# configiration.
|
||||
def compile_all(self, base_model_id):
|
||||
def compile_all(self, base_model_id, need_vae_encode):
|
||||
self.inputs = get_input_info(
|
||||
base_models[base_model_id],
|
||||
self.max_len,
|
||||
@@ -228,17 +315,22 @@ class SharkifyStableDiffusionModel:
|
||||
self.batch_size,
|
||||
)
|
||||
compiled_unet = self.get_unet()
|
||||
if self.custom_vae != "":
|
||||
print("Plugging in custom Vae")
|
||||
compiled_vae = self.get_vae()
|
||||
compiled_clip = self.get_clip()
|
||||
|
||||
if need_vae_encode:
|
||||
compiled_vae_encode = self.get_vae_encode()
|
||||
return compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode
|
||||
|
||||
return compiled_clip, compiled_unet, compiled_vae
|
||||
|
||||
def __call__(self):
|
||||
# Step 1:
|
||||
# -- Fetch all vmfbs for the model, if present, else delete the lot.
|
||||
vmfbs = fetch_or_delete_vmfbs(
|
||||
self.model_name, self.base_vae, self.precision
|
||||
)
|
||||
need_vae_encode = args.img_path is not None
|
||||
self.model_name = self.get_extended_name_for_all_model()
|
||||
vmfbs = fetch_or_delete_vmfbs(self.model_name, need_vae_encode, self.precision)
|
||||
if vmfbs[0]:
|
||||
# -- If all vmfbs are indeed present, we also try and fetch the base
|
||||
# model configuration for running SD with custom checkpoints.
|
||||
@@ -247,6 +339,8 @@ class SharkifyStableDiffusionModel:
|
||||
if args.hf_model_id == "":
|
||||
sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.")
|
||||
print("Loaded vmfbs from cache and successfully fetched base model configuration.")
|
||||
if not need_vae_encode:
|
||||
return vmfbs[:3]
|
||||
return vmfbs
|
||||
|
||||
# Step 2:
|
||||
@@ -261,12 +355,15 @@ class SharkifyStableDiffusionModel:
|
||||
preprocessCKPT(self.custom_weights)
|
||||
else:
|
||||
model_to_run = args.hf_model_id
|
||||
# For custom Vae user can provide either the repo-id or a checkpoint file,
|
||||
# and for a checkpoint file we'd need to process it via Diffusers' script.
|
||||
self.custom_vae = self.process_custom_vae()
|
||||
base_model_fetched = fetch_and_update_base_model_id(model_to_run)
|
||||
if base_model_fetched != "":
|
||||
print("Compiling all the models with the fetched base model configuration.")
|
||||
if args.ckpt_loc != "":
|
||||
args.hf_model_id = base_model_fetched
|
||||
return self.compile_all(base_model_fetched)
|
||||
return self.compile_all(base_model_fetched, need_vae_encode)
|
||||
|
||||
# Step 3:
|
||||
# -- This is the retry mechanism where the base model's configuration is not
|
||||
@@ -274,10 +371,11 @@ class SharkifyStableDiffusionModel:
|
||||
print("Inferring base model configuration.")
|
||||
for model_id in base_models:
|
||||
try:
|
||||
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id)
|
||||
if need_vae_encode:
|
||||
compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode = self.compile_all(model_id, need_vae_encode)
|
||||
else:
|
||||
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id, need_vae_encode)
|
||||
except Exception as e:
|
||||
if args.enable_stack_trace:
|
||||
traceback.print_exc()
|
||||
print("Retrying with a different base model configuration")
|
||||
continue
|
||||
# -- Once a successful compilation has taken place we'd want to store
|
||||
@@ -289,7 +387,14 @@ class SharkifyStableDiffusionModel:
|
||||
# the knowledge of base model id accordingly into `args.hf_model_id`.
|
||||
if args.ckpt_loc != "":
|
||||
args.hf_model_id = model_id
|
||||
if need_vae_encode:
|
||||
return (
|
||||
compiled_clip,
|
||||
compiled_unet,
|
||||
compiled_vae,
|
||||
compiled_vae_encode,
|
||||
)
|
||||
return compiled_clip, compiled_unet, compiled_vae
|
||||
sys.exit(
|
||||
"Cannot compile the model. Please re-run the command with `--enable_stack_trace` flag and create an issue with detailed log at https://github.com/nod-ai/SHARK/issues"
|
||||
"Cannot compile the model. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
|
||||
)
|
||||
|
||||
@@ -9,13 +9,15 @@ from apps.stable_diffusion.src.utils import (
|
||||
|
||||
|
||||
hf_model_variant_map = {
|
||||
"Linaqruf/anything-v3.0": ["anythingv3", "v2_1base"],
|
||||
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v2_1base"],
|
||||
"prompthero/openjourney": ["openjourney", "v2_1base"],
|
||||
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v2_1base"],
|
||||
"Linaqruf/anything-v3.0": ["anythingv3", "v1_4"],
|
||||
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v1_4"],
|
||||
"prompthero/openjourney": ["openjourney", "v1_4"],
|
||||
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v1_4"],
|
||||
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
|
||||
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
|
||||
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
|
||||
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
|
||||
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
|
||||
}
|
||||
|
||||
|
||||
@@ -52,6 +54,23 @@ def get_unet():
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae_encode():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "vae", is_tuned, args.precision
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
|
||||
@@ -1,3 +1,12 @@
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
|
||||
Text2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
|
||||
InpaintPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpaint import (
|
||||
OutpaintPipeline,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from random import randint
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
|
||||
class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae_encode: SharkInference,
|
||||
vae: SharkInference,
|
||||
text_encoder: SharkInference,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: SharkInference,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
||||
self.vae_encode = vae_encode
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
dtype,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre process image
|
||||
image = image.resize((width, height))
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
|
||||
# set scheduler steps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
init_timestep = min(
|
||||
int(num_inference_steps * strength), num_inference_steps
|
||||
)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
# timesteps reduced as per strength
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
# new number of steps to be used as per strength will be
|
||||
# num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# image encode
|
||||
latents = self.encode_image((image_arr,))
|
||||
latents = torch.from_numpy(latents).to(dtype)
|
||||
# add noise to data
|
||||
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents, noise, timesteps[0].repeat(1)
|
||||
)
|
||||
|
||||
return latents, timesteps
|
||||
|
||||
def encode_image(self, input_image):
|
||||
vae_encode_start = time.time()
|
||||
latents = self.vae_encode("forward", input_image)
|
||||
vae_inf_time = (time.time() - vae_encode_start) * 1000
|
||||
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
return latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get text embeddings from prompts
|
||||
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Prepare input image latent
|
||||
image_latents, final_timesteps = self.prepare_image_latents(
|
||||
image=image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=image_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=final_timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
|
||||
return all_imgs
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
|
||||
class InpaintPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae_encode: SharkInference,
|
||||
vae: SharkInference,
|
||||
text_encoder: SharkInference,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: SharkInference,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
||||
self.vae_encode = vae_encode
|
||||
|
||||
def prepare_mask_and_masked_image(self, image, mask, height, width):
|
||||
# preprocess image
|
||||
image = image.resize((width, height))
|
||||
mask = mask.resize((width, height))
|
||||
if isinstance(image, (Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], Image.Image):
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
|
||||
mask = np.concatenate(
|
||||
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
||||
)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
):
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // 8, width // 8)
|
||||
)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
masked_image = masked_image.to(dtype)
|
||||
masked_image_latents = self.vae_encode("forward", (masked_image,))
|
||||
masked_image_latents = torch.from_numpy(masked_image_latents)
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
return mask, masked_image_latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
mask_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings from prompts
|
||||
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Preprocess mask and image
|
||||
mask, masked_image = self.prepare_mask_and_masked_image(
|
||||
image, mask_image, height, width
|
||||
)
|
||||
|
||||
# Prepare mask latent variables
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask=mask,
|
||||
masked_image=masked_image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
mask=mask,
|
||||
masked_image_latents=masked_image_latents,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
|
||||
return all_imgs
|
||||
@@ -0,0 +1,542 @@
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from PIL import Image, ImageDraw, ImageFilter
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
import math
|
||||
|
||||
|
||||
class OutpaintPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae_encode: SharkInference,
|
||||
vae: SharkInference,
|
||||
text_encoder: SharkInference,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: SharkInference,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
||||
self.vae_encode = vae_encode
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_mask_and_masked_image(
|
||||
self, image, mask, mask_blur, width, height
|
||||
):
|
||||
if mask_blur > 0:
|
||||
mask = mask.filter(ImageFilter.GaussianBlur(mask_blur))
|
||||
image = image.resize((width, height))
|
||||
mask = mask.resize((width, height))
|
||||
|
||||
# preprocess image
|
||||
if isinstance(image, (Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], Image.Image):
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
|
||||
mask = np.concatenate(
|
||||
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
||||
)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
):
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // 8, width // 8)
|
||||
)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
masked_image = masked_image.to(dtype)
|
||||
masked_image_latents = self.vae_encode("forward", (masked_image,))
|
||||
masked_image_latents = torch.from_numpy(masked_image_latents)
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
return mask, masked_image_latents
|
||||
|
||||
def get_matched_noise(
|
||||
self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05
|
||||
):
|
||||
# helper fft routines that keep ortho normalization and auto-shift before and after fft
|
||||
def _fft2(data):
|
||||
if data.ndim > 2: # has channels
|
||||
out_fft = np.zeros(
|
||||
(data.shape[0], data.shape[1], data.shape[2]),
|
||||
dtype=np.complex128,
|
||||
)
|
||||
for c in range(data.shape[2]):
|
||||
c_data = data[:, :, c]
|
||||
out_fft[:, :, c] = np.fft.fft2(
|
||||
np.fft.fftshift(c_data), norm="ortho"
|
||||
)
|
||||
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
|
||||
else: # one channel
|
||||
out_fft = np.zeros(
|
||||
(data.shape[0], data.shape[1]), dtype=np.complex128
|
||||
)
|
||||
out_fft[:, :] = np.fft.fft2(
|
||||
np.fft.fftshift(data), norm="ortho"
|
||||
)
|
||||
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
|
||||
|
||||
return out_fft
|
||||
|
||||
def _ifft2(data):
|
||||
if data.ndim > 2: # has channels
|
||||
out_ifft = np.zeros(
|
||||
(data.shape[0], data.shape[1], data.shape[2]),
|
||||
dtype=np.complex128,
|
||||
)
|
||||
for c in range(data.shape[2]):
|
||||
c_data = data[:, :, c]
|
||||
out_ifft[:, :, c] = np.fft.ifft2(
|
||||
np.fft.fftshift(c_data), norm="ortho"
|
||||
)
|
||||
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
|
||||
else: # one channel
|
||||
out_ifft = np.zeros(
|
||||
(data.shape[0], data.shape[1]), dtype=np.complex128
|
||||
)
|
||||
out_ifft[:, :] = np.fft.ifft2(
|
||||
np.fft.fftshift(data), norm="ortho"
|
||||
)
|
||||
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
|
||||
|
||||
return out_ifft
|
||||
|
||||
def _get_gaussian_window(width, height, std=3.14, mode=0):
|
||||
window_scale_x = float(width / min(width, height))
|
||||
window_scale_y = float(height / min(width, height))
|
||||
|
||||
window = np.zeros((width, height))
|
||||
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
|
||||
for y in range(height):
|
||||
fy = (y / height * 2.0 - 1.0) * window_scale_y
|
||||
if mode == 0:
|
||||
window[:, y] = np.exp(-(x**2 + fy**2) * std)
|
||||
else:
|
||||
window[:, y] = (
|
||||
1 / ((x**2 + 1.0) * (fy**2 + 1.0))
|
||||
) ** (std / 3.14)
|
||||
|
||||
return window
|
||||
|
||||
def _get_masked_window_rgb(np_mask_grey, hardness=1.0):
|
||||
np_mask_rgb = np.zeros(
|
||||
(np_mask_grey.shape[0], np_mask_grey.shape[1], 3)
|
||||
)
|
||||
if hardness != 1.0:
|
||||
hardened = np_mask_grey[:] ** hardness
|
||||
else:
|
||||
hardened = np_mask_grey[:]
|
||||
for c in range(3):
|
||||
np_mask_rgb[:, :, c] = hardened[:]
|
||||
return np_mask_rgb
|
||||
|
||||
def _match_cumulative_cdf(source, template):
|
||||
src_values, src_unique_indices, src_counts = np.unique(
|
||||
source.ravel(), return_inverse=True, return_counts=True
|
||||
)
|
||||
tmpl_values, tmpl_counts = np.unique(
|
||||
template.ravel(), return_counts=True
|
||||
)
|
||||
|
||||
# calculate normalized quantiles for each array
|
||||
src_quantiles = np.cumsum(src_counts) / source.size
|
||||
tmpl_quantiles = np.cumsum(tmpl_counts) / template.size
|
||||
|
||||
interp_a_values = np.interp(
|
||||
src_quantiles, tmpl_quantiles, tmpl_values
|
||||
)
|
||||
return interp_a_values[src_unique_indices].reshape(source.shape)
|
||||
|
||||
def _match_histograms(image, reference):
|
||||
if image.ndim != reference.ndim:
|
||||
raise ValueError(
|
||||
"Image and reference must have the same number of channels."
|
||||
)
|
||||
|
||||
if image.shape[-1] != reference.shape[-1]:
|
||||
raise ValueError(
|
||||
"Number of channels in the input image and reference image must match!"
|
||||
)
|
||||
|
||||
matched = np.empty(image.shape, dtype=image.dtype)
|
||||
for channel in range(image.shape[-1]):
|
||||
matched_channel = _match_cumulative_cdf(
|
||||
image[..., channel], reference[..., channel]
|
||||
)
|
||||
matched[..., channel] = matched_channel
|
||||
|
||||
matched = matched.astype(np.float64, copy=False)
|
||||
return matched
|
||||
|
||||
width = _np_src_image.shape[0]
|
||||
height = _np_src_image.shape[1]
|
||||
num_channels = _np_src_image.shape[2]
|
||||
|
||||
np_src_image = _np_src_image[:] * (1.0 - np_mask_rgb)
|
||||
np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
|
||||
img_mask = np_mask_grey > 1e-6
|
||||
ref_mask = np_mask_grey < 1e-3
|
||||
|
||||
# rather than leave the masked area black, we get better results from fft by filling the average unmasked color
|
||||
windowed_image = _np_src_image * (
|
||||
1.0 - _get_masked_window_rgb(np_mask_grey)
|
||||
)
|
||||
windowed_image /= np.max(windowed_image)
|
||||
windowed_image += np.average(_np_src_image) * np_mask_rgb
|
||||
|
||||
src_fft = _fft2(
|
||||
windowed_image
|
||||
) # get feature statistics from masked src img
|
||||
src_dist = np.absolute(src_fft)
|
||||
src_phase = src_fft / src_dist
|
||||
|
||||
# create a generator with a static seed to make outpainting deterministic / only follow global seed
|
||||
rng = np.random.default_rng(0)
|
||||
|
||||
noise_window = _get_gaussian_window(
|
||||
width, height, mode=1
|
||||
) # start with simple gaussian noise
|
||||
noise_rgb = rng.random((width, height, num_channels))
|
||||
noise_grey = np.sum(noise_rgb, axis=2) / 3.0
|
||||
# the colorfulness of the starting noise is blended to greyscale with a parameter
|
||||
noise_rgb *= color_variation
|
||||
for c in range(num_channels):
|
||||
noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
|
||||
|
||||
noise_fft = _fft2(noise_rgb)
|
||||
for c in range(num_channels):
|
||||
noise_fft[:, :, c] *= noise_window
|
||||
noise_rgb = np.real(_ifft2(noise_fft))
|
||||
shaped_noise_fft = _fft2(noise_rgb)
|
||||
shaped_noise_fft[:, :, :] = (
|
||||
np.absolute(shaped_noise_fft[:, :, :]) ** 2
|
||||
* (src_dist**noise_q)
|
||||
* src_phase
|
||||
) # perform the actual shaping
|
||||
|
||||
# color_variation
|
||||
brightness_variation = 0.0
|
||||
contrast_adjusted_np_src = (
|
||||
_np_src_image[:] * (brightness_variation + 1.0)
|
||||
- brightness_variation * 2.0
|
||||
)
|
||||
|
||||
shaped_noise = np.real(_ifft2(shaped_noise_fft))
|
||||
shaped_noise -= np.min(shaped_noise)
|
||||
shaped_noise /= np.max(shaped_noise)
|
||||
shaped_noise[img_mask, :] = _match_histograms(
|
||||
shaped_noise[img_mask, :] ** 1.0,
|
||||
contrast_adjusted_np_src[ref_mask, :],
|
||||
)
|
||||
shaped_noise = (
|
||||
_np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
|
||||
)
|
||||
|
||||
matched_noise = shaped_noise[:]
|
||||
|
||||
return np.clip(matched_noise, 0.0, 1.0)
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
pixels,
|
||||
mask_blur,
|
||||
is_left,
|
||||
is_right,
|
||||
is_top,
|
||||
is_bottom,
|
||||
noise_q,
|
||||
color_variation,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings from prompts
|
||||
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
process_width = width
|
||||
process_height = height
|
||||
left = pixels if is_left else 0
|
||||
right = pixels if is_right else 0
|
||||
up = pixels if is_top else 0
|
||||
down = pixels if is_bottom else 0
|
||||
target_w = math.ceil((image.width + left + right) / 64) * 64
|
||||
target_h = math.ceil((image.height + up + down) / 64) * 64
|
||||
|
||||
if left > 0:
|
||||
left = left * (target_w - image.width) // (left + right)
|
||||
if right > 0:
|
||||
right = target_w - image.width - left
|
||||
if up > 0:
|
||||
up = up * (target_h - image.height) // (up + down)
|
||||
if down > 0:
|
||||
down = target_h - image.height - up
|
||||
|
||||
def expand(
|
||||
init_img,
|
||||
expand_pixels,
|
||||
is_left=False,
|
||||
is_right=False,
|
||||
is_top=False,
|
||||
is_bottom=False,
|
||||
):
|
||||
is_horiz = is_left or is_right
|
||||
is_vert = is_top or is_bottom
|
||||
pixels_horiz = expand_pixels if is_horiz else 0
|
||||
pixels_vert = expand_pixels if is_vert else 0
|
||||
|
||||
res_w = init_img.width + pixels_horiz
|
||||
res_h = init_img.height + pixels_vert
|
||||
process_res_w = math.ceil(res_w / 64) * 64
|
||||
process_res_h = math.ceil(res_h / 64) * 64
|
||||
|
||||
img = Image.new("RGB", (process_res_w, process_res_h))
|
||||
img.paste(
|
||||
init_img,
|
||||
(pixels_horiz if is_left else 0, pixels_vert if is_top else 0),
|
||||
)
|
||||
|
||||
msk = Image.new("RGB", (process_res_w, process_res_h), "white")
|
||||
draw = ImageDraw.Draw(msk)
|
||||
draw.rectangle(
|
||||
(
|
||||
expand_pixels + mask_blur if is_left else 0,
|
||||
expand_pixels + mask_blur if is_top else 0,
|
||||
msk.width - expand_pixels - mask_blur
|
||||
if is_right
|
||||
else res_w,
|
||||
msk.height - expand_pixels - mask_blur
|
||||
if is_bottom
|
||||
else res_h,
|
||||
),
|
||||
fill="black",
|
||||
)
|
||||
|
||||
np_image = (np.asarray(img) / 255.0).astype(np.float64)
|
||||
np_mask = (np.asarray(msk) / 255.0).astype(np.float64)
|
||||
noised = self.get_matched_noise(
|
||||
np_image, np_mask, noise_q, color_variation
|
||||
)
|
||||
output_image = Image.fromarray(
|
||||
np.clip(noised * 255.0, 0.0, 255.0).astype(np.uint8),
|
||||
mode="RGB",
|
||||
)
|
||||
|
||||
target_width = (
|
||||
min(width, init_img.width + pixels_horiz)
|
||||
if is_horiz
|
||||
else img.width
|
||||
)
|
||||
target_height = (
|
||||
min(height, init_img.height + pixels_vert)
|
||||
if is_vert
|
||||
else img.height
|
||||
)
|
||||
crop_region = (
|
||||
0 if is_left else output_image.width - target_width,
|
||||
0 if is_top else output_image.height - target_height,
|
||||
target_width if is_left else output_image.width,
|
||||
target_height if is_top else output_image.height,
|
||||
)
|
||||
mask_to_process = msk.crop(crop_region)
|
||||
image_to_process = output_image.crop(crop_region)
|
||||
|
||||
# Preprocess mask and image
|
||||
mask, masked_image = self.prepare_mask_and_masked_image(
|
||||
image_to_process, mask_to_process, mask_blur, width, height
|
||||
)
|
||||
|
||||
# Prepare mask latent variables
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask=mask,
|
||||
masked_image=masked_image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
mask=mask,
|
||||
masked_image_latents=masked_image_latents,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
|
||||
res_img = all_imgs[0].resize(
|
||||
(image_to_process.width, image_to_process.height)
|
||||
)
|
||||
output_image.paste(
|
||||
res_img,
|
||||
(
|
||||
0 if is_left else output_image.width - res_img.width,
|
||||
0 if is_top else output_image.height - res_img.height,
|
||||
),
|
||||
)
|
||||
output_image = output_image.crop((0, 0, res_w, res_h))
|
||||
|
||||
return output_image
|
||||
|
||||
img = image.resize((width, height))
|
||||
if left > 0:
|
||||
img = expand(img, left, is_left=True)
|
||||
if right > 0:
|
||||
img = expand(img, right, is_right=True)
|
||||
if up > 0:
|
||||
img = expand(img, up, is_top=True)
|
||||
if down > 0:
|
||||
img = expand(img, down, is_bottom=True)
|
||||
|
||||
return [img]
|
||||
@@ -9,9 +9,11 @@ from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
@@ -30,10 +32,12 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
@@ -8,14 +9,17 @@ from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
get_vae,
|
||||
get_clip,
|
||||
get_unet,
|
||||
@@ -38,10 +42,12 @@ class StableDiffusionPipeline:
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
],
|
||||
):
|
||||
self.vae = vae
|
||||
@@ -112,6 +118,8 @@ class StableDiffusionPipeline:
|
||||
total_timesteps,
|
||||
dtype,
|
||||
cpu_scheduling,
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
step_time_sum = 0
|
||||
@@ -122,6 +130,15 @@ class StableDiffusionPipeline:
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, t)
|
||||
if mask is not None and masked_image_latents is not None:
|
||||
latent_model_input = torch.cat(
|
||||
[
|
||||
torch.from_numpy(np.asarray(latent_model_input)),
|
||||
mask,
|
||||
masked_image_latents,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype)
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
@@ -169,14 +186,17 @@ class StableDiffusionPipeline:
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
],
|
||||
import_mlir: bool,
|
||||
model_id: str,
|
||||
ckpt_loc: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
max_length: int,
|
||||
batch_size: int,
|
||||
@@ -184,13 +204,13 @@ class StableDiffusionPipeline:
|
||||
width: int,
|
||||
use_base_vae: bool,
|
||||
use_tuned: bool,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
):
|
||||
if import_mlir:
|
||||
# TODO: Delet this when on-the-fly tuning of models work.
|
||||
use_tuned = False
|
||||
mlir_import = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
ckpt_loc,
|
||||
custom_vae,
|
||||
precision,
|
||||
max_len=max_length,
|
||||
batch_size=batch_size,
|
||||
@@ -198,9 +218,59 @@ class StableDiffusionPipeline:
|
||||
width=width,
|
||||
use_base_vae=use_base_vae,
|
||||
use_tuned=use_tuned,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
"InpaintPipeline",
|
||||
"OutpaintPipeline",
|
||||
]:
|
||||
clip, unet, vae, vae_encode = mlir_import()
|
||||
return cls(
|
||||
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
|
||||
)
|
||||
clip, unet, vae = mlir_import()
|
||||
return cls(vae, clip, get_tokenizer(), unet, scheduler)
|
||||
try:
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
"InpaintPipeline",
|
||||
"OutpaintPipeline",
|
||||
]:
|
||||
return cls(
|
||||
get_vae_encode(),
|
||||
get_vae(),
|
||||
get_clip(),
|
||||
get_tokenizer(),
|
||||
get_unet(),
|
||||
scheduler,
|
||||
)
|
||||
return cls(
|
||||
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
|
||||
)
|
||||
except:
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
mlir_import = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
ckpt_loc,
|
||||
custom_vae,
|
||||
precision,
|
||||
max_len=max_length,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
use_base_vae=use_base_vae,
|
||||
use_tuned=use_tuned,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
"InpaintPipeline",
|
||||
"OutpaintPipeline",
|
||||
]:
|
||||
clip, unet, vae, vae_encode = mlir_import()
|
||||
return cls(
|
||||
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
|
||||
)
|
||||
clip, unet, vae = mlir_import()
|
||||
return cls(vae, clip, get_tokenizer(), unet, scheduler)
|
||||
return cls(
|
||||
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
|
||||
)
|
||||
|
||||
@@ -3,8 +3,10 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
@@ -17,6 +19,10 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
@@ -41,6 +47,10 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"SharkEulerDiscrete"
|
||||
] = SharkEulerDiscreteScheduler.from_pretrained(
|
||||
|
||||
@@ -87,7 +87,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if args.import_mlir:
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model = compile_through_fx(
|
||||
scaling_model,
|
||||
@@ -105,15 +105,28 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
if args.import_mlir:
|
||||
_import(self)
|
||||
|
||||
else:
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
|
||||
)
|
||||
try:
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_step_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"failed to download model, falling back and using import_mlir"
|
||||
)
|
||||
args.import_mlir = True
|
||||
_import(self)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
|
||||
@@ -24,4 +24,8 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
sanitize_seed,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
@@ -29,6 +29,14 @@
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
@@ -77,6 +85,126 @@
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"height","width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"2*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
}
|
||||
},
|
||||
"runwayml/stable-diffusion-inpainting": {
|
||||
"unet": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
9,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"height","width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"2*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-2-inpainting": {
|
||||
"unet": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
9,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
1024
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
|
||||
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
|
||||
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
|
||||
"stablediffusion/inpaint_v1":"runwayml/stable-diffusion-inpainting",
|
||||
"stablediffusion/inpaint_v2":"stabilityai/stable-diffusion-2-inpainting",
|
||||
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
|
||||
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
|
||||
"openjourney/v1_4":"prompthero/openjourney",
|
||||
|
||||
@@ -22,8 +22,6 @@
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
|
||||
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
|
||||
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
|
||||
@@ -42,41 +40,52 @@
|
||||
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
|
||||
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
|
||||
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
|
||||
"anythingv3/v2_1base/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
|
||||
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
|
||||
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
|
||||
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
|
||||
"anythingv3/v2_1base/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
|
||||
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
|
||||
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
|
||||
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
|
||||
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
|
||||
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
|
||||
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
|
||||
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
|
||||
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
|
||||
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
|
||||
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
|
||||
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
|
||||
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
|
||||
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
|
||||
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
|
||||
"analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
|
||||
"openjourney/v2_1base/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
|
||||
"openjourney/v2_1base/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
|
||||
"openjourney/v2_1base/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
|
||||
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
|
||||
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
|
||||
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
|
||||
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
|
||||
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
|
||||
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
|
||||
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
|
||||
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
|
||||
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
|
||||
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
|
||||
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
|
||||
"stablediffusion/inpaint_v1/unet/fp16/length_77/untuned":"unet_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v1/unet/fp32/length_77/untuned":"unet_inpaint_fp32",
|
||||
"stablediffusion/inpaint_v1/vae_encode/fp16/length_77/untuned":"vae_encode_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v1/vae_encode/fp32/length_77/untuned":"vae_encode_inpaint_fp32",
|
||||
"stablediffusion/inpaint_v1/vae/fp16/length_77/untuned":"vae_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v1/vae/fp32/length_77/untuned":"vae_inpaint_fp32",
|
||||
"stablediffusion/inpaint_v1/clip/fp32/length_77/untuned":"clip_inpaint_fp32",
|
||||
"stablediffusion/inpaint_v2/unet/fp16/length_77/untuned":"unet_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v2/vae_encode/fp16/length_77/untuned":"vae_encode_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v2/vae/fp16/length_77/untuned":"vae_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v2/clip/fp32/length_77/untuned":"clip_inpaint_fp32",
|
||||
"anythingv3/v1_4/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
|
||||
"anythingv3/v1_4/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
|
||||
"anythingv3/v1_4/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
|
||||
"anythingv3/v1_4/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
|
||||
"anythingv3/v1_4/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
|
||||
"anythingv3/v1_4/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
|
||||
"anythingv3/v1_4/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
|
||||
"anythingv3/v1_4/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
|
||||
"anythingv3/v1_4/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
|
||||
"anythingv3/v1_4/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
|
||||
"anythingv3/v1_4/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
|
||||
"analogdiffusion/v1_4/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
|
||||
"analogdiffusion/v1_4/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
|
||||
"analogdiffusion/v1_4/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
|
||||
"analogdiffusion/v1_4/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
|
||||
"analogdiffusion/v1_4/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
|
||||
"analogdiffusion/v1_4/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
|
||||
"analogdiffusion/v1_4/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
|
||||
"analogdiffusion/v1_4/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
|
||||
"analogdiffusion/v1_4/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
|
||||
"analogdiffusion/v1_4/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
|
||||
"analogdiffusion/v1_4/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
|
||||
"openjourney/v1_4/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
|
||||
"openjourney/v1_4/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
|
||||
"openjourney/v1_4/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
|
||||
"openjourney/v1_4/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
|
||||
"openjourney/v1_4/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
|
||||
"openjourney/v1_4/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
|
||||
"openjourney/v1_4/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
|
||||
"dreamlike/v1_4/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
|
||||
"dreamlike/v1_4/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
|
||||
"dreamlike/v1_4/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
|
||||
"dreamlike/v1_4/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
|
||||
"dreamlike/v1_4/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
|
||||
"dreamlike/v1_4/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
|
||||
"dreamlike/v1_4/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -45,12 +45,12 @@
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,22 @@ def get_device():
|
||||
return device
|
||||
|
||||
|
||||
def get_device_args():
|
||||
device = get_device()
|
||||
device_spec_args = []
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
gpu_flags = get_iree_gpu_args()
|
||||
for flag in gpu_flags:
|
||||
device_spec_args.append(flag)
|
||||
elif device == "vulkan":
|
||||
device_spec_args.append(
|
||||
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
|
||||
)
|
||||
return device, device_spec_args
|
||||
|
||||
|
||||
# Download the model (Unet or VAE fp16) from shark_tank
|
||||
def load_model_from_tank():
|
||||
from apps.stable_diffusion.src.models import (
|
||||
@@ -54,7 +70,7 @@ def load_winograd_configs():
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
config_name = f"{args.annotation_model}_winograd_{device}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
winograd_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
winograd_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading Winograd config file from ", winograd_config_dir)
|
||||
download_public_file(full_gs_url, winograd_config_dir, True)
|
||||
return winograd_config_dir
|
||||
@@ -62,20 +78,42 @@ def load_winograd_configs():
|
||||
|
||||
def load_lower_configs():
|
||||
from apps.stable_diffusion.src.models import get_variant_version
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
)
|
||||
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
base_model_id = args.hf_model_id
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
if base_model_id == "runwayml/stable-diffusion-v1-5":
|
||||
base_model_id = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
variant, version = get_variant_version(base_model_id)
|
||||
|
||||
config_bucket = "gs://shark_tank/sd_tuned_configs/"
|
||||
|
||||
device, device_spec_args = get_device_args()
|
||||
spec = ""
|
||||
if device_spec_args:
|
||||
spec = device_spec_args[-1].split("=")[-1].strip()
|
||||
if device == "vulkan":
|
||||
spec = spec.split("-")[0]
|
||||
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
config_version = version
|
||||
if variant in ["anythingv3", "analogdiffusion"]:
|
||||
args.max_length = 77
|
||||
config_version = "v1_4"
|
||||
if args.annotation_model == "vae":
|
||||
args.max_length = 77
|
||||
device = get_device()
|
||||
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json"
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
config_name = (
|
||||
f"{args.annotation_model}_{args.precision}_{device}.json"
|
||||
)
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
|
||||
else:
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
|
||||
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading lowering config file from ", lowering_config_dir)
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
return lowering_config_dir
|
||||
@@ -83,13 +121,6 @@ def load_lower_configs():
|
||||
|
||||
# Annotate the model with Winograd attribute on selected conv ops
|
||||
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
|
||||
if model_name.split("_")[-1] != "tuned":
|
||||
out_file_path = (
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
)
|
||||
else:
|
||||
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
|
||||
with create_context() as ctx:
|
||||
winograd_model = model_annotation(
|
||||
ctx,
|
||||
@@ -103,59 +134,41 @@ def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
|
||||
winograd_model.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(winograd_model))
|
||||
f.close()
|
||||
return bytecode, out_file_path
|
||||
if args.save_annotation:
|
||||
if model_name.split("_")[-1] != "tuned":
|
||||
out_file_path = os.path.join(
|
||||
args.annotation_output, model_name + "_tuned_torch.mlir"
|
||||
)
|
||||
else:
|
||||
out_file_path = os.path.join(
|
||||
args.annotation_output, model_name + "_torch.mlir"
|
||||
)
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(winograd_model))
|
||||
f.close()
|
||||
|
||||
return bytecode
|
||||
|
||||
|
||||
def dump_after_mlir(input_mlir, model_name, use_winograd):
|
||||
def dump_after_mlir(input_mlir, use_winograd):
|
||||
import iree.compiler as ireec
|
||||
|
||||
device, device_spec_args = get_device_args()
|
||||
if use_winograd:
|
||||
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline='builtin.module"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
|
||||
"iree-linalg-ext-convert-conv2d-to-winograd))' "
|
||||
)
|
||||
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
else:
|
||||
dump_after = "iree-preprocessing-pad-linalg-ops"
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline='builtin.module"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32}))' "
|
||||
)
|
||||
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
|
||||
device_spec_args = ""
|
||||
device = get_device()
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
gpu_flags = get_iree_gpu_args()
|
||||
for flag in gpu_flags:
|
||||
device_spec_args += flag + " "
|
||||
elif device == "vulkan":
|
||||
device_spec_args = (
|
||||
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
|
||||
)
|
||||
print("Applying tuned configs on", model_name)
|
||||
|
||||
run_cmd(
|
||||
f"iree-compile {input_mlir} "
|
||||
"--iree-input-type=tm_tensor "
|
||||
f"--iree-hal-target-backends={iree_target_map(device)} "
|
||||
f"{device_spec_args}"
|
||||
f"{preprocess_flag}"
|
||||
"--iree-stream-resource-index-bits=64 "
|
||||
"--iree-vm-target-index-bits=64 "
|
||||
f"--mlir-print-ir-after={dump_after} "
|
||||
"--compile-to=flow "
|
||||
f"2>{args.annotation_output}/dump_after_winograd.mlir "
|
||||
dump_module = ireec.compile_str(
|
||||
input_mlir,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=device_spec_args
|
||||
+ [
|
||||
preprocess_flag,
|
||||
"--compile-to=preprocessing",
|
||||
],
|
||||
)
|
||||
return dump_module
|
||||
|
||||
|
||||
# For Unet annotate the model with tuned lowering configs
|
||||
@@ -163,72 +176,63 @@ def annotate_with_lower_configs(
|
||||
input_mlir, lowering_config_dir, model_name, use_winograd
|
||||
):
|
||||
# Dump IR after padding/img2col/winograd passes
|
||||
dump_after_mlir(input_mlir, model_name, use_winograd)
|
||||
dump_module = dump_after_mlir(input_mlir, use_winograd)
|
||||
print("Applying tuned configs on", model_name)
|
||||
|
||||
# Annotate the model with lowering configs in the config file
|
||||
with create_context() as ctx:
|
||||
tuned_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=f"{args.annotation_output}/dump_after_winograd.mlir",
|
||||
input_contents=dump_module,
|
||||
config_path=lowering_config_dir,
|
||||
search_op="all",
|
||||
)
|
||||
|
||||
# Remove the intermediate mlir and save the final annotated model
|
||||
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
|
||||
if model_name.split("_")[-1] != "tuned":
|
||||
out_file_path = (
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
)
|
||||
else:
|
||||
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
tuned_model.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(tuned_model))
|
||||
f.close()
|
||||
return bytecode, out_file_path
|
||||
if args.save_annotation:
|
||||
if model_name.split("_")[-1] != "tuned":
|
||||
out_file_path = (
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
)
|
||||
else:
|
||||
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(tuned_model))
|
||||
f.close()
|
||||
|
||||
return bytecode
|
||||
|
||||
|
||||
def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
|
||||
def sd_model_annotation(mlir_model, model_name):
|
||||
device = get_device()
|
||||
if args.annotation_model == "unet" and device == "vulkan":
|
||||
use_winograd = True
|
||||
winograd_config_dir = load_winograd_configs()
|
||||
winograd_model, model_path = annotate_with_winograd(
|
||||
winograd_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
lowering_config_dir = load_lower_configs()
|
||||
tuned_model, output_path = annotate_with_lower_configs(
|
||||
model_path, lowering_config_dir, model_name, use_winograd
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
winograd_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
elif args.annotation_model == "vae" and device == "vulkan":
|
||||
use_winograd = True
|
||||
winograd_config_dir = load_winograd_configs()
|
||||
tuned_model, output_path = annotate_with_winograd(
|
||||
tuned_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
else:
|
||||
use_winograd = False
|
||||
if model_from_tank:
|
||||
mlir_model = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
|
||||
else:
|
||||
# Just use this function to convert bytecode to string
|
||||
orig_model, model_path = annotate_with_winograd(
|
||||
mlir_model, "", model_name
|
||||
)
|
||||
mlir_model = model_path
|
||||
lowering_config_dir = load_lower_configs()
|
||||
tuned_model, output_path = annotate_with_lower_configs(
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
mlir_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
print(f"Saved the annotated mlir in {output_path}.")
|
||||
return tuned_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlir_model, model_name = load_model_from_tank()
|
||||
sd_model_annotation(mlir_model, model_name, model_from_tank=True)
|
||||
sd_model_annotation(mlir_model, model_name)
|
||||
|
||||
@@ -17,18 +17,24 @@ p = argparse.ArgumentParser(
|
||||
p.add_argument(
|
||||
"-p",
|
||||
"--prompts",
|
||||
action="append",
|
||||
default=[],
|
||||
nargs="+",
|
||||
default=["cyberpunk forest by Salvador Dali"],
|
||||
help="text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative_prompts",
|
||||
nargs="+",
|
||||
default=[""],
|
||||
default=["trees, green"],
|
||||
help="text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--img_path",
|
||||
type=str,
|
||||
help="Path to the image input for img2img/inpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
@@ -39,8 +45,8 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed to use.",
|
||||
default=-1,
|
||||
help="the seed to use. -1 for a random one.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -48,13 +54,14 @@ p.add_argument(
|
||||
type=int,
|
||||
default=1,
|
||||
choices=range(1, 4),
|
||||
help="the number of inferences to be made in a single `run`.",
|
||||
help="the number of inferences to be made in a single `batch_count`.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(384, 768, 8),
|
||||
help="the height of the output image.",
|
||||
)
|
||||
|
||||
@@ -62,6 +69,7 @@ p.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(384, 768, 8),
|
||||
help="the width of the output image.",
|
||||
)
|
||||
|
||||
@@ -79,6 +87,81 @@ p.add_argument(
|
||||
help="max length of the tokenizer output, options are 64 and 77.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="the strength of change applied on the given input image for img2img",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Inpainting and Outpainting Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--mask_path",
|
||||
type=str,
|
||||
help="Path to the mask image input for inpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--pixels",
|
||||
type=int,
|
||||
default=128,
|
||||
choices=range(8, 256, 8),
|
||||
help="Number of expended pixels for one direction for outpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--mask_blur",
|
||||
type=int,
|
||||
default=8,
|
||||
choices=range(0, 64),
|
||||
help="Number of blur pixels for outpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--left",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend left for outpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--right",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend right for outpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--top",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend top for outpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--bottom",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend bottom for outpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--noise_q",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--color_variation",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Color variation for outpainting (min=0.0, max=1.0)",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Model Config and Usage Params
|
||||
##############################################################################
|
||||
@@ -148,10 +231,10 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--runs",
|
||||
"--batch_count",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of images to be generated with random seeds in single execution",
|
||||
help="number of batch to be generated with random seeds in single execution",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -161,6 +244,13 @@ p.add_argument(
|
||||
help="Path to SD's .ckpt file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--custom_vae",
|
||||
type=str,
|
||||
default="",
|
||||
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_model_id",
|
||||
type=str,
|
||||
@@ -169,10 +259,17 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--enable_stack_trace",
|
||||
"--low_cpu_mem_usage",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Enable showing the stack trace when retrying the base model configuration",
|
||||
help="Use the accelerate package to reduce cpu memory consumption",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--attention_slicing",
|
||||
type=str,
|
||||
default="none",
|
||||
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer)",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
@@ -180,7 +277,7 @@ p.add_argument(
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree-vulkan-target-triple",
|
||||
"--iree_vulkan_target_triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan",
|
||||
@@ -279,7 +376,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--write_metadata_to_png",
|
||||
default=False,
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
|
||||
)
|
||||
@@ -292,7 +389,7 @@ p.add_argument(
|
||||
"--progress_bar",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for removing the pregress bar animation during image generation",
|
||||
help="flag for removing the progress bar animation during image generation",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -336,10 +433,10 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_winograd",
|
||||
"--save_annotation",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Apply Winograd on selected conv ops.",
|
||||
help="Save annotated mlir file",
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import os
|
||||
import gc
|
||||
import json
|
||||
import re
|
||||
from PIL import PngImagePlugin
|
||||
from datetime import datetime as dt
|
||||
from csv import DictWriter
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from random import randint
|
||||
@@ -14,26 +18,30 @@ from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
import sys, functools, operator
|
||||
import sys
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
|
||||
def get_vmfb_path_name(model_name):
|
||||
def get_extended_name(model_name):
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
return [vmfb_path, extended_name]
|
||||
return extended_name
|
||||
|
||||
|
||||
def get_vmfb_path_name(model_name):
|
||||
vmfb_path = os.path.join(os.getcwd(), model_name + ".vmfb")
|
||||
return vmfb_path
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
[vmfb_path, extended_name] = get_vmfb_path_name(model_name)
|
||||
vmfb_path = get_vmfb_path_name(model_name)
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
@@ -47,7 +55,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
os.getcwd(), model_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -231,9 +239,12 @@ def set_init_device_flags():
|
||||
args.max_length = 64
|
||||
|
||||
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
|
||||
base_model_id = args.hf_model_id
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
|
||||
if (
|
||||
args.hf_model_id == "prompthero/openjourney"
|
||||
or args.ckpt_loc != ""
|
||||
"inpainting" in args.hf_model_id
|
||||
or args.precision != "fp16"
|
||||
or args.height != 512
|
||||
or args.width != 512
|
||||
@@ -242,13 +253,24 @@ def set_init_device_flags():
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif (
|
||||
"vulkan" in args.device
|
||||
and "rdna3" not in args.iree_vulkan_target_triple
|
||||
elif args.ckpt_loc != "" and base_model_id not in [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
]:
|
||||
args.use_tuned = False
|
||||
|
||||
elif "vulkan" in args.device and not any(
|
||||
x in args.iree_vulkan_target_triple for x in ["rdna2", "rdna3"]
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80"]:
|
||||
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]:
|
||||
args.use_tuned = False
|
||||
|
||||
elif args.use_base_vae and args.hf_model_id not in [
|
||||
@@ -258,7 +280,7 @@ def set_init_device_flags():
|
||||
args.use_tuned = False
|
||||
|
||||
if args.use_tuned:
|
||||
print(f"Using tuned models for {args.hf_model_id}/fp16/{args.device}.")
|
||||
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
|
||||
else:
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
@@ -274,12 +296,35 @@ def set_init_device_flags():
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
]:
|
||||
args.import_mlir = True
|
||||
|
||||
elif args.height != 512 or args.width != 512 or args.batch_size != 1:
|
||||
args.import_mlir = True
|
||||
|
||||
elif args.use_tuned and args.hf_model_id in [
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
"prompthero/openjourney",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
]:
|
||||
args.import_mlir = True
|
||||
|
||||
elif (
|
||||
args.use_tuned
|
||||
and "vulkan" in args.device
|
||||
and "rdna2" in args.iree_vulkan_target_triple
|
||||
):
|
||||
args.import_mlir = True
|
||||
|
||||
elif (
|
||||
args.use_tuned
|
||||
and "cuda" in args.device
|
||||
and get_cuda_sm_cc() == "sm_89"
|
||||
):
|
||||
args.import_mlir = True
|
||||
|
||||
|
||||
# Utility to get list of devices available.
|
||||
def get_available_devices():
|
||||
@@ -354,6 +399,11 @@ def get_opt_flags(model, precision="fp16"):
|
||||
return iree_flags
|
||||
|
||||
|
||||
def get_path_stem(path):
|
||||
path = Path(path)
|
||||
return path.stem
|
||||
|
||||
|
||||
def get_path_to_diffusers_checkpoint(custom_weights):
|
||||
path = Path(custom_weights)
|
||||
diffusers_path = path.parent.absolute()
|
||||
@@ -395,7 +445,7 @@ def preprocessCKPT(custom_weights):
|
||||
|
||||
|
||||
def load_vmfb(vmfb_path, model, precision):
|
||||
model = "vae" if "base_vae" in model else model
|
||||
model = "vae" if "base_vae" in model or "vae_encode" in model else model
|
||||
precision = "fp32" if "clip" in model else precision
|
||||
extra_args = get_opt_flags(model, precision)
|
||||
shark_module = SharkInference(mlir_module=None, device=args.device)
|
||||
@@ -403,24 +453,34 @@ def load_vmfb(vmfb_path, model, precision):
|
||||
return shark_module
|
||||
|
||||
|
||||
# This utility returns vmfbs of Clip, Unet and Vae, in case all three of them
|
||||
# This utility returns vmfbs of Clip, Unet, Vae and Vae_encode, in case all of them
|
||||
# are present; deletes them otherwise.
|
||||
def fetch_or_delete_vmfbs(basic_model_name, use_base_vae, precision="fp32"):
|
||||
model_name = ["clip", "unet", "base_vae" if use_base_vae else "vae"]
|
||||
def fetch_or_delete_vmfbs(
|
||||
extended_model_name, need_vae_encode, precision="fp32"
|
||||
):
|
||||
vmfb_path = [
|
||||
get_vmfb_path_name(model + basic_model_name)[0] for model in model_name
|
||||
get_vmfb_path_name(extended_model_name[model])
|
||||
for model in extended_model_name
|
||||
]
|
||||
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
|
||||
all_vmfb_present = functools.reduce(operator.__and__, vmfb_present)
|
||||
compiled_models = [None] * 3
|
||||
all_vmfb_present = True
|
||||
compiled_models = []
|
||||
for i in range(3):
|
||||
all_vmfb_present = all_vmfb_present and vmfb_present[i]
|
||||
compiled_models.append(None)
|
||||
if need_vae_encode:
|
||||
all_vmfb_present = all_vmfb_present and vmfb_present[3]
|
||||
compiled_models.append(None)
|
||||
|
||||
# We need to delete vmfbs only if some of the models were compiled.
|
||||
if not all_vmfb_present:
|
||||
for i in range(len(vmfb_path)):
|
||||
for i in range(len(compiled_models)):
|
||||
if vmfb_present[i]:
|
||||
os.remove(vmfb_path[i])
|
||||
print("Deleted: ", vmfb_path[i])
|
||||
else:
|
||||
for i in range(len(vmfb_path)):
|
||||
model_name = [model for model in extended_model_name.keys()]
|
||||
for i in range(len(compiled_models)):
|
||||
compiled_models[i] = load_vmfb(
|
||||
vmfb_path[i], model_name[i], precision
|
||||
)
|
||||
@@ -458,3 +518,95 @@ def sanitize_seed(seed):
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
return seed
|
||||
|
||||
|
||||
# clear all the cached objects to recompile cleanly.
|
||||
def clear_all():
|
||||
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
|
||||
from glob import glob
|
||||
import shutil
|
||||
|
||||
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
|
||||
for vmfb in vmfbs:
|
||||
if os.path.exists(vmfb):
|
||||
os.remove(vmfb)
|
||||
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
|
||||
# TODO: Remove this once we have better weight updation logic.
|
||||
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
|
||||
for yaml in inference_yaml:
|
||||
if os.path.exists(yaml):
|
||||
os.remove(yaml)
|
||||
home = os.path.expanduser("~")
|
||||
if os.name == "nt": # Windows
|
||||
appdata = os.getenv("LOCALAPPDATA")
|
||||
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
|
||||
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
|
||||
elif os.name == "unix":
|
||||
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed):
|
||||
output_path = args.output_dir if args.output_dir else Path.cwd()
|
||||
generated_imgs_path = Path(
|
||||
output_path, "generated_imgs", dt.now().strftime("%Y%m%d")
|
||||
)
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
|
||||
out_img_name = (
|
||||
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
img_model = args.hf_model_id
|
||||
if args.ckpt_loc:
|
||||
img_model = os.path.basename(args.ckpt_loc)
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if args.write_metadata_to_png:
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
if args.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {args.output_img_format} is not supported yet."
|
||||
"Image saved as png instead. Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
new_entry = {
|
||||
"VARIANT": img_model,
|
||||
"SCHEDULER": args.scheduler,
|
||||
"PROMPT": args.prompts[0],
|
||||
"NEG_PROMPT": args.negative_prompts[0],
|
||||
"SEED": img_seed,
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"HEIGHT": args.height,
|
||||
"WIDTH": args.width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
}
|
||||
|
||||
with open(csv_path, "a") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
if args.save_metadata_to_json:
|
||||
del new_entry["OUTPUT"]
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
# Stable Diffusion optimized for AMD RDNA2/RDNA3 GPUs
|
||||
|
||||
Before you start, please be aware that this is beta software that relies on a special AMD driver. Like all StableDiffusion GUIs published so far, you need some technical expertise to set it up. We apologize in advance if you bump into issues. If that happens, please don't hesitate to ask our Discord community for help! Please be assured that we (Nod and AMD) are working hard to improve the user experience in coming months.
|
||||
If it works well for you, please "star" the following GitHub projects... this is one of the best ways to help and spread the word!
|
||||
|
||||
* https://github.com/nod-ai/SHARK
|
||||
* https://github.com/iree-org/iree
|
||||
|
||||
## Install this specific AMD Drivers (AMD latest may not have all the fixes).
|
||||
|
||||
### AMD KB Drivers for RDNA2 and RDNA3:
|
||||
|
||||
*AMD Software: Adrenalin Edition 22.11.1 for MLIR/IREE Driver Version 22.20.29.09 for Windows® 10 and Windows® 11 (Windows Driver Store Version 31.0.12029.9003)*
|
||||
|
||||
First, for RDNA2 users, download this special driver in a folder of your choice. We recommend you keep the installation files around, since you may need to re-install it later, if Windows Update decides to overwrite it:
|
||||
https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mlir-iree
|
||||
|
||||
For RDNA3, the latest driver 23.1.2 supports MLIR/IREE as well: https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-1-2-kb
|
||||
|
||||
KNOWN ISSUES with this special AMD driver:
|
||||
* `Windows Update` may (depending how it's configured) automatically install a new official AMD driver that overwrites this IREE-specific driver. If Stable Diffusion used to work, then a few days later, it slows down a lot or produces incorrect results (e.g. black images), this may be the cause. To fix this problem, please check the installed driver version, and re-install the special driver if needed. (TODO: document how to prevent this `Windows Update` behavior!)
|
||||
* Some people using this special driver experience mouse pointer accuracy issues, especially if using a larger-than-default mouse pointer. The clicked point isn't centered properly. One possible work-around is to reset the pointer size to "1" in "Change pointer size and color".
|
||||
|
||||
## Installation
|
||||
|
||||
Download the latest Windows SHARK SD binary [492 here](https://github.com/nod-ai/SHARK/releases/download/20230203.492/shark_sd_20230203_492.exe) in a folder of your choice. If you want nighly builds, you can look for them on the GitHub releases page.
|
||||
|
||||
Notes:
|
||||
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR which can be outdated if you run a new EXE from the same folder. You can use `--clear_all` flag once to clean all the old files.
|
||||
* If you recently updated the driver or this binary (EXE file), we recommend you:
|
||||
* clear all the local artifacts with `--clear_all` OR
|
||||
* clear the Vulkan shader cache: For Windows users this can be done by clearing the contents of `C:\Users\%username%\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
|
||||
* clear the `huggingface` cache. In Windows, this is `C:\Users\%username%\.cache\huggingface`.
|
||||
|
||||
## Running
|
||||
|
||||
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE to start the web browser)
|
||||
* The first run may take about 10-15 minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
|
||||
* If successful, you will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
|
||||
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/?__theme=dark.
|
||||
|
||||
## Stopping
|
||||
|
||||
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment. The application should stop.
|
||||
* Please make sure to do the above step before you attempt to update the EXE to a new version.
|
||||
|
||||
# Results
|
||||
|
||||
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
|
||||
|
||||
|
||||
Here are some samples generated:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
The output on a 7900XTX would like:
|
||||
|
||||
```shell
|
||||
Stats for run 0:
|
||||
Average step time: 47.19188690185547ms/it
|
||||
Clip Inference time (ms) = 109.531
|
||||
VAE Inference time (ms): 78.590
|
||||
|
||||
Total image generation time: 2.5788655281066895sec
|
||||
```
|
||||
|
||||
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
|
||||
@@ -1,14 +1,21 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import glob
|
||||
|
||||
if "AMD_ENABLE_LLPC" not in os.environ:
|
||||
os.environ["AMD_ENABLE_LLPC"] = "1"
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
|
||||
import gradio as gr
|
||||
from apps.stable_diffusion.src import args, clear_all
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
clear_gradio_tmp_imgs_folder,
|
||||
)
|
||||
|
||||
# clear all gradio tmp images from the last session
|
||||
clear_gradio_tmp_imgs_folder()
|
||||
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
@@ -18,245 +25,23 @@ def resource_path(relative_path):
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.src import (
|
||||
prompt_examples,
|
||||
args,
|
||||
get_available_devices,
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_web,
|
||||
img2img_web,
|
||||
inpaint_web,
|
||||
outpaint_web,
|
||||
)
|
||||
from apps.stable_diffusion.scripts import txt2img_inf
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
sdlogo_loc = resource_path("logos/sd-demo-logo.png")
|
||||
sd_web = gr.TabbedInterface(
|
||||
[txt2img_web, img2img_web, inpaint_web, outpaint_web],
|
||||
["Text-to-Image", "Image-to-Image", "Inpainting", "Outpainting"],
|
||||
css=dark_theme,
|
||||
)
|
||||
|
||||
|
||||
demo_css = resource_path("css/sd_dark_theme.css")
|
||||
|
||||
|
||||
with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
logo2 = Image.open(sdlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=100)
|
||||
with gr.Column(scale=5, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=logo2,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="demo_title",
|
||||
).style(width=150, height=100)
|
||||
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value="None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value="cyberpunk forest by Salvador Dali",
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value="trees, green",
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value="SharkEulerDiscrete",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 786, value=512, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 786, value=512, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp16",
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=64,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=50, step=1, label="Steps"
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=7.5,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
10,
|
||||
value=1,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=1,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(value=-1, precision=0, label="Seed")
|
||||
available_devices = get_available_devices()
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image")
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
inputs=prompt,
|
||||
cache_examples=False,
|
||||
elem_id="prompt_examples",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2], height="auto")
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=4,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
kwargs = dict(
|
||||
fn=txt2img_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
|
||||
shark_web.queue()
|
||||
shark_web.launch(
|
||||
sd_web.queue()
|
||||
sd_web.launch(
|
||||
share=args.share,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 33 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 5.0 KiB |
4
apps/stable_diffusion/web/ui/__init__.py
Normal file
4
apps/stable_diffusion/web/ui/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import txt2img_web
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import img2img_web
|
||||
from apps.stable_diffusion.web.ui.inpaint_ui import inpaint_web
|
||||
from apps.stable_diffusion.web.ui.outpaint_ui import outpaint_web
|
||||
@@ -144,19 +144,30 @@
|
||||
--dataset-table-border-hover: var(--color-grey-800);
|
||||
}
|
||||
|
||||
/* SHARK theme customization */
|
||||
|
||||
.gradio-container {
|
||||
/* SHARK theme */
|
||||
body {
|
||||
background-color: var(--color-background-primary);
|
||||
}
|
||||
|
||||
/* display in full width for desktop devices */
|
||||
@media (min-width: 1536px)
|
||||
{
|
||||
.gradio-container {
|
||||
max-width: var(--size-full) !important;
|
||||
}
|
||||
}
|
||||
|
||||
.gradio-container .contain {
|
||||
padding: 0 var(--size-4) !important;
|
||||
}
|
||||
|
||||
.container {
|
||||
background-color: black !important;
|
||||
padding-top: 20px !important;
|
||||
padding-top: var(--size-5) !important;
|
||||
}
|
||||
|
||||
#ui_title {
|
||||
padding: 10px !important;
|
||||
padding: var(--size-2) 0 0 var(--size-1);
|
||||
}
|
||||
|
||||
#top_logo {
|
||||
@@ -165,15 +176,6 @@
|
||||
border: 0;
|
||||
}
|
||||
|
||||
#demo_title {
|
||||
background-color: var(--color-background-primary);
|
||||
border-radius: 0 !important;
|
||||
border: 0;
|
||||
padding-top: 15px;
|
||||
padding-bottom: 0px;
|
||||
width: 350px !important;
|
||||
}
|
||||
|
||||
#demo_title_outer {
|
||||
border-radius: 0;
|
||||
}
|
||||
@@ -182,7 +184,7 @@
|
||||
border-radius: 0 !important
|
||||
}
|
||||
|
||||
#prompt_box textarea {
|
||||
#prompt_box textarea, #negative_prompt_box textarea {
|
||||
background-color: var(--color-background-primary) !important;
|
||||
}
|
||||
|
||||
@@ -196,7 +198,7 @@
|
||||
|
||||
#ui_body {
|
||||
background-color: var(--color-background-secondary) !important;
|
||||
padding: 10px !important;
|
||||
padding: var(--size-2) !important;
|
||||
border-radius: 0.5em !important;
|
||||
}
|
||||
|
||||
@@ -207,3 +209,7 @@
|
||||
footer {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
#gallery + div {
|
||||
border-radius: 0 !important;
|
||||
}
|
||||
231
apps/stable_diffusion/web/ui/img2img_ui.py
Normal file
231
apps/stable_diffusion/web/ui/img2img_ui.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import img2img_inf
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
init_image = gr.Image(label="Input Image", type="pil")
|
||||
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value="PNDM",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"DPMSolverMultistep",
|
||||
"EulerAncestralDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 786, value=args.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 786, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.strength,
|
||||
step=0.01,
|
||||
label="Strength",
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2])
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=1,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
kwargs = dict(
|
||||
fn=img2img_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
224
apps/stable_diffusion/web/ui/inpaint_ui.py
Normal file
224
apps/stable_diffusion/web/ui/inpaint_ui.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import inpaint_inf
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
init_image = gr.Image(
|
||||
label="Masked Image",
|
||||
source="upload",
|
||||
tool="sketch",
|
||||
type="pil",
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value="PNDM",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"DPMSolverMultistep",
|
||||
"EulerAncestralDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 786, value=args.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 786, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2])
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=1,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
kwargs = dict(
|
||||
fn=inpaint_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
|
Before Width: | Height: | Size: 10 KiB After Width: | Height: | Size: 10 KiB |
260
apps/stable_diffusion/web/ui/outpaint_ui.py
Normal file
260
apps/stable_diffusion/web/ui/outpaint_ui.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import outpaint_inf
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
init_image = gr.Image(label="Input Image", type="filepath")
|
||||
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value="PNDM",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"DPMSolverMultistep",
|
||||
"EulerAncestralDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
pixels = gr.Slider(
|
||||
8,
|
||||
256,
|
||||
value=args.pixels,
|
||||
step=8,
|
||||
label="Pixels to expand",
|
||||
)
|
||||
mask_blur = gr.Slider(
|
||||
0,
|
||||
64,
|
||||
value=args.mask_blur,
|
||||
step=1,
|
||||
label="Mask blur",
|
||||
)
|
||||
with gr.Row():
|
||||
directions = gr.CheckboxGroup(
|
||||
label="Outpainting direction",
|
||||
choices=["left", "right", "up", "down"],
|
||||
value=["left", "right", "up", "down"],
|
||||
)
|
||||
with gr.Row():
|
||||
noise_q = gr.Slider(
|
||||
0.0,
|
||||
4.0,
|
||||
value=1.0,
|
||||
step=0.01,
|
||||
label="Fall-off exponent (lower=higher detail)",
|
||||
)
|
||||
color_variation = gr.Slider(
|
||||
0.0,
|
||||
1.0,
|
||||
value=0.05,
|
||||
step=0.01,
|
||||
label="Color variation",
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 786, value=args.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 786, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=20, step=1, label="Steps"
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2])
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=1,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
kwargs = dict(
|
||||
fn=outpaint_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
init_image,
|
||||
pixels,
|
||||
mask_blur,
|
||||
directions,
|
||||
noise_q,
|
||||
color_variation,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
229
apps/stable_diffusion/web/ui/txt2img_ui.py
Normal file
229
apps/stable_diffusion/web/ui/txt2img_ui.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import txt2img_inf
|
||||
from apps.stable_diffusion.src import prompt_examples, args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"KDPM2Discrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 786, value=args.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 786, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
inputs=prompt,
|
||||
cache_examples=False,
|
||||
elem_id="prompt_examples",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2])
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=1,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
kwargs = dict(
|
||||
fn=txt2img_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
15
apps/stable_diffusion/web/ui/utils.py
Normal file
15
apps/stable_diffusion/web/ui/utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
import sys
|
||||
from apps.stable_diffusion.src import get_available_devices
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
available_devices = get_available_devices()
|
||||
31
apps/stable_diffusion/web/utils/gradio_configs.py
Normal file
31
apps/stable_diffusion/web/utils/gradio_configs.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import os
|
||||
import tempfile
|
||||
import gradio
|
||||
from os import listdir
|
||||
|
||||
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
|
||||
|
||||
# Clear all gradio tmp images
|
||||
def clear_gradio_tmp_imgs_folder():
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
return
|
||||
for fileName in listdir(gradio_tmp_imgs_folder):
|
||||
# Delete tmp png files
|
||||
if fileName.startswith("tmp") and fileName.endswith(".png"):
|
||||
os.remove(gradio_tmp_imgs_folder + fileName)
|
||||
|
||||
|
||||
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
os.mkdir(gradio_tmp_imgs_folder)
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
|
||||
)
|
||||
pil_image.save(file_obj)
|
||||
return file_obj
|
||||
|
||||
|
||||
# Register save_pil_to_file override
|
||||
gradio.processing_utils.save_pil_to_file = save_pil_to_file
|
||||
@@ -30,9 +30,15 @@ def compare_images(new_filename, golden_filename):
|
||||
diff = np.abs(new - golden)
|
||||
mean = np.mean(diff)
|
||||
if mean > 0.1:
|
||||
subprocess.run(
|
||||
["gsutil", "cp", new_filename, "gs://shark_tank/testdata/builder/"]
|
||||
)
|
||||
if os.name != "nt":
|
||||
subprocess.run(
|
||||
[
|
||||
"gsutil",
|
||||
"cp",
|
||||
new_filename,
|
||||
"gs://shark_tank/testdata/builder/",
|
||||
]
|
||||
)
|
||||
raise SystemExit("new and golden not close")
|
||||
else:
|
||||
print("SUCCESS")
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import os
|
||||
from sys import executable
|
||||
import subprocess
|
||||
from apps.stable_diffusion.src.utils.resources import (
|
||||
get_json_file,
|
||||
)
|
||||
from datetime import datetime as dt
|
||||
from shark.shark_downloader import download_public_file
|
||||
from image_comparison import compare_images
|
||||
import argparse
|
||||
from glob import glob
|
||||
import shutil
|
||||
import requests
|
||||
|
||||
model_config_dicts = get_json_file(
|
||||
os.path.join(
|
||||
@@ -17,51 +20,179 @@ model_config_dicts = get_json_file(
|
||||
)
|
||||
|
||||
|
||||
def parse_sd_out(filename, command, device, use_tune, model_name, import_mlir):
|
||||
with open(filename, "r+") as f:
|
||||
lines = f.readlines()
|
||||
metrics = {}
|
||||
vals_to_read = [
|
||||
"Clip Inference time",
|
||||
"Average step",
|
||||
"VAE Inference time",
|
||||
"Total image generation",
|
||||
]
|
||||
for line in lines:
|
||||
for val in vals_to_read:
|
||||
if val in line:
|
||||
metrics[val] = line.split(" ")[-1].strip("\n")
|
||||
|
||||
metrics["Average step"] = metrics["Average step"].strip("ms/it")
|
||||
metrics["Total image generation"] = metrics[
|
||||
"Total image generation"
|
||||
].strip("sec")
|
||||
metrics["device"] = device
|
||||
metrics["use_tune"] = use_tune
|
||||
metrics["model_name"] = model_name
|
||||
metrics["import_mlir"] = import_mlir
|
||||
metrics["command"] = command
|
||||
return metrics
|
||||
|
||||
|
||||
def get_inpaint_inputs():
|
||||
os.mkdir("./test_images/inputs")
|
||||
img_url = (
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve"
|
||||
"/main/stable_diffusion_inpaint/input_bench_image.png"
|
||||
)
|
||||
mask_url = (
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve"
|
||||
"/main/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
)
|
||||
img = requests.get(img_url)
|
||||
mask = requests.get(mask_url)
|
||||
open("./test_images/inputs/image.png", "wb").write(img.content)
|
||||
open("./test_images/inputs/mask.png", "wb").write(mask.content)
|
||||
|
||||
|
||||
def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
# Get golden values from tank
|
||||
shutil.rmtree("./test_images", ignore_errors=True)
|
||||
model_metrics = []
|
||||
os.mkdir("./test_images")
|
||||
os.mkdir("./test_images/golden")
|
||||
get_inpaint_inputs()
|
||||
hf_model_names = model_config_dicts[0].values()
|
||||
tuned_options = ["--no-use_tuned", "use_tuned"]
|
||||
tuned_options = ["--no-use_tuned", "--use_tuned"]
|
||||
import_options = ["--import_mlir", "--no-import_mlir"]
|
||||
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
|
||||
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
if os.name == "nt":
|
||||
prompt_text = '--prompt="cyberpunk forest by Salvador Dali"'
|
||||
inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"'
|
||||
if beta:
|
||||
extra_flags.append("--beta_models=True")
|
||||
for model_name in hf_model_names:
|
||||
for use_tune in tuned_options:
|
||||
command = [
|
||||
"python",
|
||||
"apps/stable_diffusion/scripts/txt2img.py",
|
||||
"--device=" + device,
|
||||
"--prompt=cyberpunk forest by Salvador Dali",
|
||||
"--output_dir="
|
||||
+ os.path.join(os.getcwd(), "test_images", model_name),
|
||||
"--hf_model_id=" + model_name,
|
||||
use_tune,
|
||||
extra_flags.append("--no-progress_bar")
|
||||
to_skip = [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
]
|
||||
for import_opt in import_options:
|
||||
for model_name in hf_model_names:
|
||||
if model_name in to_skip:
|
||||
continue
|
||||
for use_tune in tuned_options:
|
||||
command = (
|
||||
[
|
||||
executable, # executable is the python from the venv used to run this
|
||||
"apps/stable_diffusion/scripts/txt2img.py",
|
||||
"--device=" + device,
|
||||
prompt_text,
|
||||
"--negative_prompts=" + '""',
|
||||
"--seed=42",
|
||||
import_opt,
|
||||
"--output_dir="
|
||||
+ os.path.join(os.getcwd(), "test_images", model_name),
|
||||
"--hf_model_id=" + model_name,
|
||||
use_tune,
|
||||
]
|
||||
if "inpainting" not in model_name
|
||||
else [
|
||||
executable,
|
||||
"apps/stable_diffusion/scripts/inpaint.py",
|
||||
"--device=" + device,
|
||||
inpaint_prompt_text,
|
||||
"--negative_prompts=" + '""',
|
||||
"--img_path=./test_images/inputs/image.png",
|
||||
"--mask_path=./test_images/inputs/mask.png",
|
||||
"--seed=42",
|
||||
"--import_mlir",
|
||||
"--output_dir="
|
||||
+ os.path.join(os.getcwd(), "test_images", model_name),
|
||||
"--hf_model_id=" + model_name,
|
||||
use_tune,
|
||||
]
|
||||
)
|
||||
command += extra_flags
|
||||
if os.name == "nt":
|
||||
command = " ".join(command)
|
||||
dumpfile_name = "_".join(model_name.split("/")) + ".txt"
|
||||
dumpfile_name = os.path.join(os.getcwd(), dumpfile_name)
|
||||
with open(dumpfile_name, "w+") as f:
|
||||
generated_image = not subprocess.call(
|
||||
command,
|
||||
stdout=f,
|
||||
stderr=f,
|
||||
)
|
||||
if os.name != "nt":
|
||||
command = " ".join(command)
|
||||
if generated_image:
|
||||
model_metrics.append(
|
||||
parse_sd_out(
|
||||
dumpfile_name,
|
||||
command,
|
||||
device,
|
||||
use_tune,
|
||||
model_name,
|
||||
import_opt,
|
||||
)
|
||||
)
|
||||
print(command)
|
||||
print("Successfully generated image")
|
||||
os.makedirs(
|
||||
"./test_images/golden/" + model_name, exist_ok=True
|
||||
)
|
||||
download_public_file(
|
||||
"gs://shark_tank/testdata/golden/" + model_name,
|
||||
"./test_images/golden/" + model_name,
|
||||
)
|
||||
test_file_path = os.path.join(
|
||||
os.getcwd(),
|
||||
"test_images",
|
||||
model_name,
|
||||
"generated_imgs",
|
||||
dt.now().strftime("%Y%m%d"),
|
||||
"*.png",
|
||||
)
|
||||
test_file = glob(test_file_path)[0]
|
||||
|
||||
golden_path = (
|
||||
"./test_images/golden/" + model_name + "/*.png"
|
||||
)
|
||||
golden_file = glob(golden_path)[0]
|
||||
compare_images(test_file, golden_file)
|
||||
else:
|
||||
print(command)
|
||||
print("failed to generate image for this configuration")
|
||||
if "2_1_base" in model_name:
|
||||
print("failed a known successful model.")
|
||||
exit(1)
|
||||
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
|
||||
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
|
||||
f.write(header)
|
||||
for metric in model_metrics:
|
||||
output = [
|
||||
metric["model_name"],
|
||||
metric["device"],
|
||||
metric["use_tune"],
|
||||
metric["import_mlir"],
|
||||
metric["Clip Inference time"],
|
||||
metric["Average step"],
|
||||
metric["VAE Inference time"],
|
||||
metric["Total image generation"],
|
||||
metric["command"],
|
||||
]
|
||||
command += extra_flags
|
||||
generated_image = not subprocess.call(
|
||||
command, stdout=subprocess.DEVNULL
|
||||
)
|
||||
if generated_image:
|
||||
print(" ".join(command))
|
||||
print("Successfully generated image")
|
||||
os.makedirs(
|
||||
"./test_images/golden/" + model_name, exist_ok=True
|
||||
)
|
||||
download_public_file(
|
||||
"gs://shark_tank/testdata/golden/" + model_name,
|
||||
"./test_images/golden/" + model_name,
|
||||
)
|
||||
test_file_path = os.path.join(
|
||||
os.getcwd(), "test_images", model_name, "generated_imgs"
|
||||
)
|
||||
test_file = glob(test_file_path + "/*.png")[0]
|
||||
golden_path = "./test_images/golden/" + model_name + "/*.png"
|
||||
golden_file = glob(golden_path)[0]
|
||||
compare_images(test_file, golden_file)
|
||||
else:
|
||||
print(" ".join(command))
|
||||
print("failed to generate image for this configuration")
|
||||
f.write(";".join(output) + "\n")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
10
conftest.py
10
conftest.py
@@ -60,3 +60,13 @@ def pytest_addoption(parser):
|
||||
default="gs://shark_tank/latest",
|
||||
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
|
||||
)
|
||||
parser.addoption(
|
||||
"--benchmark_dispatches",
|
||||
default=None,
|
||||
help="Benchmark individual dispatch kernels produced by IREE compiler. Use 'All' for all, or specific dispatches e.g. '0 1 2 10'",
|
||||
)
|
||||
parser.addoption(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="./temp_dispatch_benchmarks",
|
||||
help="Directory in which dispatch benchmarks are saved.",
|
||||
)
|
||||
|
||||
@@ -162,13 +162,13 @@ def save_tf_model(tf_model_list):
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
input,
|
||||
inputs=input,
|
||||
frontend="tf",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
dir=tf_model_dir,
|
||||
model_name=tf_model_name,
|
||||
)
|
||||
|
||||
44
process_skipfiles.py
Normal file
44
process_skipfiles.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# This script will toggle the comment/uncommenting aspect for dealing
|
||||
# with __file__ AttributeError arising in case of a few modules in
|
||||
# `torch/_dynamo/skipfiles.py` (within shark.venv)
|
||||
|
||||
from distutils.sysconfig import get_python_lib
|
||||
import fileinput
|
||||
from pathlib import Path
|
||||
|
||||
# Diffusers 0.13.1 fails with transformers __init.py errros in BLIP. So remove it for now until we fork it
|
||||
pix2pix_file = Path(
|
||||
get_python_lib()
|
||||
+ "/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py"
|
||||
)
|
||||
if pix2pix_file.exists():
|
||||
print("Removing..%s", pix2pix_file)
|
||||
pix2pix_file.unlink()
|
||||
|
||||
|
||||
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
|
||||
|
||||
modules_to_comment = ["abc,", "os,", "posixpath,", "_collections_abc,"]
|
||||
startMonitoring = 0
|
||||
for line in fileinput.input(path_to_skipfiles, inplace=True):
|
||||
if "SKIP_DIRS = " in line:
|
||||
startMonitoring = 1
|
||||
print(line, end="")
|
||||
elif startMonitoring in [1, 2]:
|
||||
if "]" in line:
|
||||
startMonitoring += 1
|
||||
print(line, end="")
|
||||
else:
|
||||
flag = True
|
||||
for module in modules_to_comment:
|
||||
if module in line:
|
||||
if not line.startswith("#"):
|
||||
print(f"#{line}", end="")
|
||||
else:
|
||||
print(f"{line[1:]}", end="")
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
print(line, end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
@@ -1,7 +1,7 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
|
||||
numpy==1.22.4
|
||||
numpy>1.22.4
|
||||
torchvision
|
||||
pytorch-triton
|
||||
tabulate
|
||||
@@ -15,8 +15,8 @@ iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow==2.10.1
|
||||
keras==2.10
|
||||
tf-nightly
|
||||
keras>=2.10
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
|
||||
@@ -16,7 +16,7 @@ parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
transformers
|
||||
diffusers
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@4453536887526756afae14261d3335b5b12b86d8
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
|
||||
@@ -1,19 +1,54 @@
|
||||
<#
|
||||
.SYNOPSIS
|
||||
A script to update and install the SHARK runtime and its dependencies.
|
||||
|
||||
.DESCRIPTION
|
||||
This script updates and installs the SHARK runtime and its dependencies.
|
||||
It checks the Python version installed and installs any required build
|
||||
dependencies into a Python virtual environment.
|
||||
If that environment does not exist, it creates it.
|
||||
|
||||
.PARAMETER update-src
|
||||
git pulls latest version
|
||||
|
||||
.PARAMETER force
|
||||
removes and recreates venv to force update of all dependencies
|
||||
|
||||
.EXAMPLE
|
||||
.\setup_venv.ps1 --force
|
||||
|
||||
.EXAMPLE
|
||||
.\setup_venv.ps1 --update-src
|
||||
|
||||
.INPUTS
|
||||
None
|
||||
|
||||
.OUTPUTS
|
||||
None
|
||||
|
||||
#>
|
||||
|
||||
param([string]$arguments)
|
||||
|
||||
if ($arguments -eq "--update-src"){
|
||||
git pull
|
||||
}
|
||||
|
||||
#Write-Host "Installing python"
|
||||
|
||||
#Start-Process winget install Python.Python.3.10 '/quiet InstallAllUsers=1 PrependPath=1' -wait -NoNewWindow
|
||||
|
||||
#Write-Host "python installation completed successfully"
|
||||
|
||||
#Write-Host "Reload environment variables"
|
||||
#$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||
#Write-Host "Reloaded environment variables"
|
||||
|
||||
if ($arguments -eq "--force"){
|
||||
if (Test-Path env:VIRTUAL_ENV) {
|
||||
Write-Host "deactivating..."
|
||||
Deactivate
|
||||
}
|
||||
|
||||
if (Test-Path .\shark.venv\) {
|
||||
Write-Host "removing and recreating venv..."
|
||||
Remove-Item .\shark.venv -Force -Recurse
|
||||
if (Test-Path .\shark.venv\) {
|
||||
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# redirect stderr into stdout
|
||||
$p = &{python -V} 2>&1
|
||||
@@ -25,19 +60,36 @@ $version = if($p -is [System.Management.Automation.ErrorRecord])
|
||||
}
|
||||
else
|
||||
{
|
||||
# otherwise return as is
|
||||
$p
|
||||
# otherwise return complete Python list
|
||||
$ErrorActionPreference = 'SilentlyContinue'
|
||||
$PyVer = py --list
|
||||
}
|
||||
|
||||
Write-Host "Python version found is"
|
||||
Write-Host $p
|
||||
# deactivate any activated venvs
|
||||
if ($PyVer -like "*venv*")
|
||||
{
|
||||
deactivate # make sure we don't update the wrong venv
|
||||
$PyVer = py --list # update list
|
||||
}
|
||||
|
||||
Write-Host "Python versions found are"
|
||||
Write-Host ($PyVer | Out-String) # formatted output with line breaks
|
||||
if (!($PyVer.length -ne 0)) {$p} # return Python --version String if py.exe is unavailable
|
||||
if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any list
|
||||
{
|
||||
Write-Host "Please install Python 3.11 and try again"
|
||||
break
|
||||
}
|
||||
|
||||
Write-Host "Installing Build Dependencies"
|
||||
python -m venv .\shark.venv\
|
||||
# make sure we really use 3.11 from list, even if it's not the default.
|
||||
if (!($PyVer.length -ne 0)) {py -3.11 -m venv .\shark.venv\}
|
||||
else {python -m venv .\shark.venv\}
|
||||
.\shark.venv\Scripts\activate
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
@@ -42,7 +42,7 @@ Green=`tput setaf 2`
|
||||
Yellow=`tput setaf 3`
|
||||
|
||||
# Assume no binary torch-mlir.
|
||||
# Currently available for macOS m1&intel (3.10) and Linux(3.7,3.8,3.9,3.10)
|
||||
# Currently available for macOS m1&intel (3.11) and Linux(3.8,3.10,3.11)
|
||||
torch_mlir_bin=false
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "${Yellow}Apple macOS detected"
|
||||
@@ -60,12 +60,12 @@ if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
fi
|
||||
echo "${Yellow}Run the following commands to setup your SSL certs for your Python version if you see SSL errors with tests"
|
||||
echo "${Yellow}/Applications/Python\ 3.XX/Install\ Certificates.command"
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.10" ]; then
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.11" ]; then
|
||||
torch_mlir_bin=true
|
||||
fi
|
||||
elif [[ $(uname -s) = 'Linux' ]]; then
|
||||
echo "${Yellow}Linux detected"
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.7" ] || [ "$PYTHON_VERSION_X_Y" == "3.8" ] || [ "$PYTHON_VERSION_X_Y" == "3.9" ] || [ "$PYTHON_VERSION_X_Y" == "3.10" ] ; then
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.8" ] || [ "$PYTHON_VERSION_X_Y" == "3.10" ] || [ "$PYTHON_VERSION_X_Y" == "3.11" ] ; then
|
||||
torch_mlir_bin=true
|
||||
fi
|
||||
else
|
||||
@@ -89,7 +89,7 @@ if [ "$torch_mlir_bin" = true ]; then
|
||||
fi
|
||||
else
|
||||
echo "${Red}No binaries found for Python $PYTHON_VERSION_X_Y on $(uname -s)"
|
||||
echo "${Yello}Python 3.10 supported on macOS and 3.7,3.8,3.9 and 3.10 on Linux"
|
||||
echo "${Yello}Python 3.11 supported on macOS and 3.8,3.10 and 3.11 on Linux"
|
||||
echo "${Red}Please build torch-mlir from source in your environment"
|
||||
exit 1
|
||||
fi
|
||||
@@ -98,7 +98,7 @@ if [[ -z "${USE_IREE}" ]]; then
|
||||
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
|
||||
else
|
||||
touch ./.use-iree
|
||||
RUNTIME="https://iree-org.github.io/iree/pip-release-links.html"
|
||||
RUNTIME="https://openxla.github.io/iree/pip-release-links.html"
|
||||
fi
|
||||
if [[ -z "${NO_BACKEND}" ]]; then
|
||||
echo "Installing ${RUNTIME}..."
|
||||
@@ -112,7 +112,7 @@ if [[ ! -z "${IMPORTER}" ]]; then
|
||||
if [[ $(uname -s) = 'Linux' ]]; then
|
||||
echo "${Yellow}Linux detected.. installing Linux importer tools"
|
||||
#Always get the importer tools from upstream IREE
|
||||
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://iree-org.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://openxla.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
elif [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "${Yellow}macOS detected.. installing macOS importer tools"
|
||||
#Conda seems to have some problems installing these packages and hope they get resolved upstream.
|
||||
@@ -129,7 +129,7 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
TV_VERSION=${TV_VER:9:18}
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu117."
|
||||
else
|
||||
|
||||
698
shark/examples/shark_inference/sharded_bloom.py
Normal file
698
shark/examples/shark_inference/sharded_bloom.py
Normal file
@@ -0,0 +1,698 @@
|
||||
####################################################################################
|
||||
# Please make sure you have transformers 4.21.2 installed before running this demo
|
||||
#
|
||||
# -p --model_path: the directory in which you want to store the bloom files.
|
||||
# -dl --device_list: the list of device indices you want to use. if you want to only use the first device, or you are running on cpu leave this blank.
|
||||
# Otherwise, please give this argument in this format: "[0, 1, 2]"
|
||||
# -de --device: the device you want to run bloom on. E.G. cpu, cuda
|
||||
# -c, --recompile: set to true if you want to recompile to vmfb.
|
||||
# -d, --download: set to true if you want to redownload the mlir files
|
||||
# -t --token_count: the number of tokens you want to generate
|
||||
# -pr --prompt: the prompt you want to feed to the model
|
||||
# -m --model_namme: the name of the model, e.g. bloom-560m
|
||||
#####################################################################################
|
||||
|
||||
import os
|
||||
import io
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
import torch_mlir
|
||||
from torch_mlir import TensorPlaceholder
|
||||
import re
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
||||
import json
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
import urllib.request
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from transformers import (
|
||||
BloomTokenizerFast,
|
||||
BloomForSequenceClassification,
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from transformers.models.bloom.modeling_bloom import (
|
||||
BloomBlock,
|
||||
build_alibi_tensor,
|
||||
)
|
||||
|
||||
IS_CUDA = False
|
||||
|
||||
|
||||
class ShardedBloom:
|
||||
def __init__(self, src_folder):
|
||||
f = open(f"{src_folder}/config.json")
|
||||
config = json.load(f)
|
||||
f.close()
|
||||
|
||||
self.layers_initialized = False
|
||||
|
||||
self.src_folder = src_folder
|
||||
try:
|
||||
self.n_embed = config["n_embed"]
|
||||
except KeyError:
|
||||
self.n_embed = config["hidden_size"]
|
||||
self.vocab_size = config["vocab_size"]
|
||||
self.n_layer = config["n_layer"]
|
||||
try:
|
||||
self.n_head = config["num_attention_heads"]
|
||||
except KeyError:
|
||||
self.n_head = config["n_head"]
|
||||
|
||||
def _init_layer(self, layer_name, device, replace, device_idx):
|
||||
if replace or not os.path.exists(
|
||||
f"{self.src_folder}/{layer_name}.vmfb"
|
||||
):
|
||||
f_ = open(f"{self.src_folder}/{layer_name}.mlir", encoding="utf-8")
|
||||
module = f_.read()
|
||||
f_.close()
|
||||
module = bytes(module, "utf-8")
|
||||
shark_module = SharkInference(
|
||||
module,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
shark_module.save_module(
|
||||
module_name=f"{self.src_folder}/{layer_name}",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
"",
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
return shark_module
|
||||
|
||||
def init_layers(self, device, replace=False, device_idx=[0]):
|
||||
if device_idx is not None:
|
||||
n_devices = len(device_idx)
|
||||
|
||||
self.word_embeddings_module = self._init_layer(
|
||||
"word_embeddings",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[0 % n_devices],
|
||||
)
|
||||
self.word_embeddings_layernorm_module = self._init_layer(
|
||||
"word_embeddings_layernorm",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[1 % n_devices],
|
||||
)
|
||||
self.ln_f_module = self._init_layer(
|
||||
"ln_f",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[2 % n_devices],
|
||||
)
|
||||
self.lm_head_module = self._init_layer(
|
||||
"lm_head",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[3 % n_devices],
|
||||
)
|
||||
self.block_modules = [
|
||||
self._init_layer(
|
||||
f"bloom_block_{i}",
|
||||
device,
|
||||
replace,
|
||||
device_idx
|
||||
if device_idx is None
|
||||
else device_idx[(i + 4) % n_devices],
|
||||
)
|
||||
for i in range(self.n_layer)
|
||||
]
|
||||
|
||||
self.layers_initialized = True
|
||||
|
||||
def load_layers(self):
|
||||
assert self.layers_initialized
|
||||
|
||||
self.word_embeddings_module.load_module(
|
||||
f"{self.src_folder}/word_embeddings.vmfb"
|
||||
)
|
||||
self.word_embeddings_layernorm_module.load_module(
|
||||
f"{self.src_folder}/word_embeddings_layernorm.vmfb"
|
||||
)
|
||||
for block_module, i in zip(self.block_modules, range(self.n_layer)):
|
||||
block_module.load_module(f"{self.src_folder}/bloom_block_{i}.vmfb")
|
||||
self.ln_f_module.load_module(f"{self.src_folder}/ln_f.vmfb")
|
||||
self.lm_head_module.load_module(f"{self.src_folder}/lm_head.vmfb")
|
||||
|
||||
def forward_pass(self, input_ids, device):
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.word_embeddings_module.device_idx)
|
||||
|
||||
input_embeds = self.word_embeddings_module(
|
||||
inputs=(input_ids,), function_name="forward"
|
||||
)
|
||||
|
||||
input_embeds = torch.tensor(input_embeds).float()
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.word_embeddings_layernorm_module.device_idx)
|
||||
hidden_states = self.word_embeddings_layernorm_module(
|
||||
inputs=(input_embeds,), function_name="forward"
|
||||
)
|
||||
|
||||
hidden_states = torch.tensor(hidden_states).float()
|
||||
|
||||
attention_mask = torch.ones(
|
||||
[hidden_states.shape[0], len(input_ids[0])]
|
||||
)
|
||||
alibi = build_alibi_tensor(
|
||||
attention_mask,
|
||||
self.n_head,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
)
|
||||
|
||||
causal_mask = _prepare_attn_mask(
|
||||
attention_mask, input_ids.size(), input_embeds, 0
|
||||
)
|
||||
causal_mask = torch.tensor(causal_mask).float()
|
||||
|
||||
presents = ()
|
||||
all_hidden_states = tuple(hidden_states)
|
||||
|
||||
for block_module, i in zip(self.block_modules, range(self.n_layer)):
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(block_module.device_idx)
|
||||
|
||||
output = block_module(
|
||||
inputs=(
|
||||
hidden_states.detach().numpy(),
|
||||
alibi.detach().numpy(),
|
||||
causal_mask.detach().numpy(),
|
||||
),
|
||||
function_name="forward",
|
||||
)
|
||||
hidden_states = torch.tensor(output[0]).float()
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
presents = presents + (
|
||||
tuple(
|
||||
(
|
||||
output[1],
|
||||
output[2],
|
||||
)
|
||||
),
|
||||
)
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.ln_f_module.device_idx)
|
||||
|
||||
hidden_states = self.ln_f_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.lm_head_module.device_idx)
|
||||
|
||||
logits = self.lm_head_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
logits = torch.tensor(logits).float()
|
||||
|
||||
return torch.argmax(logits[:, -1, :], dim=-1)
|
||||
|
||||
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
past_key_values_length: int = 0,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
|
||||
mask.masked_fill_(intermediate_mask, 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
target_length, past_key_values_length, dtype=dtype
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
expanded_mask = mask[None, None, :, :].expand(
|
||||
batch_size, 1, target_length, target_length + past_key_values_length
|
||||
)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
batch_size, source_length = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else source_length
|
||||
|
||||
expanded_mask = (
|
||||
mask[:, None, None, :]
|
||||
.expand(batch_size, 1, tgt_len, source_length)
|
||||
.to(dtype)
|
||||
)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def _prepare_attn_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
past_key_values_length=past_key_values_length,
|
||||
).to(attention_mask.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
|
||||
def download_model(destination_folder, model_name):
|
||||
download_public_file(
|
||||
f"https://{model_name}/config.json", destination_folder
|
||||
)
|
||||
f = open(f"{destination_folder}/config.json")
|
||||
config = json.load(f)
|
||||
f.close()
|
||||
n_blocks = config["n_layer"]
|
||||
download_public_file(
|
||||
f"https://{model_name}/lm_head.mlir", destination_folder
|
||||
)
|
||||
download_public_file(f"https://{model_name}/ln_f.mlir", destination_folder)
|
||||
download_public_file(
|
||||
f"https://{model_name}/word_embeddings.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
f"https://{model_name}/word_embeddings_layernorm.mlir",
|
||||
destination_folder,
|
||||
)
|
||||
download_public_file(
|
||||
f"https://{model_name}/tokenizer.json", destination_folder
|
||||
)
|
||||
|
||||
for i in range(n_blocks):
|
||||
download_public_file(
|
||||
f"https://{model_name}/bloom_block_{i}.mlir", destination_folder
|
||||
)
|
||||
|
||||
|
||||
def compile_embeddings(embeddings_layer, input_ids, path):
|
||||
input_ids_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
input_ids, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
embeddings_layer,
|
||||
(input_ids_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def compile_word_embeddings_layernorm(
|
||||
embeddings_layer_layernorm, embeds, path
|
||||
):
|
||||
embeds_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
embeds, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
embeddings_layer_layernorm,
|
||||
(embeds_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def compile_to_mlir(
|
||||
bblock,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
output_attentions=False,
|
||||
alibi=None,
|
||||
block_index=0,
|
||||
path=".",
|
||||
):
|
||||
fx_g = make_fx(
|
||||
bblock,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
tracing_mode="real",
|
||||
_allow_non_fake_inputs=False,
|
||||
)(hidden_states, alibi, attention_mask)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
attention_mask, dynamic_axes=[2, 3]
|
||||
)
|
||||
alibi_placeholder = TensorPlaceholder.like(alibi, dynamic_axes=[2])
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
hidden_states_placeholder,
|
||||
alibi_placeholder,
|
||||
attention_mask_placeholder,
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
module_placeholder = module
|
||||
module_context = module_placeholder.context
|
||||
|
||||
def check_valid_line(line, line_n, mlir_file_len):
|
||||
if "private" in line:
|
||||
return False
|
||||
if "attributes" in line:
|
||||
return False
|
||||
if mlir_file_len - line_n == 2:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
mlir_file_len = len(str(module).split("\n"))
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "17x" in line:
|
||||
line = re.sub("17x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi eq" in line:
|
||||
line = re.sub("c17", "dim", line)
|
||||
if " 17," in line:
|
||||
line = re.sub(" 17,", " %dim,", line)
|
||||
return line
|
||||
|
||||
module = "\n".join(
|
||||
[
|
||||
remove_constant_dim(line)
|
||||
for line, line_n in zip(
|
||||
str(module).split("\n"), range(mlir_file_len)
|
||||
)
|
||||
if check_valid_line(line, line_n, mlir_file_len)
|
||||
]
|
||||
)
|
||||
|
||||
module = module_placeholder.parse(module, context=module_context)
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def compile_ln_f(ln_f, hidden_layers, path):
|
||||
hidden_layers_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_layers, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ln_f,
|
||||
(hidden_layers_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def compile_lm_head(lm_head, hidden_layers, path):
|
||||
hidden_layers_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_layers, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
lm_head,
|
||||
(hidden_layers_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def create_mlirs(destination_folder, model_name):
|
||||
model_config = "bigscience/" + model_name
|
||||
sample_input_ids = torch.ones([1, 17], dtype=torch.int64)
|
||||
|
||||
urllib.request.urlretrieve(
|
||||
f"https://huggingface.co/bigscience/{model_name}/resolve/main/config.json",
|
||||
filename=f"{destination_folder}/config.json",
|
||||
)
|
||||
urllib.request.urlretrieve(
|
||||
f"https://huggingface.co/bigscience/bloom/resolve/main/tokenizer.json",
|
||||
filename=f"{destination_folder}/tokenizer.json",
|
||||
)
|
||||
|
||||
class HuggingFaceLanguage(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = BloomForCausalLM.from_pretrained(model_config)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
class HuggingFaceBlock(torch.nn.Module):
|
||||
def __init__(self, block):
|
||||
super().__init__()
|
||||
self.model = block
|
||||
|
||||
def forward(self, tokens, alibi, attention_mask):
|
||||
output = self.model(
|
||||
hidden_states=tokens,
|
||||
alibi=alibi,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
output_attentions=False,
|
||||
)
|
||||
return (output[0], output[1][0], output[1][1])
|
||||
|
||||
model = HuggingFaceLanguage()
|
||||
|
||||
compile_embeddings(
|
||||
model.model.transformer.word_embeddings,
|
||||
sample_input_ids,
|
||||
f"{destination_folder}/word_embeddings.mlir",
|
||||
)
|
||||
|
||||
inputs_embeds = model.model.transformer.word_embeddings(sample_input_ids)
|
||||
|
||||
compile_word_embeddings_layernorm(
|
||||
model.model.transformer.word_embeddings_layernorm,
|
||||
inputs_embeds,
|
||||
f"{destination_folder}/word_embeddings_layernorm.mlir",
|
||||
)
|
||||
|
||||
hidden_states = model.model.transformer.word_embeddings_layernorm(
|
||||
inputs_embeds
|
||||
)
|
||||
|
||||
input_shape = sample_input_ids.size()
|
||||
|
||||
current_sequence_length = hidden_states.shape[1]
|
||||
past_key_values_length = 0
|
||||
past_key_values = tuple([None] * len(model.model.transformer.h))
|
||||
|
||||
attention_mask = torch.ones(
|
||||
(hidden_states.shape[0], current_sequence_length), device="cpu"
|
||||
)
|
||||
|
||||
alibi = build_alibi_tensor(
|
||||
attention_mask,
|
||||
model.model.transformer.n_head,
|
||||
hidden_states.dtype,
|
||||
"cpu",
|
||||
)
|
||||
|
||||
causal_mask = _prepare_attn_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
head_mask = model.model.transformer.get_head_mask(
|
||||
None, model.model.transformer.config.n_layer
|
||||
)
|
||||
output_attentions = model.model.transformer.config.output_attentions
|
||||
|
||||
all_hidden_states = ()
|
||||
|
||||
for i, (block, layer_past) in enumerate(
|
||||
zip(model.model.transformer.h, past_key_values)
|
||||
):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
proxy_model = HuggingFaceBlock(block)
|
||||
|
||||
compile_to_mlir(
|
||||
proxy_model,
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=True,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
block_index=i,
|
||||
path=f"{destination_folder}/bloom_block_{i}.mlir",
|
||||
)
|
||||
|
||||
compile_ln_f(
|
||||
model.model.transformer.ln_f,
|
||||
hidden_states,
|
||||
f"{destination_folder}/ln_f.mlir",
|
||||
)
|
||||
hidden_states = model.model.transformer.ln_f(hidden_states)
|
||||
compile_lm_head(
|
||||
model.model.lm_head,
|
||||
hidden_states,
|
||||
f"{destination_folder}/lm_head.mlir",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(prog="Bloom-560m")
|
||||
parser.add_argument("-p", "--model_path")
|
||||
parser.add_argument("-dl", "--device_list", default=None)
|
||||
parser.add_argument("-de", "--device", default="cpu")
|
||||
parser.add_argument("-c", "--recompile", default=False, type=bool)
|
||||
parser.add_argument("-d", "--download", default=False, type=bool)
|
||||
parser.add_argument("-t", "--token_count", default=10, type=int)
|
||||
parser.add_argument("-m", "--model_name", default="bloom-560m")
|
||||
parser.add_argument(
|
||||
"-pr",
|
||||
"--prompt",
|
||||
default="The SQL command to extract all the users whose name starts with A is: ",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device_list is not None:
|
||||
args.device_list = json.loads(args.device_list)
|
||||
|
||||
if args.device == "cuda" and args.device_list is not None:
|
||||
IS_CUDA = True
|
||||
from cuda.cudart import cudaSetDevice
|
||||
if args.download:
|
||||
# download_model(args.model_path, args.model_name)
|
||||
create_mlirs(args.model_path, args.model_name)
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
input_ids = tokenizer.encode(args.prompt, return_tensors="pt")
|
||||
|
||||
shardedbloom = ShardedBloom(args.model_path)
|
||||
shardedbloom.init_layers(
|
||||
device=args.device, replace=args.recompile, device_idx=args.device_list
|
||||
)
|
||||
shardedbloom.load_layers()
|
||||
|
||||
for _ in range(args.token_count):
|
||||
next_token = shardedbloom.forward_pass(
|
||||
torch.tensor(input_ids), device=args.device
|
||||
)
|
||||
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
|
||||
|
||||
print(tokenizer.decode(input_ids.squeeze()))
|
||||
@@ -139,9 +139,14 @@ def run_benchmark_module(benchmark_cl):
|
||||
benchmark_path
|
||||
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
|
||||
bench_result = run_cmd(" ".join(benchmark_cl))
|
||||
print(bench_result)
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_result)
|
||||
time = float(match.group(1))
|
||||
unit = match.group(3)
|
||||
try:
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_result)
|
||||
time = float(match.group(1))
|
||||
unit = match.group(3)
|
||||
except AttributeError:
|
||||
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_result)
|
||||
time = float(match.group(1))
|
||||
unit = match.group(2)
|
||||
return 1.0 / (time * 0.001)
|
||||
|
||||
@@ -99,6 +99,7 @@ else:
|
||||
print(
|
||||
f"shark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache= flag"
|
||||
)
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
|
||||
|
||||
# Checks whether the directory and files exists.
|
||||
|
||||
@@ -18,11 +18,12 @@ alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,False,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,True,True,True,"disabled until generateable",""
|
||||
bert-large-uncased,mhlo,tf,1e-2,1e-3,default,None,True,True,True,"disabled until generatedable",""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390",""
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
microsoft/resnet-50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344",""
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/388","macos"
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/343","macos"
|
||||
|
||||
|
@@ -31,3 +31,4 @@ xlm-roberta-base,False,False,-,-,-
|
||||
facebook/convnext-tiny-224,False,False,-,-,-
|
||||
efficientnet-v2-s,False,False,22M,"image-classification,cnn","Includes MBConv and Fused-MBConv"
|
||||
mnasnet1_0,False,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
|
||||
|
@@ -15,6 +15,7 @@ keras_models = ["resnet50", "efficientnet-v2-s"]
|
||||
maskedlm_models = [
|
||||
"albert-base-v2",
|
||||
"bert-base-uncased",
|
||||
"bert-large-uncased",
|
||||
"camembert-base",
|
||||
"dbmdz/convbert-base-turkish-cased",
|
||||
"deberta-base",
|
||||
|
||||
@@ -137,6 +137,19 @@ class SharkModuleTester:
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
shark_args.local_tank_cache = self.local_tank_cache
|
||||
shark_args.force_update_tank = self.update_tank
|
||||
shark_args.dispatch_benchmarks = self.benchmark_dispatches
|
||||
if self.benchmark_dispatches is not None:
|
||||
_m = self.config["model_name"].split("/")
|
||||
_m.extend([self.config["framework"], str(dynamic), device])
|
||||
_m = "_".join(_m)
|
||||
shark_args.dispatch_benchmarks_dir = os.path.join(
|
||||
self.dispatch_benchmarks_dir,
|
||||
_m,
|
||||
)
|
||||
if not os.path.exists(self.dispatch_benchmarks_dir):
|
||||
os.mkdir(self.dispatch_benchmarks_dir)
|
||||
if not os.path.exists(shark_args.dispatch_benchmarks_dir):
|
||||
os.mkdir(shark_args.dispatch_benchmarks_dir)
|
||||
if "nhcw-nhwc" in self.config["flags"] and not os.path.isfile(
|
||||
".use-iree"
|
||||
):
|
||||
@@ -278,6 +291,12 @@ class SharkModuleTest(unittest.TestCase):
|
||||
"update_tank"
|
||||
)
|
||||
self.module_tester.tank_url = self.pytestconfig.getoption("tank_url")
|
||||
self.module_tester.benchmark_dispatches = self.pytestconfig.getoption(
|
||||
"benchmark_dispatches"
|
||||
)
|
||||
self.module_tester.dispatch_benchmarks_dir = (
|
||||
self.pytestconfig.getoption("dispatch_benchmarks_dir")
|
||||
)
|
||||
|
||||
if config["xfail_cpu"] == "True" and device == "cpu":
|
||||
pytest.xfail(reason=config["xfail_reason"])
|
||||
|
||||
@@ -18,3 +18,4 @@ microsoft/mpnet-base,hf
|
||||
facebook/convnext-tiny-224,img
|
||||
google/vit-base-patch16-224,img
|
||||
efficientnet-v2-s,keras
|
||||
bert-large-uncased,hf
|
||||
|
||||
|
@@ -18,3 +18,4 @@ nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encod
|
||||
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
|
Reference in New Issue
Block a user