Compare commits

...

94 Commits

Author SHA1 Message Date
jinchen62
696df349cb Fix curl issue (#1369) 2023-04-28 09:31:14 -07:00
jinchen62
cb54cb1348 Add model manager tab for SD webui (#1368) 2023-04-28 02:43:40 -07:00
Daniel Garvey
9bdb86637d add tkinter launch for webui (#1364) 2023-04-27 19:17:55 -05:00
jinchen62
fb6f26517f Fix webui note (#1367) 2023-04-27 16:14:43 -07:00
Chi_Liu
aa8ada9da9 Add support for torch to stablehlo and tosa in shark_importer (#1360) 2023-04-27 08:09:45 -07:00
powderluv
1db906a373 Revert "Add model manager tab for webui (#1359)" (#1362)
This reverts commit 9d1d1617d8.
2023-04-26 22:25:26 -07:00
jinchen62
9d1d1617d8 Add model manager tab for webui (#1359) 2023-04-26 13:38:18 -07:00
jinchen62
7112789cb8 Add support of using civitai model download url (#1357) 2023-04-25 23:39:52 -07:00
jinchen62
d6b8be2849 Add drawing canvas for img2img stencil scribble (#1355) 2023-04-25 14:41:01 -07:00
powderluv
822171277c Revert "[SD] Add FastChat as part of SD WebUI (#1349)" (#1350)
This reverts commit a5ae9d9f02.
2023-04-24 15:22:25 -07:00
Abhishek Varma
a5ae9d9f02 [SD] Add FastChat as part of SD WebUI (#1349)
-- This commit includes FastChat as part of SD WebUI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-04-24 11:12:58 -07:00
powderluv
09e3f63d5b Fix pascal (#1346)
* Add fp32 for upscaler VAE

* Plumb Pascal vulkan support
2023-04-23 20:28:25 -07:00
powderluv
d60a5a9396 Add fp32 for upscaler VAE (#1345) 2023-04-23 15:27:55 -07:00
m68k-fr
90df0ee365 [Web] Gallery set to a 768px reference for high-end desktop users (#1344) 2023-04-23 11:48:06 -07:00
nirvedhmeshram
133c1bcadd add device to scheduler model names (#1338) 2023-04-22 20:13:56 -05:00
powderluv
caadbe14e9 Revert VAE to use im2col (#1339) 2023-04-22 15:23:41 -07:00
Ean Garvey
5f5823ccd9 Fix inference object imports for SD apps. (#1334) 2023-04-21 13:40:48 -05:00
Vivek Khandelwal
d2f7e03b7e Add StableLM model (#1331) 2023-04-21 09:51:02 -07:00
Gaurav Shukla
0b01bbe479 [SD] Add txt2img/upscaler/inpaint/outpaint Rest API (#1325)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-21 09:06:06 -07:00
yzhang93
25c5fc44ae Modify tuner.py to take vulkan target triple flag (#1328) 2023-04-20 14:31:32 -07:00
Daniel Garvey
7330729c92 enable sd pytest (#1322) 2023-04-19 22:11:30 -05:00
Ean Garvey
ce16cd5431 Create local shark_tank if needed for tuning configs. (#1321)
Now that --clear_all successfully deletes local shark_tank cache, we need to make sure it exists before trying to use it.
2023-04-19 11:44:21 -05:00
Ean Garvey
598dc5f79d Don't dump image data on img2img api call. (#1320) 2023-04-19 21:24:46 +05:30
Abhishek Varma
1f8e332cbe [SD] Fix img2img API bug for custom_vae argument (#1319)
-- https://github.com/nod-ai/SHARK/pull/1314 misses to add `custom_vae`
   parameter to img2img_if's invocation within img2img_api.
-- This commit adds a fix to the same.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-04-19 10:39:52 -05:00
Abhishek Varma
17b9632659 [SD] Adapted SHARK's v1 img2img API for SdPaint + updated Stencil model ID (#1318) 2023-04-19 06:29:36 -07:00
jinchen62
bda92a54ab Fix custom vae path (#1317) 2023-04-18 20:50:43 -07:00
jinchen62
747ed383b1 Add custom vae dropdown in webui (#1314) 2023-04-18 17:24:02 -07:00
Ean Garvey
1afe07c296 Disable winograd on VAE with rdna2 and fix unet tuning. (#1313)
* Disable winograd on VAE with rdna2 and fix unet tuning.

* Fix batch size 1 downloads and clear_all on windows.
2023-04-18 15:55:10 -05:00
jinchen62
b70919b38d Fix memory leak with ondemand (#1312)
support ondemand for outpainting and multi batch_count
2023-04-18 13:03:16 -05:00
m68k-fr
4e513d647f Update list of scheduler available for inferences (#1298) 2023-04-17 22:37:00 -05:00
jinchen62
94cd2a0fed Fix outpainting config (#1310) 2023-04-17 10:48:52 -07:00
Kyle Herndon
606029c01c Fix LoRA device format bug and allow LoRA to resume from a previous training 2023-04-17 13:19:46 +05:30
powderluv
1aa85222e9 Add AMD W7900 target triple (#1304)
This maps to RDNA3
2023-04-16 00:14:21 -07:00
m68k-fr
1b3f468c04 [Web] Style Fixes for Gradio V3.25.0 (#1300) 2023-04-13 18:40:42 -05:00
m68k-fr
35de7e27fa [Web] remove txt2img ui dependencies from png import metadata (#1275) 2023-04-12 07:32:47 -10:00
yzhang93
467f900759 Add auto-tuner to SD apps (#1291) 2023-04-12 09:21:17 -07:00
Ean Garvey
0bd9d582c7 Add documentation for using SHARK with AI-Render (#1296) 2023-04-12 03:09:34 -10:00
jinchen62
428cfe8dae Fix low vram mode issues (#1295)
- add ondemand back to img2img
- workaround memory leak for batch count
2023-04-11 17:59:09 -07:00
Ean Garvey
f17915bedc Fix batch size appending to model name. (#1294)
* Update shark_downloader.py

* Update shark_downloader.py
2023-04-11 15:34:25 -05:00
Gaurav Shukla
1b49b5149a [SD] Add Img2Img rest API
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-11 23:06:58 +05:30
jinchen62
3002793301 Unload clip on demand and workaround memory leak (#1283) 2023-04-10 16:59:03 -07:00
Phaneesh Barwaria
d25ef5529f Add fix for vae fp32 Upscalar (#1284)
- fixes size mismatch error for upscalar vae
2023-04-07 14:36:40 -05:00
Ean Garvey
308856a947 Touch unet if base cfg needed for SD pipeline init (#1281) 2023-04-05 03:02:29 -05:00
m68k-fr
151b4e142f [SD] Fix encoder error for model_max_length not beeing 77 (#1278)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-04-04 22:39:29 -07:00
Ean Garvey
e5a69a7c36 pin diffusers to e47459c (#1279) 2023-04-04 18:29:21 -07:00
m68k-fr
450b6cafc4 [SD] Add weight emphasis to prompts encoder (#1276) 2023-04-04 09:47:04 -07:00
Daniel Garvey
237d26baa2 update model db to reflect changes (#1277)
* remove 1/1 tqdm progress bar

* update model_db to reflect changes
2023-04-04 11:46:55 -05:00
Daniel Garvey
67d6ee1104 remove 1/1 tqdm progress bar (#1274) 2023-04-03 22:30:09 -05:00
Ean Garvey
98b069488e Add tank_version.json (#1272) 2023-04-03 18:36:23 -07:00
jinchen62
e0f227643a Fix webui circular import issue (#1271) 2023-04-03 16:00:10 -07:00
jinchen62
a0af3bb0cb xload and unload models (#1242) 2023-04-03 14:42:18 -07:00
powderluv
2cd61a5b96 strip source map (#1270) 2023-04-03 14:41:32 -07:00
Gaurav Shukla
f49d41a807 [SD] Add Stable diffusion text2image rest API (#1265)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-03 12:02:24 -07:00
Ean Garvey
2191fc8952 Separate pytest benchmark modes and fix model updates for SHARK downloader / pytest. (#1264)
* Only xfail windows models in CI

* downloader: make model updates more robust.

* Separate baseline and native benchmarks in pytest.

* Fix native benchmarks

* Fix torchvision model utils.
2023-04-03 08:24:21 -07:00
PhaneeshB
aea7796e60 add gradio client to spec 2023-04-03 18:57:19 +05:30
Abhishek Varma
a376619f1e [SD] Improve vmfb caching algo and retry mechanism (#1248)
-- This commit gets rid of the all-or-nothing vmfb caching mechanism
   and improves the retry mechanism by providing lower-level granularity
   for compiling each model units.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-03-31 09:38:14 -07:00
powderluv
02d52bb626 Add Intel ARC A770 target triple (#1263)
This just enables the plumbing. It generates black images.
2023-03-29 14:49:05 -07:00
Abhishek Varma
3b63645f79 [SD] Fix custom model path for WebUI (#1260)
-- This commit fixes custom model path for WebUI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-29 09:48:11 -07:00
Ean Garvey
d6f740b998 allow pytest to retry getting model artifacts + disable autotuning for pytorch benchmarks (#1257)
* Adds a few xfails to enable macOS builder

* Convert string batch sizes to ints where needed.

* allow pytest to retry getting model artifacts

* Reduce attempts and add assert msg.
2023-03-28 23:38:45 -05:00
Daniel Garvey
594c6b8ea2 fix ckpt dir (#1258) 2023-03-28 14:31:01 -07:00
Ean Garvey
96b1560da5 Make batch size configurable via pytest and fix sharktank generation. (#1227)
* Fix sharktank generation and add batch_size pytest option for torch.

* Disable torch dynamo until py3.11 supported

* Compile torchmodel without dynamo if torch.compile fails

* Use release versions of TF/Keras for importer.

* Pin torchvision and remove debug prints.

* Remove duplicates from torch model list.

* Update generate_sharktank.py

* xfail a few models that fail sharktank generation/ numerics
2023-03-28 14:33:39 -05:00
Abhishek Varma
0ef6a0e234 [SD] Fix Stencil scribble crash by updating image resize (#1255)
-- This commit updates Stencil resize feature to cap the size of
   images within [128,768] as supported by the SD pipeline.
-- This solves the issue of scribble crashing on larger image.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-28 10:13:11 -07:00
Gaurav Shukla
641d535f44 [SD] Fix device path issue for cpu (#1256)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-28 10:09:49 -07:00
Daniel Garvey
5bb7846227 single entry point exe for all cli apps (#1158)
usage:
add --app="img2img" (or "inpaint" "outpaint" "txt2img")
2023-03-28 11:15:21 -05:00
yzhang93
8f84258fb8 Fix check for use_tuned conditions (#1252) 2023-03-27 11:21:25 -07:00
Ean Garvey
7619e76bbd Disable and xfail some models that fail validation/compilation. (#1251)
* Rollback T5 models for torch as the inputs give some issues that aren't trivial to resolve
* xfail efficientnet-b0 on torch+cuda -- see CUDA requesting shared memory size larger than allowed size openxla/iree#12771
2023-03-27 12:42:53 -05:00
Daniel Garvey
9267eadbfa disable openjourney gen for nightly (#1249) 2023-03-27 11:55:34 -05:00
Phaneesh Barwaria
431132b8ee Fix img2img mode switch (#1247)
* add updated scheduler value in global config

* clear scheduler global variable with others
2023-03-27 07:01:22 -07:00
cstueckrath
fb35e13e7a fix Python version detection bug (#1246)
* fix Python version detection bug

* Update setup_venv.ps1
2023-03-27 07:00:40 -07:00
yzhang93
17a67897d1 Add SD v2.1 768x768 tuned model (#1244)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-24 10:39:15 -07:00
Gaurav Shukla
da449b73aa [SD] Disable lora training tab for now (#1241)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-24 09:16:24 -07:00
Kyle Herndon
0b0526699a Fix incorrect device argument initialization for LoRA training by extracting the device type and number and formatting it for pytorch (#1237)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
2023-03-24 01:10:50 -07:00
Boian Petkantchin
4fac46f7bb In models testing fix paths to be relative to the script dir not cwd (#1128)
authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-03-22 15:26:52 -05:00
Daniel Garvey
49925950f1 fix false positives (#1193) 2023-03-22 15:25:39 -05:00
Thomas
807947c0c8 Remove deprecated cli option iree-hal-cuda-disable-loop-nounroll-wa (#1235) 2023-03-22 12:05:15 -05:00
Abhishek Varma
593428bda4 [SD] Fix for transformers/__init__.py issue in PyInstaller (#1233)
-- This commit fixes the transformers/__init__.py issue in PyInstaller.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-22 08:43:53 -07:00
Abhishek Varma
cede9b4fec [SD] Fix custom_vae as a required parameter in inpaint (#1232) 2023-03-22 04:30:17 -07:00
Prashant Kumar
c2360303f0 Add the int8 quantized model. 2023-03-22 16:28:13 +05:30
jinchen62
420366c1b8 Move schedulers to global obj (#1225) 2023-03-21 22:40:43 -07:00
Ean Garvey
d31bae488c Set iree-input-type to tm_tensor for SD (#1228) 2023-03-21 19:07:31 -07:00
Kyle Herndon
c23fcf3748 Fix incorrect device argument initialization for LoRA training (#1231)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-21 19:07:18 -07:00
jinchen62
7dbbb1726a Fix SD obj not defined if fail to get models from pretrained (#1222) 2023-03-21 07:55:17 -07:00
Abhishek Varma
8b8cc7fd33 [SD] Update LoRA inference to handle various checkpoints (#1215) 2023-03-21 06:52:20 -07:00
Ean Garvey
e3c96a2b9d Move sentencepiece to importer requirements. (#1218) 2023-03-21 00:39:57 -05:00
Ean Garvey
5e3f50647d Set --vulkan_large_heap_block_size default to 2gb. (#1220) 2023-03-20 21:07:09 -07:00
gpetters94
7899e1803a Add fix for attention slicing fp16 (#1217) 2023-03-20 19:11:29 -07:00
mariecwhite
d105246b9c Fix t5 models 2023-03-21 10:39:59 +11:00
mariecwhite
90c958bca2 Add T5-base and T5-large Torch and TF Models (#1116) 2023-03-20 17:32:50 -05:00
mariecwhite
f99903e023 Add EfficientNet B0 and B7 Torch and TF models 2023-03-21 09:22:05 +11:00
mariecwhite
c6f44ef1b3 Add EfficientNet B0 and B7 Torch and TF models 2023-03-21 09:14:45 +11:00
mariecwhite
8dcd4d5aeb Make batch size configurable 2023-03-20 18:03:17 -04:00
Phoenix Meadowlark
d319f4684e Add peak memory reporting for IREE, TF and PyTorch (#1216) 2023-03-20 15:40:49 -05:00
Ean Garvey
54d7b6d83e Generate model artifacts in pytests if they don't exist in the cloud. (#1121)
* Add gen_shark_files fn to shark_downloader for OTF artifact generation

* add generate_sharktank as a tank/ python module.

* Fix some paths in tank generation.
2023-03-20 12:13:19 -05:00
m68k-fr
4a622532e5 [Web] Stop images (#1212) 2023-03-19 14:37:30 -07:00
81 changed files with 5523 additions and 2617 deletions

View File

@@ -112,7 +112,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
@@ -120,9 +120,9 @@ jobs:
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
@@ -145,17 +145,19 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
pytest -k vulkan -s
pytest -k vulkan -s --ci
- name: Validate Stable Diffusion Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
python build_tools/stable_diffusion_testing.py --device=vulkan

View File

@@ -114,12 +114,12 @@ source shark.venv/bin/activate
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
```shell
python3.11 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
python3.11 apps/stable_diffusion/scripts/main.py --app=txt2img --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc

View File

@@ -0,0 +1,207 @@
import torch
import shark
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
import torch_mlir
from apps.stable_diffusion.src.utils import (
base_models,
get_opt_flags,
get_vmfb_path_name,
)
from apps.stable_diffusion.src.models.model_wrappers import replace_shape_str
import os
from io import BytesIO
tokenizer = AutoTokenizer.from_pretrained(
"stabilityai/stablelm-tuned-alpha-7b"
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
prompt = f"{system_prompt}<|USER|>What's your mood today?<|ASSISTANT|>"
inputs = tokenizer(prompt, return_tensors="pt")
class SLM(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
"stabilityai/stablelm-tuned-alpha-7b"
)
def forward(self, input_ids, attention_mask):
return self.model(input_ids, attention_mask)[0]
slm_model = SLM()
res_pytorch = slm_model(inputs["input_ids"], inputs["attention_mask"])
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
fx_g = make_fx(
slm_model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(inputs["input_ids"], inputs["attention_mask"])
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
transform_fx(fx_g)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
module = torch_mlir.compile(
ts_g,
[inputs["input_ids"], inputs["attention_mask"]],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
)
shark_module.compile()
result_shark = shark_module(
"forward", [inputs["input_ids"], inputs["attention_mask"]]
)
print("Result PyTorch")
print(res_pytorch)
print("Result SHARK")
print(result_shark)

View File

@@ -1,6 +1 @@
from apps.stable_diffusion.scripts.txt2img import txt2img_inf
from apps.stable_diffusion.scripts.img2img import img2img_inf
from apps.stable_diffusion.scripts.inpaint import inpaint_inf
from apps.stable_diffusion.scripts.outpaint import outpaint_inf
from apps.stable_diffusion.scripts.upscaler import upscaler_inf
from apps.stable_diffusion.scripts.train_lora_word import lora_train

View File

@@ -2,276 +2,22 @@ import sys
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
StencilPipeline,
resize_stencil,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# Both width and height should be > 384 and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image):
width, height = image.size
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 384:
n_size = 384
if width == min_size:
width = n_size
height = n_size / aspect_ratio
else:
height = n_size
width = n_size * aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
new_image = image.resize((n_width, n_height))
return new_image, n_width, n_height
# Exposed to UI.
def img2img_inf(
prompt: str,
negative_prompt: str,
init_image,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.strength = strength
args.scheduler = scheduler
args.img_path = "not none"
if init_image is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB")
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_stencil = None if use_stencil == "None" else use_stencil
args.use_stencil = use_stencil
if use_stencil is not None:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler"
)
args.scheduler = "PNDM"
else:
sys.exit(
"Img2Img works best with PNDM scheduler. Other schedulers are not supported yet."
)
cpu_scheduling = not args.scheduler.startswith("Shark")
args.precision = precision
dtype = torch.float32 if precision == "fp32" else torch.half
new_config_obj = Config(
"img2img",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=use_lora,
use_stencil=use_stencil,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
if use_stencil is not None:
args.use_tuned = False
global_obj.set_sd_obj(
StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
)
else:
global_obj.set_sd_obj(
Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"STRENGTH": strength}
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
batch_size,
height,
width,
steps,
strength,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, strength={args.strength}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -288,16 +34,11 @@ if __name__ == "__main__":
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, args.width, args.height = resize_stencil(image)
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler"
)
args.scheduler = "PNDM"
else:
sys.exit(
"Img2Img works best with PNDM scheduler. Other schedulers are not supported yet."
)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
@@ -324,6 +65,7 @@ if __name__ == "__main__":
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
@@ -342,6 +84,7 @@ if __name__ == "__main__":
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
start_time = time.time()
@@ -377,3 +120,7 @@ if __name__ == "__main__":
extra_info = {"STRENGTH": args.strength}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -10,196 +11,10 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def inpaint_inf(
prompt: str,
negative_prompt: str,
image_dict,
height: int,
width: int,
inpaint_full_res: bool,
inpaint_full_res_padding: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
args.mask_path = "not none"
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"inpaint",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=use_lora,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_sd_obj(
InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
image = image_dict["image"]
mask_image = image_dict["mask"]
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
mask_image,
batch_size,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -229,6 +44,7 @@ if __name__ == "__main__":
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
@@ -236,10 +52,10 @@ if __name__ == "__main__":
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
@@ -282,3 +98,7 @@ if __name__ == "__main__":
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,19 @@
from apps.stable_diffusion.src import args
from apps.stable_diffusion.scripts import (
img2img,
txt2img,
# inpaint,
# outpaint,
)
if __name__ == "__main__":
if args.app == "txt2img":
txt2img.main()
elif args.app == "img2img":
img2img.main()
# elif args.app == "inpaint":
# inpaint.main()
# elif args.app == "outpaint":
# outpaint.main()
else:
print(f"args.app value is {args.app} but this isn't supported")

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
@@ -12,203 +13,7 @@ from apps.stable_diffusion.src import (
)
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def outpaint_inf(
prompt: str,
negative_prompt: str,
init_image,
pixels: int,
mask_blur: int,
directions: list,
noise_q: float,
color_variation: float,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"outpaint",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=use_lora,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_sd_obj(
OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
left = True if "left" in directions else False
right = True if "right" in directions else False
top = True if "up" in directions else False
bottom = True if "down" in directions else False
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
init_image,
pixels,
mask_blur,
left,
right,
top,
bottom,
noise_q,
color_variation,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -243,6 +48,7 @@ if __name__ == "__main__":
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
@@ -307,3 +113,7 @@ if __name__ == "__main__":
}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -73,6 +73,7 @@ from apps.stable_diffusion.src import (
set_init_device_flags,
clear_all,
)
from apps.stable_diffusion.src.utils import update_lora_weight
# Setup the dataset
@@ -159,7 +160,19 @@ class LoraDataset(Dataset):
return example
schedulers = None
def torch_device(device):
device_tokens = device.split("=>")
if len(device_tokens) == 1:
device_str = device_tokens[0].strip()
else:
device_str = device_tokens[1].strip()
device_type_tokens = device_str.split("://")
if device_type_tokens[0] == "metal":
device_type_tokens[0] = "vulkan"
if len(device_type_tokens) > 1:
return device_type_tokens[0] + ":" + device_type_tokens[1]
else:
return device_type_tokens[0]
########## Setting up the model ##########
@@ -180,6 +193,7 @@ def lora_train(
max_length: int,
training_images_dir: str,
lora_save_dir: str,
use_lora: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -187,8 +201,6 @@ def lora_train(
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
print(
"Note LoRA training is not compatible with the latest torch-mlir branch"
)
@@ -227,7 +239,8 @@ def lora_train(
args.max_length = max_length
args.height = height
args.width = width
args.device = device
args.device = torch_device(device)
args.use_lora = use_lora
# Load the Stable Diffusion model
text_encoder = CLIPTextModel.from_pretrained(
@@ -252,29 +265,33 @@ def lora_train(
unet.to(args.device)
text_encoder.to(args.device)
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if use_lora != "":
update_lora_weight(unet, args.use_lora, "unet")
else:
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)
unet.set_attn_processor(lora_attn_procs)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
class VaeModel(torch.nn.Module):
@@ -671,4 +688,5 @@ if __name__ == "__main__":
args.max_length,
args.training_images_dir,
args.lora_save_dir,
args.use_lora,
)

View File

@@ -0,0 +1,126 @@
import os
from pathlib import Path
from shark_tuner.codegen_tuner import SharkCodegenTuner
from shark_tuner.iree_utils import (
dump_dispatches,
create_context,
export_module_to_mlir_file,
)
from shark_tuner.model_annotation import model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.utils import set_init_device_flags
from apps.stable_diffusion.src.utils.sd_annotation import (
get_device_args,
load_winograd_configs,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def load_mlir_module():
sd_model = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
max_len=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=False,
low_cpu_mem_usage=args.low_cpu_mem_usage,
return_mlir=True,
)
if args.annotation_model == "unet":
mlir_module = sd_model.unet()
model_name = sd_model.model_name["unet"]
elif args.annotation_model == "vae":
mlir_module = sd_model.vae()
model_name = sd_model.model_name["vae"]
else:
raise ValueError(
f"{args.annotation_model} is not supported for tuning."
)
return mlir_module, model_name
def main():
args.use_tuned = False
set_init_device_flags()
mlir_module, model_name = load_mlir_module()
# Get device and device specific arguments
device, device_spec_args = get_device_args()
device_spec = ""
vulkan_target_triple = ""
if device_spec_args:
device_spec = device_spec_args[-1].split("=")[-1].strip()
if device == "vulkan":
vulkan_target_triple = device_spec
device_spec = device_spec.split("-")[0]
# Add winograd annotation for vulkan device
use_winograd = (
True
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
else False
)
winograd_config = (
load_winograd_configs()
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
else ""
)
with create_context() as ctx:
input_module = model_annotation(
ctx,
input_contents=mlir_module,
config_path=winograd_config,
search_op="conv",
winograd=use_winograd,
)
# Dump model dispatches
generates_dir = Path.home() / "tmp"
if not os.path.exists(generates_dir):
os.makedirs(generates_dir)
dump_mlir = generates_dir / "temp.mlir"
dispatch_dir = generates_dir / f"{model_name}_{device_spec}_dispatches"
export_module_to_mlir_file(input_module, dump_mlir)
dump_dispatches(
dump_mlir,
device,
dispatch_dir,
vulkan_target_triple,
use_winograd=use_winograd,
)
# Tune each dispatch
dtype = "f16" if args.precision == "fp16" else "f32"
config_filename = f"{model_name}_{device_spec}_configs.json"
for f_path in os.listdir(dispatch_dir):
if not f_path.endswith(".mlir"):
continue
model_dir = os.path.join(dispatch_dir, f_path)
tuner = SharkCodegenTuner(
model_dir,
device,
"random",
args.num_iters,
args.tuned_config_dir,
dtype,
args.search_op,
batch_size=1,
config_filename=config_filename,
use_dispatch=True,
vulkan_target_triple=vulkan_target_triple,
)
tuner.tune()
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,5 @@
import torch
import transformers
import time
from apps.stable_diffusion.src import (
args,
@@ -11,186 +12,7 @@ from apps.stable_diffusion.src import (
)
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"txt2img",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=use_lora,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_sd_obj(
Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += (
f"\nsteps={steps}, guidance_scale={guidance_scale}, seed={seeds}"
)
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
# text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -200,7 +22,6 @@ if __name__ == "__main__":
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
use_lora = args.use_lora
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
@@ -216,7 +37,9 @@ if __name__ == "__main__":
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
@@ -256,3 +79,7 @@ if __name__ == "__main__":
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
@@ -12,187 +13,6 @@ from apps.stable_diffusion.src import (
)
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def upscaler_inf(
prompt: str,
negative_prompt: str,
init_image,
height: int,
width: int,
steps: int,
noise_level: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.scheduler = scheduler
if init_image is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB").resize((height, width))
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
args.height = 128
args.width = 128
new_config_obj = Config(
"upscaler",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
args.height,
args.width,
device,
use_lora=None,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_size = batch_size
args.max_length = max_length
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_sd_obj(
UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.get_sd_obj().low_res_scheduler = schedulers["DDPM"]
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"NOISE LEVEL": noise_level}
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
low_res_img = image
high_res_img = Image.new("RGB", (height * 4, width * 4))
for i in range(0, width, 128):
for j in range(0, height, 128):
box = (j, i, j + 128, i + 128)
upscaled_image = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
low_res_img.crop(box),
batch_size,
args.height,
args.width,
steps,
noise_level,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
save_output_img(high_res_img, img_seed, extra_info)
generated_imgs.append(high_res_img)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
if args.clear_all:
clear_all()
@@ -232,7 +52,9 @@ if __name__ == "__main__":
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ddpm_scheduler=schedulers["DDPM"],
ondemand=args.ondemand,
)
start_time = time.time()

View File

@@ -25,9 +25,12 @@ datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv-python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('tkinter')
datas += collect_data_files('webview')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),

View File

@@ -25,6 +25,7 @@ datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
@@ -43,7 +44,7 @@ hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
a = Analysis(
['scripts/txt2img.py'],
['scripts/main.py'],
pathex=['.'],
binaries=binaries,
datas=datas,

View File

@@ -5,6 +5,7 @@ from apps.stable_diffusion.src.utils import (
get_available_devices,
clear_all,
save_output_img,
resize_stencil,
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,

View File

@@ -1,9 +1,11 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from collections import defaultdict
from pathlib import Path
import torch
import safetensors.torch
import traceback
import subprocess
import sys
import os
from apps.stable_diffusion.src.utils import (
@@ -11,13 +13,13 @@ from apps.stable_diffusion.src.utils import (
get_opt_flags,
base_models,
args,
fetch_or_delete_vmfbs,
preprocessCKPT,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
get_stencil_model_id,
update_lora_weight,
)
@@ -54,29 +56,9 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
return new_shape
# Get the input info for various models i.e. "unet", "clip", "vae", "vae_encode".
def get_input_info(model_info, max_len, width, height, batch_size):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = defaultdict(list)
for k in model_info:
for inp in model_info[k]:
shape = model_info[k][inp]["shape"]
dtype = dtype_config[model_info[k][inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, max_len, width, height, batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map[k].append(tensor)
return input_map
def check_compilation(model, model_name):
if not model:
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
class SharkifyStableDiffusionModel:
@@ -99,7 +81,9 @@ class SharkifyStableDiffusionModel:
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
use_lora: str = ""
use_lora: str = "",
use_quantize: str = None,
return_mlir: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -107,11 +91,21 @@ class SharkifyStableDiffusionModel:
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.use_quantize = use_quantize
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
if not os.path.isfile(weights_path):
subprocess.run(["wget", custom_weights, "-O", weights_path])
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
@@ -145,18 +139,32 @@ class SharkifyStableDiffusionModel:
self.use_lora = use_lora
print(self.model_name)
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
self.sharktank_dir = sharktank_dir
self.generate_vmfb = generate_vmfb
def get_extended_name_for_all_model(self, mask_to_fetch):
self.inputs = dict()
self.model_to_run = ""
if self.custom_weights != "":
self.model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights, self.is_inpaint)
else:
self.model_to_run = args.hf_model_id
self.custom_vae = self.process_custom_vae()
self.base_model_id = fetch_and_update_base_model_id(self.model_to_run)
if self.base_model_id != "" and args.ckpt_loc != "":
args.hf_model_id = self.base_model_id
self.return_mlir = return_mlir
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
index = 0
for model in sub_model_list:
if mask_to_fetch[index] == False:
index += 1
continue
sub_model = model
model_config = self.model_name
if "vae" == model:
@@ -164,6 +172,8 @@ class SharkifyStableDiffusionModel:
model_config = model_config + get_path_stem(self.custom_vae)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
model_config = model_config + get_path_stem(self.use_stencil)
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
@@ -176,6 +186,29 @@ class SharkifyStableDiffusionModel:
if not (height % 8 == 0 and height >= 128):
sys.exit("height should be greater than 128 and multiple of 8")
# Get the input info for a model i.e. "unet", "clip", "vae", etc.
def get_input_info_for(self, model_info):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = []
for inp in model_info:
shape = model_info[inp]["shape"]
dtype = dtype_config[model_info[inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, self.max_len, self.width, self.height, self.batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map.append(tensor)
return input_map
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
@@ -192,16 +225,20 @@ class SharkifyStableDiffusionModel:
vae_encode = VaeEncodeModel()
inputs = tuple(self.inputs["vae_encode"])
is_f16 = True if self.precision == "fp16" else False
shark_vae_encode = compile_through_fx(
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
shark_vae_encode, vae_encode_mlir = compile_through_fx(
vae_encode,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
model_name=self.model_name["vae_encode"],
extended_model_name=self.model_name["vae_encode"],
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae_encode",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae_encode
return shark_vae_encode, vae_encode_mlir
def get_vae(self):
class VaeModel(torch.nn.Module):
@@ -241,53 +278,31 @@ class SharkifyStableDiffusionModel:
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae = compile_through_fx(
shark_vae, vae_mlir = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
model_name=self.model_name["vae"],
extended_model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae
def get_vae_upscaler(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
def forward(self, input):
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
return x
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
shark_vae = compile_through_fx(
vae,
inputs,
use_tuned=self.use_tuned,
model_name=self.model_name["vae"],
extra_args=get_opt_flags("vae", precision="fp32"),
)
return shark_vae
return shark_vae, vae_mlir
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -295,6 +310,8 @@ class SharkifyStableDiffusionModel:
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
@@ -323,18 +340,22 @@ class SharkifyStableDiffusionModel:
unet = ControlledUnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_unet"])
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
shark_controlled_unet = compile_through_fx(
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
model_name=self.model_name["stencil_unet"],
extended_model_name=self.model_name["stencil_unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_controlled_unet
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self):
class StencilControlNetModel(torch.nn.Module):
@@ -378,16 +399,20 @@ class SharkifyStableDiffusionModel:
inputs = tuple(self.inputs["stencil_adaptor"])
input_mask = [True, True, True, True]
shark_cnet = compile_through_fx(
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
model_name=self.model_name["stencil_adaptor"],
extended_model_name=self.model_name["stencil_adaptor"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_adaptor",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_cnet
return shark_cnet, cnet_mlir
def get_unet(self):
class UnetModel(torch.nn.Module):
@@ -399,7 +424,7 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
self.unet.load_attn_procs(use_lora)
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
@@ -433,10 +458,10 @@ class SharkifyStableDiffusionModel:
save_dir,
exist_ok=True,
)
shark_unet = compile_through_fx(
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
model_name=self.model_name["unet"],
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
@@ -444,8 +469,12 @@ class SharkifyStableDiffusionModel:
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet
return shark_unet, unet_mlir
def get_unet_upscaler(self):
class UnetModel(torch.nn.Module):
@@ -473,26 +502,32 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
shark_unet = compile_through_fx(
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
model_name=self.model_name["unet"],
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet
return shark_unet, unet_mlir
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
def forward(self, input):
return self.text_encoder(input)[0]
@@ -504,16 +539,20 @@ class SharkifyStableDiffusionModel:
save_dir,
exist_ok=True,
)
shark_clip = compile_through_fx(
shark_clip, clip_mlir = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
model_name=self.model_name["clip"],
extended_model_name=self.model_name["clip"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip
return shark_clip, clip_mlir
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
@@ -534,132 +573,109 @@ class SharkifyStableDiffusionModel:
vae_checkpoint = vae_checkpoint["state_dict"]
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
# configiration.
def compile_all(self, base_model_id, need_vae_encode, need_stencil):
self.inputs = get_input_info(
base_models[base_model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
if self.is_upscaler:
return self.get_clip(), self.get_unet_upscaler(), self.get_vae_upscaler()
compiled_controlnet = None
compiled_controlled_unet = None
compiled_unet = None
if need_stencil:
compiled_controlnet = self.get_control_net()
compiled_controlled_unet = self.get_controlled_unet()
else:
compiled_unet = self.get_unet()
if self.custom_vae != "":
print("Plugging in custom Vae")
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
if need_stencil:
return compiled_clip, compiled_controlled_unet, compiled_vae, compiled_controlnet
if need_vae_encode:
compiled_vae_encode = self.get_vae_encode()
return compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode
return compiled_clip, compiled_unet, compiled_vae
def __call__(self):
# Step 1:
# -- Fetch all vmfbs for the model, if present, else delete the lot.
need_vae_encode, need_stencil = False, False
if not self.is_upscaler and args.img_path is not None:
if self.use_stencil is not None:
need_stencil = True
def compile_unet_variants(self, model):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler()
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
return get_unet()
else:
need_vae_encode = True
# `mask_to_fetch` prepares a mask to pick a combination out of :-
# ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
mask_to_fetch = [True, True, False, True, False, False]
if need_vae_encode:
mask_to_fetch = [True, True, False, True, True, False]
elif need_stencil:
mask_to_fetch = [True, False, True, True, False, True]
self.model_name = self.get_extended_name_for_all_model(mask_to_fetch)
vmfbs = fetch_or_delete_vmfbs(self.model_name, self.precision)
if vmfbs[0]:
# -- If all vmfbs are indeed present, we also try and fetch the base
# model configuration for running SD with custom checkpoints.
if self.custom_weights != "":
args.hf_model_id = fetch_and_update_base_model_id(self.custom_weights)
if args.hf_model_id == "":
sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.")
print("Loaded vmfbs from cache and successfully fetched base model configuration.")
return vmfbs
# Step 2:
# -- If vmfbs weren't found, we try to see if the base model configuration
# for the required SD run is known to us and bypass the retry mechanism.
model_to_run = ""
if self.custom_weights != "":
model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights, self.is_inpaint)
return self.get_unet()
else:
model_to_run = args.hf_model_id
# For custom Vae user can provide either the repo-id or a checkpoint file,
# and for a checkpoint file we'd need to process it via Diffusers' script.
self.custom_vae = self.process_custom_vae()
base_model_fetched = fetch_and_update_base_model_id(model_to_run)
if base_model_fetched != "":
print("Compiling all the models with the fetched base model configuration.")
if args.ckpt_loc != "":
args.hf_model_id = base_model_fetched
return self.compile_all(base_model_fetched, need_vae_encode, need_stencil)
return self.get_controlled_unet()
# Step 3:
# -- This is the retry mechanism where the base model's configuration is not
# known to us and figure that out by trial and error.
print("Inferring base model configuration.")
for model_id in base_models:
try:
if need_vae_encode:
compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode = self.compile_all(model_id, need_vae_encode, need_stencil)
elif need_stencil:
compiled_clip, compiled_unet, compiled_vae, compiled_controlnet = self.compile_all(model_id, need_vae_encode, need_stencil)
else:
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id, need_vae_encode, need_stencil)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
if need_vae_encode:
return (
compiled_clip,
compiled_unet,
compiled_vae,
compiled_vae_encode,
)
if need_stencil:
return (
compiled_clip,
compiled_unet,
compiled_vae,
compiled_controlnet,
)
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
)
def vae_encode(self):
try:
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
check_compilation(compiled_vae_encode, "Vae Encode")
if self.return_mlir:
return vae_encode_mlir
return compiled_vae_encode
except Exception as e:
sys.exit(e)
def clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(base_models["clip"])
compiled_clip, clip_mlir = self.get_clip()
check_compilation(compiled_clip, "Clip")
if self.return_mlir:
return clip_mlir
return compiled_clip
except Exception as e:
sys.exit(e)
def unet(self):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
compiled_unet = None
unet_inputs = base_models[model]
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet, unet_mlir = self.compile_unet_variants(model)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
try:
compiled_unet, unet_mlir = self.compile_unet_variants(model)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(self.model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
break
check_compilation(compiled_unet, "Unet")
if self.return_mlir:
return unet_mlir
return compiled_unet
except Exception as e:
sys.exit(e)
def vae(self):
try:
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
self.inputs["vae"] = self.get_input_info_for(vae_input)
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")
if self.return_mlir:
return vae_mlir
return compiled_vae
except Exception as e:
sys.exit(e)
def controlnet(self):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
check_compilation(compiled_stencil_adaptor, "Stencil")
if self.return_mlir:
return controlnet_mlir
return compiled_stencil_adaptor
except Exception as e:
sys.exit(e)

View File

@@ -20,6 +20,15 @@ hf_model_variant_map = {
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
bucket_key = "gs://shark_tank/prashant_nod"
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
@@ -41,6 +50,12 @@ def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
# TODO: Get the quantize model from model_db.json
if args.use_quantize == "int8":
bk, mk, flags = get_quantize_model()
return get_shark_model(bk, mk, flags)
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"

View File

@@ -20,16 +20,15 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class Image2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -40,9 +39,30 @@ class Image2ImagePipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_image_latents(
self,
@@ -89,9 +109,12 @@ class Image2ImagePipeline(StableDiffusionPipeline):
return latents, timesteps
def encode_image(self, input_image):
self.load_vae_encode()
vae_encode_start = time.time()
latents = self.vae_encode("forward", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
if self.ondemand:
self.unload_vae_encode()
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
return latents
@@ -131,8 +154,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -161,6 +186,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
@@ -168,5 +194,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -19,16 +19,15 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class InpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -39,9 +38,30 @@ class InpaintPipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
@@ -305,9 +325,12 @@ class InpaintPipeline(StableDiffusionPipeline):
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
@@ -383,8 +406,10 @@ class InpaintPipeline(StableDiffusionPipeline):
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -428,6 +453,7 @@ class InpaintPipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
@@ -435,6 +461,8 @@ class InpaintPipeline(StableDiffusionPipeline):
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
if inpaint_full_res:
output_image = self.apply_overlay(

View File

@@ -20,16 +20,15 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i
StableDiffusionPipeline,
)
import math
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class OutpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -40,9 +39,30 @@ class OutpaintPipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
@@ -123,9 +143,12 @@ class OutpaintPipeline(StableDiffusionPipeline):
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
@@ -384,8 +407,10 @@ class OutpaintPipeline(StableDiffusionPipeline):
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -506,6 +531,7 @@ class OutpaintPipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],

View File

@@ -20,16 +20,16 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import controlnet_hint_conversion
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class StencilPipeline(StableDiffusionPipeline):
def __init__(
self,
controlnet: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -39,9 +39,22 @@ class StencilPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.controlnet = controlnet
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
def load_controlnet(self):
if self.controlnet is not None:
return
self.controlnet = self.sd_model.controlnet()
def unload_controlnet(self):
del self.controlnet
self.controlnet = None
def prepare_latents(
self,
@@ -68,6 +81,113 @@ class StencilPipeline(StableDiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_stencil_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
controlnet_conditioning_scale: float = 1.0,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
self.load_controlnet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.ondemand:
self.unload_unet()
self.unload_controlnet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
@@ -108,8 +228,10 @@ class StencilPipeline(StableDiffusionPipeline):
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -134,11 +256,11 @@ class StencilPipeline(StableDiffusionPipeline):
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
controlnet=self.controlnet,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
@@ -146,5 +268,7 @@ class StencilPipeline(StableDiffusionPipeline):
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -1,5 +1,4 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
@@ -19,15 +18,12 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -39,8 +35,12 @@ class Text2ImagePipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
def prepare_latents(
self,
@@ -110,8 +110,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -128,12 +130,15 @@ class Text2ImagePipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
self.load_vae()
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -27,6 +27,7 @@ from apps.stable_diffusion.src.utils import (
end_profiling,
)
from PIL import Image
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def preprocess(image):
@@ -55,10 +56,6 @@ def preprocess(image):
class UpscalerPipeline(StableDiffusionPipeline):
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -80,8 +77,12 @@ class UpscalerPipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.low_res_scheduler = low_res_scheduler
def prepare_extra_step_kwargs(self, generator, eta):
@@ -163,6 +164,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
@@ -208,6 +210,8 @@ class UpscalerPipeline(StableDiffusionPipeline):
# )
step_time_sum += step_time
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -251,8 +255,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# 4. Preprocess image
image = preprocess(image).to(dtype)
@@ -299,6 +305,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
@@ -306,5 +313,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -20,7 +20,6 @@ from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
get_vae,
get_clip,
get_unet,
@@ -30,15 +29,15 @@ from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
import sys
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
class StableDiffusionPipeline:
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -50,14 +49,85 @@ class StableDiffusionPipeline:
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.unet = unet
self.vae = None
self.text_encoder = None
self.unet = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
self.sd_model = sd_model
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
try:
self.tokenizer = get_tokenizer()
except:
self.load_unet()
self.unload_unet()
self.tokenizer = get_tokenizer()
def load_clip(self):
if self.text_encoder is not None:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
)
self.text_encoder = self.sd_model.clip()
else:
try:
self.text_encoder = get_clip()
except:
print("download pipeline failed, falling back to import_mlir")
self.text_encoder = self.sd_model.clip()
def unload_clip(self):
del self.text_encoder
self.text_encoder = None
def load_unet(self):
if self.unet is not None:
return
if self.import_mlir or self.use_lora:
self.unet = self.sd_model.unet()
else:
try:
self.unet = get_unet()
except:
print("download pipeline failed, falling back to import_mlir")
self.unet = self.sd_model.unet()
def unload_unet(self):
del self.unet
self.unet = None
def load_vae(self):
if self.vae is not None:
return
if self.import_mlir or self.use_lora:
self.vae = self.sd_model.vae()
else:
try:
self.vae = get_vae()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae = self.sd_model.vae()
def unload_vae(self):
del self.vae
self.vae = None
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -77,12 +147,14 @@ class StableDiffusionPipeline:
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
self.load_clip()
clip_inf_start = time.time()
text_embeddings = self.text_encoder("forward", (text_input,))
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
@@ -111,109 +183,6 @@ class StableDiffusionPipeline:
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def produce_stencil_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
controlnet=None,
controlnet_conditioning_scale: float = 1.0,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def produce_img_latents(
self,
latents,
@@ -226,10 +195,12 @@ class StableDiffusionPipeline:
masked_image_latents=None,
return_all_latents=False,
):
self.status = SD_STATE_IDLE
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
@@ -275,6 +246,11 @@ class StableDiffusionPipeline:
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -308,115 +284,556 @@ class StableDiffusionPipeline:
width: int,
use_base_vae: bool,
use_tuned: bool,
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
if (
not import_mlir
and not use_lora
and cls.__name__ == "StencilPipeline"
):
sys.exit("StencilPipeline not supported with SharkTank currently.")
is_inpaint = cls.__name__ in [
"InpaintPipeline",
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
if import_mlir or use_lora:
if not import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
)
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
)
if cls.__name__ in [
"Image2ImagePipeline",
"InpaintPipeline",
"OutpaintPipeline",
]:
clip, unet, vae, vae_encode = mlir_import()
return cls(
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
)
if cls.__name__ in ["StencilPipeline"]:
clip, unet, vae, controlnet = mlir_import()
return cls(
controlnet, vae, clip, get_tokenizer(), unet, scheduler
)
if cls.__name__ in ["UpscalerPipeline"]:
clip, unet, vae = mlir_import()
return cls(
vae, clip, get_tokenizer(), unet, scheduler, ddpm_scheduler
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
try:
if cls.__name__ in [
"Image2ImagePipeline",
"InpaintPipeline",
"OutpaintPipeline",
]:
return cls(
get_vae_encode(),
get_vae(),
get_clip(),
get_tokenizer(),
get_unet(),
scheduler,
)
if cls.__name__ == "StencilPipeline":
import sys
sd_model = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
use_quantize=use_quantize,
)
sys.exit(
"StencilPipeline not supported with SharkTank currently."
)
if cls.__name__ in ["UpscalerPipeline"]:
return cls(
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
scheduler,
ddpm_scheduler,
sd_model,
import_mlir,
use_lora,
ondemand,
)
except:
print("download pipeline failed, falling back to import_mlir")
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
# #####################################################
# Implements text embeddings with weights from prompts
# https://huggingface.co/AlanB/lpw_stable_diffusion_mod
# #####################################################
def encode_prompts_weight(
self,
prompt,
negative_prompt,
model_max_length,
do_classifier_free_guidance=True,
max_embeddings_multiples=1,
num_images_per_prompt=1,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
model_max_length (int):
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not,
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
num_images_per_prompt (`int`):
number of images that should be generated per prompt
SHARK: num_images_per_prompt is not used (defaulted to 1)
"""
# SHARK: Save model_max_length, load the clip and init inference time
self.model_max_length = model_max_length
self.load_clip()
clip_inf_start = time.time()
batch_size = len(prompt) if isinstance(prompt, list) else 1
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
if cls.__name__ in [
"Image2ImagePipeline",
"InpaintPipeline",
"OutpaintPipeline",
]:
clip, unet, vae, vae_encode = mlir_import()
return cls(
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
)
if cls.__name__ == "StencilPipeline":
clip, unet, vae, controlnet = mlir_import()
return cls(
controlnet, vae, clip, get_tokenizer(), unet, scheduler
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt
if do_classifier_free_guidance
else None,
max_embeddings_multiples=max_embeddings_multiples,
)
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = text_embeddings.shape
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = uncond_embeddings.shape
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()
from typing import List, Optional, Union
import re
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(
pipe: StableDiffusionPipeline, prompt: List[str], max_length: int
):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens,
weights,
max_length,
bos,
eos,
no_boseos_middle=True,
chunk_length=77,
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length
if no_boseos_middle
else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
tokens[i] = (
[bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
)
if no_boseos_middle:
weights[i] = (
[1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
)
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j
* (chunk_length - 2) : min(
len(weights[i]), (j + 1) * (chunk_length - 2)
)
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
# text_embedding = pipe.text_encoder(text_input_chunk)[0]
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
formatted_text_input_chunk = torch.cat(
[text_input_chunk, text_input_chunk]
)
text_embedding = pipe.text_encoder(
"forward", (formatted_text_input_chunk,)
)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
else:
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
# Convert the result to tensor
# text_embeddings = pipe.text_encoder(text_input)[0]
formatted_text_input = torch.cat([text_input, text_input])
text_embeddings = pipe.text_encoder(
"forward", (formatted_text_input,)
)[0]
text_embeddings = torch.from_numpy(text_embeddings)[None, :]
return text_embeddings
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
pipe (`StableDiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(
pipe, prompt, max_length - 2
)
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(
prompt, max_length=max_length, truncation=True
).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(
max_length, max([len(token) for token in uncond_tokens])
)
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(
uncond_tokens, dtype=torch.long, device="cpu"
)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
prompt_weights = torch.tensor(
prompt_weights, dtype=torch.float, device="cpu"
)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
uncond_weights = torch.tensor(
uncond_weights, dtype=torch.float, device="cpu"
)
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@@ -40,6 +40,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
def compile(self):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
device = args.device.split(":", 1)[0].strip()
model_input = {
"euler": {
@@ -89,19 +90,19 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
def _import(self):
scaling_model = ScalingModel()
self.scaling_model = compile_through_fx(
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model = compile_through_fx(
self.step_model, _ = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)

View File

@@ -24,7 +24,6 @@ from apps.stable_diffusion.src.utils.utils import (
get_available_devices,
get_opt_flags,
preprocessCKPT,
fetch_or_delete_vmfbs,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
@@ -32,4 +31,7 @@ from apps.stable_diffusion.src.utils.utils import (
get_extended_name,
clear_all,
save_output_img,
get_generation_text_info,
update_lora_weight,
resize_stencil,
)

View File

@@ -1,6 +1,157 @@
{
"stabilityai/stable-diffusion-x4-upscaler": {
"unet": {
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"vae_upscaler": {
"latents" : {
"shape" : [
"1*batch_size",4,"8*height","8*width"
],
"dtype":"f32"
}
}
},
"unet": {
"stabilityai/stable-diffusion-2-1": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-2-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"runwayml/stable-diffusion-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-x4-upscaler": {
"latents": {
"shape": [
"2*batch_size",
@@ -28,141 +179,39 @@
"shape": [2],
"dtype": "i64"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"8*height","8*width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"stabilityai/stable-diffusion-2-1": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
"stencil_adaptor": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
}
},
"CompVis/stable-diffusion-v1-4": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stencil_adaptor": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
}
},
"stencil_unet": {
"stencil_unet": {
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
@@ -242,143 +291,6 @@
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"stabilityai/stable-diffusion-2-inpainting": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"runwayml/stable-diffusion-inpainting": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
}
}
}

View File

@@ -1,85 +1,19 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/sd_untuned",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
"stablediffusion/untuned":"gs://shark_tank/nightly"
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/unet/fp32/length_64/untuned":"unet_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/vae/fp32/length_64/untuned":"vae_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"anythingv3/v1_4/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v1_4/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
"anythingv3/v1_4/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
"anythingv3/v1_4/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v1_4/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v1_4/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
"anythingv3/v1_4/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
"anythingv3/v1_4/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
"anythingv3/v1_4/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v1_4/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
"anythingv3/v1_4/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v1_4/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v1_4/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
"analogdiffusion/v1_4/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
"analogdiffusion/v1_4/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v1_4/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v1_4/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
"analogdiffusion/v1_4/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
"analogdiffusion/v1_4/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
"analogdiffusion/v1_4/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v1_4/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
"analogdiffusion/v1_4/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
"openjourney/v1_4/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
"openjourney/v1_4/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
"openjourney/v1_4/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
"openjourney/v1_4/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
"openjourney/v1_4/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
"openjourney/v1_4/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
"openjourney/v1_4/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
"dreamlike/v1_4/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
"dreamlike/v1_4/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
"dreamlike/v1_4/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
"dreamlike/v1_4/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
"dreamlike/v1_4/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v1_4/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v1_4/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan"
}
]

View File

@@ -45,12 +45,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}

View File

@@ -70,24 +70,27 @@ def load_winograd_configs():
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
if not os.path.exists(WORKDIR):
os.mkdir(WORKDIR)
winograd_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
def load_lower_configs():
def load_lower_configs(base_model_id=None):
from apps.stable_diffusion.src.models import get_variant_version
from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
)
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
if not base_model_id:
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
variant, version = get_variant_version(base_model_id)
@@ -114,7 +117,14 @@ def load_lower_configs():
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
and args.width == 768
):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
@@ -212,7 +222,7 @@ def annotate_with_lower_configs(
return bytecode
def sd_model_annotation(mlir_model, model_name):
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
@@ -220,19 +230,22 @@ def sd_model_annotation(mlir_model, model_name):
winograd_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs()
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
winograd_model, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
tuned_model = mlir_model
else:
use_winograd = False
lowering_config_dir = load_lower_configs()
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)

View File

@@ -22,6 +22,12 @@ p = argparse.ArgumentParser(
### Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
)
p.add_argument(
"-p",
"--prompts",
@@ -340,6 +346,21 @@ p.add_argument(
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
)
p.add_argument(
"--ondemand",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
@@ -360,7 +381,7 @@ p.add_argument(
p.add_argument(
"--vulkan_large_heap_block_size",
default="4147483648",
default="2073741824",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
@@ -472,6 +493,13 @@ p.add_argument(
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--web_mode",
type=str,
default="app",
help="any number of: [api, app, webui]. Currently api can't be run with others.",
)
p.add_argument(
@@ -488,6 +516,12 @@ p.add_argument(
help="flag for setting server port",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for enabling rest API",
)
##############################################################################
### SD model auto-annotation flags
##############################################################################
@@ -512,6 +546,31 @@ p.add_argument(
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file",
)
##############################################################################
### SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all",
)
args, unknown = p.parse_known_args()
if args.import_debug:

View File

@@ -126,14 +126,14 @@ def controlnet_hint_conversion(
stencil_to_model_id_map = {
"canny": "lllyasviel/sd-controlnet-canny",
"depth": "lllyasviel/sd-controlnet-depth",
"canny": "lllyasviel/control_v11p_sd15_canny",
"depth": "lllyasviel/control_v11p_sd15_depth",
"hed": "lllyasviel/sd-controlnet-hed",
"mlsd": "lllyasviel/sd-controlnet-mlsd",
"normal": "lllyasviel/sd-controlnet-normal",
"openpose": "lllyasviel/sd-controlnet-openpose",
"scribble": "lllyasviel/sd-controlnet-scribble",
"seg": "lllyasviel/sd-controlnet-seg",
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
"normal": "lllyasviel/control_v11p_sd15_normalbae",
"openpose": "lllyasviel/control_v11p_sd15_openpose",
"scribble": "lllyasviel/control_v11p_sd15_scribble",
"seg": "lllyasviel/control_v11p_sd15_seg",
}

View File

@@ -3,12 +3,15 @@ import gc
import json
import re
from PIL import PngImagePlugin
from PIL import Image
from datetime import datetime as dt
from csv import DictWriter
from pathlib import Path
import numpy as np
from random import randint
import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
@@ -21,7 +24,7 @@ from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
download_from_original_stable_diffusion_ckpt,
)
@@ -36,6 +39,15 @@ def get_vmfb_path_name(model_name):
return vmfb_path
def _load_vmfb(shark_module, vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
vmfb_path = get_vmfb_path_name(model_name)
@@ -78,7 +90,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
mlir_model, device=args.device, mlir_dialect="tm_tensor"
)
return _compile_module(shark_module, model_name, extra_args)
@@ -87,7 +99,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
def compile_through_fx(
model,
inputs,
model_name,
extended_model_name,
is_f16=False,
f16_input_mask=None,
use_tuned=False,
@@ -95,7 +107,20 @@ def compile_through_fx(
debug=False,
generate_vmfb=True,
extra_args=[],
base_model_id=None,
model_name=None,
precision=None,
return_mlir=False,
):
if not return_mlir and model_name is not None:
vmfb_path = get_vmfb_path_name(extended_model_name)
if os.path.isfile(vmfb_path):
shark_module = SharkInference(mlir_module=None, device=args.device)
return (
_load_vmfb(shark_module, vmfb_path, model_name, precision),
None,
)
from shark.parser import shark_args
if "cuda" in args.device:
@@ -110,29 +135,28 @@ def compile_through_fx(
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=model_name,
model_name=extended_model_name,
save_dir=save_dir,
)
if use_tuned:
if "vae" in model_name.split("_")[0]:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
mlir_module = sd_model_annotation(mlir_module, model_name)
if "unet" in model_name.split("_")[0]:
args.annotation_model = "unet"
mlir_module = sd_model_annotation(
mlir_module, extended_model_name, base_model_id
)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
mlir_dialect="tm_tensor",
)
if generate_vmfb:
shark_module = SharkInference(
return (
_compile_module(shark_module, extended_model_name, extra_args),
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
del mlir_module
gc.collect()
return _compile_module(shark_module, model_name, extra_args)
del mlir_module
gc.collect()
@@ -264,8 +288,9 @@ def set_init_device_flags():
if (
args.precision != "fp16"
or args.height != 512
or args.width != 512
or args.height not in [512, 768]
or (args.height == 512 and args.width != 512)
or (args.height == 768 and args.width != 768)
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
@@ -299,6 +324,20 @@ def set_init_device_flags():
]:
args.use_tuned = False
elif (
args.height == 768
and args.width == 768
and (
base_model_id
not in [
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
]
or "rdna3" not in args.iree_vulkan_target_triple
)
):
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
else:
@@ -368,7 +407,7 @@ def get_available_devices():
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
available_devices.append("device => cpu")
return available_devices
@@ -454,7 +493,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
num_in_channels = 9 if is_inpaint else 4
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
@@ -464,44 +503,113 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
print("Loading complete")
def load_vmfb(vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module = SharkInference(mlir_module=None, device=args.device)
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
# This utility returns vmfbs of Clip, Unet, Vae and Vae_encode, in case all of them
# are present; deletes them otherwise.
def fetch_or_delete_vmfbs(extended_model_name, precision="fp32"):
vmfb_path = [
get_vmfb_path_name(extended_model_name[model])
for model in extended_model_name
]
number_of_vmfbs = len(vmfb_path)
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = True
compiled_models = [None] * number_of_vmfbs
for i in range(number_of_vmfbs):
all_vmfb_present = all_vmfb_present and vmfb_present[i]
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(number_of_vmfbs):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
def processLoRA(model, use_lora, splitting_prefix):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
model_name = [model for model in extended_model_name.keys()]
for i in range(number_of_vmfbs):
compiled_models[i] = load_vmfb(
vmfb_path[i], model_name[i], precision
state_dict = torch.load(use_lora)
alpha = 0.75
visited = []
# directly update weight in model
process_unet = "te" not in splitting_prefix
for key in state_dict:
if ".alpha" in key or key in visited:
continue
curr_layer = model
if ("text" not in key and process_unet) or (
"text" in key and not process_unet
):
layer_infos = (
key.split(".")[0].split(splitting_prefix)[-1].split("_")
)
return compiled_models
else:
continue
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = (
state_dict[pair_keys[0]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
weight_down = (
state_dict[pair_keys[1]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
return model
def update_lora_weight_for_unet(unet, use_lora):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
return unet
main_file_name = get_path_stem(use_lora)
if ".bin" in use_lora:
main_file_name += ".bin"
elif ".safetensors" in use_lora:
main_file_name += ".safetensors"
elif ".pt" in use_lora:
main_file_name += ".pt"
else:
sys.exit("Only .bin and .safetensors format for LoRA is supported")
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_")
def update_lora_weight(model, use_lora, model_name):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora)
try:
return processLoRA(model, use_lora, "lora_te_")
except:
return None
# `fetch_and_update_base_model_id` is a resource utility function which
@@ -557,7 +665,9 @@ def clear_all():
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
shutil.rmtree(
os.path.join(home, ".local/shark_tank"), ignore_errors=True
)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
@@ -629,3 +739,57 @@ def save_output_img(output_img, img_seed, extra_info={}):
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
return text_output
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image):
width, height = image.size
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 128:
n_size = 128
if width == min_size:
width = n_size
height = n_size / aspect_ratio
else:
height = n_size
width = n_size * aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
min_size = min(width, height)
if min_size > 768:
n_size = 768
if width == min_size:
height = n_size
width = n_size * aspect_ratio
else:
width = n_size
height = n_size / aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
new_image = image.resize((n_width, n_height))
return new_image, n_width, n_height

View File

@@ -1,204 +1,252 @@
from multiprocessing import Process, freeze_support
import os
import sys
import transformers
from apps.stable_diffusion.src import args, clear_all
import apps.stable_diffusion.web.utils.global_obj as global_obj
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
import gradio as gr
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src import args, clear_all
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
)
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create the custom model folder if it doesn't already exist
get_custom_model_path().mkdir(parents=True, exist_ok=True)
if args.clear_all:
clear_all()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
def launch_app(address):
from tkinter import Tk
import webview
tk = Tk()
# size of the window where we show our website
tk.geometry("1280x720")
webview.create_window("SHARK", address)
webview.start(private_mode=False)
if __name__ == "__main__":
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
if args.api or "api" in args.web_mode.split(","):
from apps.stable_diffusion.web.ui import (
txt2img_api,
img2img_api,
upscaler_api,
inpaint_api,
)
from fastapi import FastAPI, APIRouter
import uvicorn
# init global sd pipeline and config
global_obj._init()
app = FastAPI()
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route(
# "/sdapi/v1/outpaint", outpaint_api, methods=["post"]
# )
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
app.include_router(APIRouter())
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
sys.exit(0)
import gradio as gr
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
)
return os.path.join(base_path, relative_path)
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create the custom model folder if it doesn't already exist
dir = ["models", "vae", "lora"]
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_gallery,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
img2img_web,
img2img_gallery,
img2img_init_image,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
img2img_sendto_upscaler,
inpaint_web,
inpaint_gallery,
inpaint_init_image,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
inpaint_sendto_upscaler,
outpaint_web,
outpaint_gallery,
outpaint_init_image,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
outpaint_sendto_upscaler,
upscaler_web,
upscaler_gallery,
upscaler_init_image,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
lora_train_web,
)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
# init global sd pipeline and config
global_obj.init()
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
img2img_web.render()
with gr.TabItem(label="Inpainting", id=2):
inpaint_web.render()
with gr.TabItem(label="Outpainting", id=3):
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.TabItem(label="LoRA Training", id=5):
lora_train_web.render()
register_button_click(
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_gallery,
txt2img_sendto_img2img,
1,
[txt2img_gallery],
[img2img_init_image, tabs],
)
register_button_click(
txt2img_sendto_inpaint,
2,
[txt2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_outpaint,
3,
[txt2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_upscaler,
4,
[txt2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
img2img_web,
img2img_gallery,
img2img_init_image,
img2img_sendto_inpaint,
2,
[img2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_outpaint,
3,
[img2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_upscaler,
4,
[img2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
inpaint_web,
inpaint_gallery,
inpaint_init_image,
inpaint_sendto_img2img,
1,
[inpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
inpaint_sendto_outpaint,
3,
[inpaint_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
inpaint_sendto_upscaler,
4,
[inpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
outpaint_web,
outpaint_gallery,
outpaint_init_image,
outpaint_sendto_img2img,
1,
[outpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
outpaint_sendto_inpaint,
2,
[outpaint_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
outpaint_sendto_upscaler,
4,
[outpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
upscaler_web,
upscaler_gallery,
upscaler_init_image,
upscaler_sendto_img2img,
1,
[upscaler_gallery],
[img2img_init_image, tabs],
)
register_button_click(
upscaler_sendto_inpaint,
2,
[upscaler_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
upscaler_sendto_outpaint,
3,
[upscaler_gallery],
[outpaint_init_image, tabs],
lora_train_web,
model_web,
)
# init global sd pipeline and config
global_obj._init()
sd_web.queue()
sd_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
img2img_web.render()
with gr.TabItem(label="Inpainting", id=2):
inpaint_web.render()
with gr.TabItem(label="Outpainting", id=3):
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.TabItem(label="Model Manager", id=5):
model_web.render()
with gr.Tabs(visible=False) as experimental_tabs:
with gr.TabItem(label="LoRA Training", id=5):
lora_train_web.render()
register_button_click(
txt2img_sendto_img2img,
1,
[txt2img_gallery],
[img2img_init_image, tabs],
)
register_button_click(
txt2img_sendto_inpaint,
2,
[txt2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_outpaint,
3,
[txt2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_upscaler,
4,
[txt2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
img2img_sendto_inpaint,
2,
[img2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_outpaint,
3,
[img2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_upscaler,
4,
[img2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
inpaint_sendto_img2img,
1,
[inpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
inpaint_sendto_outpaint,
3,
[inpaint_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
inpaint_sendto_upscaler,
4,
[inpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
outpaint_sendto_img2img,
1,
[outpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
outpaint_sendto_inpaint,
2,
[outpaint_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
outpaint_sendto_upscaler,
4,
[outpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
upscaler_sendto_img2img,
1,
[upscaler_gallery],
[img2img_init_image, tabs],
)
register_button_click(
upscaler_sendto_inpaint,
2,
[upscaler_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
upscaler_sendto_outpaint,
3,
[upscaler_gallery],
[outpaint_init_image, tabs],
)
sd_web.queue()
if "app" in args.web_mode.split(","):
t = Process(
target=launch_app, args=[f"http://localhost:{args.server_port}"]
)
t.start()
sd_web.launch(
share=args.share,
inbrowser="webui" in args.web_mode.split(","),
server_name="0.0.0.0",
server_port=args.server_port,
)

View File

@@ -1,4 +1,6 @@
from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_inf,
txt2img_api,
txt2img_web,
txt2img_gallery,
txt2img_sendto_img2img,
@@ -7,6 +9,8 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
img2img_api,
img2img_web,
img2img_gallery,
img2img_init_image,
@@ -15,6 +19,8 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.inpaint_ui import (
inpaint_inf,
inpaint_api,
inpaint_web,
inpaint_gallery,
inpaint_init_image,
@@ -23,6 +29,8 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
inpaint_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.outpaint_ui import (
outpaint_inf,
outpaint_api,
outpaint_web,
outpaint_gallery,
outpaint_init_image,
@@ -31,6 +39,8 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
outpaint_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.upscaler_ui import (
upscaler_inf,
upscaler_api,
upscaler_web,
upscaler_gallery,
upscaler_init_image,
@@ -39,3 +49,4 @@ from apps.stable_diffusion.web.ui.upscaler_ui import (
upscaler_sendto_outpaint,
)
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
from apps.stable_diffusion.web.ui.model_manager import model_web

View File

@@ -101,6 +101,9 @@ Procedure to upgrade the dark theme:
}
/* SHARK theme */
body {
background-color: var(--background-fill-primary);
}
/* display in full width for desktop devices */
@media (min-width: 1536px)
@@ -166,18 +169,49 @@ footer {
border-radius: 0 !important;
}
/* Gallery: Remove the default square ratio thumbnail and limit images height to the container */
#gallery .thumbnail-item.thumbnail-lg {
aspect-ratio: unset;
max-height: calc(55vh - (2 * var(--spacing-lg)));
}
@media (min-width: 1921px) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
#gallery .grid-wrap, #gallery .preview{
min-height: calc(768px + 4px + var(--size-14));
max-height: calc(768px + 4px + var(--size-14));
}
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
#gallery .thumbnail-item.thumbnail-lg {
max-height: 770px !important;
}
}
/* Don't upscale when viewing in solo image mode */
#gallery .preview img {
object-fit: scale-down;
}
/* Navbar images in cover mode*/
#gallery .preview .thumbnail-item img {
object-fit: cover;
}
/* Limit the stable diffusion text output height */
#std_output textarea {
max-height: 215px;
}
/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */
#gallery .wrap.default {
pointer-events: none;
}
/* Import Png info box */
#txt2img_prompt_image .fixed-height {
height: var(--size-32);
#txt2img_prompt_image {
height: var(--size-32) !important;
}
/* Hide "remove buttons" from ui dropdowns */
#custom_model .token-remove.remove-all,
#lora_weights .token-remove.remove-all,
#scheduler .token-remove.remove-all,
#device .token-remove.remove-all,
#stencil_model .token-remove.remove-all {

View File

@@ -1,17 +1,340 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import img2img_inf
from apps.stable_diffusion.src import args
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
scheduler_list_cpu_only,
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
StencilPipeline,
resize_stencil,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
import numpy as np
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def img2img_inf(
prompt: str,
negative_prompt: str,
image_dict,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.strength = strength
args.scheduler = scheduler
args.img_path = "not none"
args.ondemand = ondemand
if image_dict is None:
return None, "An Initial Image is required"
if use_stencil == "scribble":
image = image_dict["mask"].convert("RGB")
else:
image = image_dict["image"].convert("RGB")
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_stencil = None if use_stencil == "None" else use_stencil
args.use_stencil = use_stencil
if use_stencil is not None:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
args.precision = precision
dtype = torch.float32 if precision == "fp32" else torch.half
new_config_obj = Config(
"img2img",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=use_stencil,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
if use_stencil is not None:
args.use_tuned = False
global_obj.set_sd_obj(
StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
else:
global_obj.set_sd_obj(
Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(args.scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"STRENGTH": strength}
text_output = ""
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
batch_size,
height,
width,
steps,
strength,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
# yield generated_imgs, text_output
return generated_imgs, text_output
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Img2Img Rest API.
def img2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["image"])
res = img2img_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["denoising_strength"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
use_stencil=InputData["use_stencil"]
if "use_stencil" in InputData.keys()
else "None",
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Image-to-Image") as img2img_web:
@@ -41,11 +364,19 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
@@ -62,7 +393,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
img2img_init_image = gr.Image(
label="Input Image", type="pil"
label="Input Image",
source="upload",
tool="sketch",
type="pil",
).style(height=300)
with gr.Accordion(label="Stencil Options", open=False):
@@ -73,13 +407,64 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
value="None",
choices=["None", "canny", "openpose", "scribble"],
)
def show_canvas(choice):
if choice == "scribble":
return (
gr.Slider.update(visible=True),
gr.Slider.update(visible=True),
gr.Button.update(visible=True),
)
else:
return (
gr.Slider.update(visible=False),
gr.Slider.update(visible=False),
gr.Button.update(visible=False),
)
def create_canvas(w, h):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
create_button = gr.Button(
label="Start",
value="Open drawing canvas!",
visible=False,
)
create_button.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[img2img_init_image],
)
use_stencil.change(
fn=show_canvas,
inputs=use_stencil,
outputs=[canvas_width, canvas_height, create_button],
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -93,8 +478,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="PNDM",
choices=scheduler_list,
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -143,6 +528,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
step=0.01,
label="Denoising Strength",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
@@ -199,19 +589,17 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
).style(columns=[2], object_fit="contain")
output_dir = (
args.output_dir if args.output_dir else Path.cwd()
)
output_dir = Path(output_dir, "generated_imgs")
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at {output_dir}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
with gr.Row():
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
img2img_sendto_outpaint = gr.Button(
@@ -238,6 +626,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
scheduler,
custom_model,
hf_model_id,
custom_vae,
precision,
device,
max_length,
@@ -246,6 +635,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[img2img_gallery, std_output],
show_progress=args.progress_bar,
@@ -255,5 +645,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -1,17 +1,294 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import inpaint_inf
from apps.stable_diffusion.src import args
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
scheduler_list_cpu_only,
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def inpaint_inf(
prompt: str,
negative_prompt: str,
image_dict,
height: int,
width: int,
inpaint_full_res: bool,
inpaint_full_res_padding: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
args.mask_path = "not none"
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"inpaint",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
mask_image,
batch_size,
height,
width,
inpaint_full_res,
inpaint_full_res_padding,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Inpaint Rest API.
def inpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["image"])
mask = decode_base64_to_image(InputData["mask"])
res = inpaint_inf(
InputData["prompt"],
InputData["negative_prompt"],
{"image": init_image, "mask": mask},
InputData["height"],
InputData["width"],
InputData["is_full_res"],
InputData["full_res_padding"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Inpainting") as inpaint_web:
@@ -41,11 +318,19 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
@@ -71,10 +356,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -88,8 +373,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="PNDM",
choices=scheduler_list,
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -145,6 +430,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
@@ -201,19 +491,17 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
).style(columns=[2], object_fit="contain")
output_dir = (
args.output_dir if args.output_dir else Path.cwd()
)
output_dir = Path(output_dir, "generated_imgs")
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at {output_dir}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
with gr.Row():
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
inpaint_sendto_outpaint = gr.Button(
@@ -241,6 +529,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
scheduler,
custom_model,
hf_model_id,
custom_vae,
precision,
device,
max_length,
@@ -248,6 +537,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[inpaint_gallery, std_output],
show_progress=args.progress_bar,
@@ -257,5 +547,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -9,7 +9,8 @@ from apps.stable_diffusion.web.ui.utils import (
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list_txt2img,
get_custom_vae_or_lora_weights,
scheduler_list,
predefined_models,
)
@@ -48,6 +49,20 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
lines=3,
)
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights to initialize weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID to initialize weights",
lines=3,
)
with gr.Group(elem_id="image_dir_box_outer"):
training_images_dir = gr.Textbox(
label="ImageDirectory",
@@ -68,7 +83,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
choices=scheduler_list_txt2img,
choices=scheduler_list,
)
with gr.Row():
height = gr.Slider(
@@ -195,6 +210,9 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
max_length,
training_images_dir,
output_loc,
get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
),
],
outputs=[std_output],
show_progress=args.progress_bar,

View File

@@ -0,0 +1,136 @@
import os
import gradio as gr
import requests
from io import BytesIO
from PIL import Image
from shark.iree_utils._common import run_cmd
def get_hf_list(limit=20):
path = "https://huggingface.co/api/models"
params = {
"search": "stable-diffusion",
"sort": "downloads",
"direction": "-1",
"limit": {limit},
"full": "true",
}
response = requests.get(path, params=params)
return response.json()
def get_civit_list(num_of_models=50):
path = f"https://civitai.com/api/v1/models?limit={num_of_models}&types=Checkpoint"
headers = {"Content-Type": "application/json"}
raw_json = requests.get(path, headers=headers).json()
models = list(raw_json.items())[0][1]
safe_models = [
safe_model for safe_model in models if not safe_model["nsfw"]
]
version_id = 0 # Currently just using the first version.
safe_models = [
safe_model
for safe_model in safe_models
if safe_model["modelVersions"][version_id]["files"][0]["metadata"][
"format"
]
== "SafeTensor"
]
first_version_models = []
for model_iter in safe_models:
# The modelVersion would only keep the version name.
if (
model_iter["modelVersions"][version_id]["images"][0]["nsfw"]
!= "None"
):
continue
model_iter["modelVersions"][version_id]["modelName"] = model_iter[
"name"
]
model_iter["modelVersions"][version_id]["rating"] = model_iter[
"stats"
]["rating"]
model_iter["modelVersions"][version_id]["favoriteCount"] = model_iter[
"stats"
]["favoriteCount"]
model_iter["modelVersions"][version_id]["downloadCount"] = model_iter[
"stats"
]["downloadCount"]
first_version_models.append(model_iter["modelVersions"][version_id])
return first_version_models
def get_image_from_model(model_json):
model_id = model_json["modelId"]
image = None
for img_info in model_json["images"]:
if img_info["nsfw"] == "None":
image_url = model_json["images"][0]["url"]
response = requests.get(image_url)
image = BytesIO(response.content)
break
return image
hf_model_list = get_hf_list()
civit_model_list = get_civit_list()
with gr.Blocks() as model_web:
model_source = gr.Radio(
choices=["Hugging Face", "Civitai"],
type="index",
value="Hugging Face",
label="Model Source",
)
with gr.Column(visible=True) as hf_block:
for model in hf_model_list:
with gr.Row():
model_url = gr.Textbox(
label="Model ID:",
value=model["modelId"],
lines=1,
interactive=False,
)
model_info = gr.Textbox(
value=f'Download Count: {model["downloads"]}{os.linesep}Favorite Count: {model["likes"]}',
lines=2,
show_label=False,
interactive=False,
)
with gr.Column(visible=False) as civit_block:
for model in civit_model_list:
with gr.Row():
image = get_image_from_model(model)
if image is None:
continue
model_img = Image.open(image)
gr.Image(
value=model_img,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=300, height=300)
with gr.Column():
gr.Textbox(
label=f'{model["modelName"]}',
value=f'Rating: {model["rating"]}{os.linesep}Favorite Count: {model["favoriteCount"]}{os.linesep}Download Count: {model["downloadCount"]}{os.linesep}File Format: {model["files"][0]["metadata"]["format"]}',
lines=4,
)
gr.Textbox(
label="Download URL:",
value=f'{model["files"][0]["downloadUrl"]}',
lines=1,
)
def update_model_list(model_source):
if model_source:
return gr.update(visible=False), gr.update(visible=True)
else:
return gr.update(visible=True), gr.update(visible=False)
model_source.change(
fn=update_model_list,
inputs=model_source,
outputs=[hf_block, civit_block],
)

View File

@@ -1,17 +1,305 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import outpaint_inf
from apps.stable_diffusion.src import args
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
scheduler_list_cpu_only,
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def outpaint_inf(
prompt: str,
negative_prompt: str,
init_image,
pixels: int,
mask_blur: int,
directions: list,
noise_q: float,
color_variation: float,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"outpaint",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
left = True if "left" in directions else False
right = True if "right" in directions else False
top = True if "up" in directions else False
bottom = True if "down" in directions else False
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
init_image,
pixels,
mask_blur,
left,
right,
top,
bottom,
noise_q,
color_variation,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Inpaint Rest API.
def outpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = outpaint_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["pixels"],
InputData["mask_blur"],
InputData["directions"],
InputData["noise_q"],
InputData["color_variation"],
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Outpainting") as outpaint_web:
@@ -41,11 +329,19 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
@@ -68,10 +364,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -85,8 +381,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="PNDM",
choices=scheduler_list,
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -164,6 +460,11 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
steps = gr.Slider(
1, 100, value=20, step=1, label="Steps"
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
@@ -220,19 +521,17 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
).style(columns=[2], object_fit="contain")
output_dir = (
args.output_dir if args.output_dir else Path.cwd()
)
output_dir = Path(output_dir, "generated_imgs")
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at {output_dir}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
with gr.Row():
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -261,6 +560,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
scheduler,
custom_model,
hf_model_id,
custom_vae,
precision,
device,
max_length,
@@ -268,6 +568,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[outpaint_gallery, std_output],
show_progress=args.progress_bar,
@@ -277,5 +578,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -1,17 +1,268 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import txt2img_inf
from apps.stable_diffusion.src import prompt_examples, args
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list_txt2img,
scheduler_list,
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.utils.png_metadata import import_png_metadata
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
utils,
save_output_img,
prompt_examples,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"txt2img",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Text2Img Rest API.
def txt2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
res = txt2img_inf(
InputData["prompt"],
InputData["negative_prompt"],
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row(elem_id="ui_title"):
@@ -42,11 +293,20 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"]
+ get_custom_model_files("vae"),
)
with gr.Column(scale=1, min_width=170):
png_info_img = gr.Image(
label="Import PNG info",
@@ -72,10 +332,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -90,7 +350,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
choices=scheduler_list_txt2img,
choices=scheduler_list,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -105,10 +365,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
384,
768,
value=args.height,
step=8,
label="Height",
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
384,
768,
value=args.width,
step=8,
label="Width",
)
precision = gr.Radio(
label="Precision",
@@ -139,6 +407,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
@@ -195,19 +468,17 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
).style(columns=[2], object_fit="contain")
output_dir = (
args.output_dir if args.output_dir else Path.cwd()
)
output_dir = Path(output_dir, "generated_imgs")
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at {output_dir}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
with gr.Row():
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -233,6 +504,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
scheduler,
custom_model,
hf_model_id,
custom_vae,
precision,
device,
max_length,
@@ -240,6 +512,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[txt2img_gallery, std_output],
show_progress=args.progress_bar,
@@ -249,17 +522,24 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
)
from apps.stable_diffusion.web.utils.png_metadata import (
import_png_metadata,
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
png_info_img.change(
fn=import_png_metadata,
inputs=[
png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
custom_model,
hf_model_id,
],
outputs=[
png_info_img,

View File

@@ -1,17 +1,297 @@
from pathlib import Path
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import upscaler_inf
from apps.stable_diffusion.src import args
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
scheduler_list_cpu_only,
predefined_upscaler_models,
cancel_sd,
)
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def upscaler_inf(
prompt: str,
negative_prompt: str,
init_image,
height: int,
width: int,
steps: int,
noise_level: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
if init_image is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB").resize((height, width))
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
args.height = 128
args.width = 128
new_config_obj = Config(
"upscaler",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
args.height,
args.width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_size = batch_size
args.max_length = max_length
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
"DDPM"
)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"NOISE LEVEL": noise_level}
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
low_res_img = image
high_res_img = Image.new("RGB", (height * 4, width * 4))
for i in range(0, width, 128):
for j in range(0, height, 128):
box = (j, i, j + 128, i + 128)
upscaled_image = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
low_res_img.crop(box),
batch_size,
args.height,
args.width,
steps,
noise_level,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
save_output_img(high_res_img, img_seed, extra_info)
generated_imgs.append(high_res_img)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Upscaler Rest API.
def upscaler_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = upscaler_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["noise_level"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Upscaler") as upscaler_web:
@@ -41,11 +321,19 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
@@ -65,13 +353,28 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Input Image", type="pil"
).style(height=300)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="DDIM",
choices=scheduler_list,
choices=scheduler_list_cpu_only,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -128,6 +431,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
step=1,
label="Noise Level",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
@@ -184,19 +492,17 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
).style(columns=[2], object_fit="contain")
output_dir = (
args.output_dir if args.output_dir else Path.cwd()
)
output_dir = Path(output_dir, "generated_imgs")
std_output = gr.Textbox(
value="Nothing to show.",
value=f"Images will be saved at {output_dir}",
lines=1,
elem_id="std_output",
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
with gr.Row():
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -221,11 +527,15 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
scheduler,
custom_model,
hf_model_id,
custom_vae,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[upscaler_gallery, std_output],
show_progress=args.progress_bar,

View File

@@ -5,6 +5,10 @@ import glob
from pathlib import Path
from apps.stable_diffusion.src import args
from dataclasses import dataclass
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
@dataclass
@@ -12,6 +16,7 @@ class Config:
mode: str
model_id: str
ckpt_loc: str
custom_vae: str
precision: str
batch_size: int
max_length: int
@@ -20,6 +25,7 @@ class Config:
device: str
use_lora: str
use_stencil: str
ondemand: str
custom_model_filetypes = (
@@ -27,13 +33,7 @@ custom_model_filetypes = (
"*.safetensors",
) # the tuple of file types
scheduler_list = [
"DDIM",
"PNDM",
"DPMSolverMultistep",
"EulerAncestralDiscrete",
]
scheduler_list_txt2img = [
scheduler_list_cpu_only = [
"DDIM",
"PNDM",
"LMSDiscrete",
@@ -41,6 +41,8 @@ scheduler_list_txt2img = [
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
]
scheduler_list = scheduler_list_cpu_only + [
"SharkEulerDiscrete",
]
@@ -70,24 +72,63 @@ def resource_path(relative_path):
return os.path.join(base_path, relative_path)
def get_custom_model_path():
return Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd(), "models")
def get_custom_model_path(model="models"):
# If `--ckpt_dir` is provided it'd override the heirarchical folder
# structure in WebUI :-
# model
# |___lora
# |___vae
if args.ckpt_dir:
return Path(args.ckpt_dir)
match model:
case "models":
return Path(Path.cwd(), "models")
case "vae":
return Path(Path.cwd(), "models/vae")
case "lora":
return Path(Path.cwd(), "models/lora")
case _:
return ""
def get_custom_model_pathfile(custom_model_name):
return os.path.join(get_custom_model_path(), custom_model_name)
def get_custom_model_pathfile(custom_model_name, model="models"):
return os.path.join(get_custom_model_path(model), custom_model_name)
def get_custom_model_files():
def get_custom_model_files(model="models"):
ckpt_files = []
for extn in custom_model_filetypes:
file_types = custom_model_filetypes
if model == "lora":
file_types = custom_model_filetypes + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_custom_model_path(), extn))
for x in glob.glob(
os.path.join(get_custom_model_path(model), extn)
)
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_custom_vae_or_lora_weights(weights, hf_id, model):
use_weight = ""
if weights == "None" and not hf_id:
use_weight = ""
elif not hf_id:
use_weight = get_custom_model_pathfile(weights, model)
else:
use_weight = hf_id
return use_weight
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:
global_obj.set_sd_status(SD_STATE_CANCEL)
except Exception:
pass
nodlogo_loc = resource_path("logos/nod-logo.png")
available_devices = get_available_devices()

View File

@@ -8,39 +8,68 @@ Also we could avoid memory leak when switching models by clearing the cache.
"""
def init():
global sd_obj
global config_obj
sd_obj = None
config_obj = None
def _init():
global _sd_obj
global _config_obj
global _schedulers
_sd_obj = None
_config_obj = None
_schedulers = None
def set_sd_obj(value):
global sd_obj
sd_obj = value
global _sd_obj
_sd_obj = value
def set_sd_scheduler(key):
global _sd_obj
_sd_obj.scheduler = _schedulers[key]
def set_sd_status(value):
global _sd_obj
_sd_obj.status = value
def set_cfg_obj(value):
global config_obj
config_obj = value
global _config_obj
_config_obj = value
def set_schedulers(value):
global sd_obj
sd_obj.scheduler = value
global _schedulers
_schedulers = value
def get_sd_obj():
return sd_obj
global _sd_obj
return _sd_obj
def get_sd_status():
global _sd_obj
return _sd_obj.status
def get_cfg_obj():
return config_obj
global _config_obj
return _config_obj
def get_scheduler(key):
global _schedulers
return _schedulers[key]
def clear_cache():
global sd_obj
global config_obj
del sd_obj
del config_obj
global _sd_obj
global _config_obj
global _schedulers
del _sd_obj
del _config_obj
del _schedulers
gc.collect()
_sd_obj = None
_config_obj = None
_schedulers = None

View File

@@ -1,21 +1,8 @@
import re
from pathlib import Path
from apps.stable_diffusion.web.ui.txt2img_ui import (
png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
custom_model,
hf_model_id,
)
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
scheduler_list_txt2img,
scheduler_list,
predefined_models,
)
@@ -75,7 +62,19 @@ def parse_generation_parameters(x: str):
return res
def import_png_metadata(pil_data):
def import_png_metadata(
pil_data,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
hf_model_id,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
@@ -110,39 +109,44 @@ def import_png_metadata(pil_data):
% metadata["Model"]
)
outputs = {
png_info_img: None,
negative_prompt: metadata["Negative prompt"],
steps: int(metadata["Steps"]),
guidance_scale: float(metadata["CFG scale"]),
seed: int(metadata["Seed"]),
width: float(metadata["Size-1"]),
height: float(metadata["Size-2"]),
}
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
cfg_scale = float(metadata["CFG scale"])
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
outputs[custom_model] = png_custom_model
outputs[hf_model_id] = ""
custom_model = png_custom_model
hf_model_id = ""
if "Model" in metadata and png_hf_model_id:
outputs[custom_model] = "None"
outputs[hf_model_id] = png_hf_model_id
custom_model = "None"
hf_model_id = png_hf_model_id
if "Prompt" in metadata:
outputs[prompt] = metadata["Prompt"]
prompt = metadata["Prompt"]
if "Sampler" in metadata:
if metadata["Sampler"] in scheduler_list_txt2img:
outputs[scheduler] = metadata["Sampler"]
if metadata["Sampler"] in scheduler_list:
sampler = metadata["Sampler"]
else:
print(
"Import PNG info: Unable to find a scheduler for %s"
% metadata["Sampler"]
)
return outputs
except Exception as ex:
if pil_data and pil_data.info.get("parameters"):
print("import_png_metadata failed with %s" % ex)
pass
return {
png_info_img: None,
}
return (
None,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
hf_model_id,
)

View File

@@ -2,4 +2,4 @@
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python generate_sharktank.py
python tank/generate_sharktank.py

View File

@@ -87,11 +87,22 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]
counter = 0
for import_opt in import_options:
for model_name in hf_model_names:
if model_name in to_skip:
continue
for use_tune in tuned_options:
if (
model_name == "stabilityai/stable-diffusion-2-1"
and use_tune == tuned_options[0]
):
continue
elif (
model_name == "stabilityai/stable-diffusion-2-1-base"
and use_tune == tuned_options[1]
):
continue
command = (
[
executable, # executable is the python from the venv used to run this
@@ -174,9 +185,21 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
else:
print(command)
print("failed to generate image for this configuration")
if "2_1_base" in model_name:
print("failed a known successful model.")
exit(1)
with open(dumpfile_name, "r+") as f:
output = f.readlines()
print("\n".join(output))
exit(1)
if os.name == "nt":
counter += 1
if counter % 2 == 0:
extra_flags.append(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
else:
if counter != 1:
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)

View File

@@ -2,9 +2,11 @@ def pytest_addoption(parser):
# Attaches SHARK command-line arguments to the pytest machinery.
parser.addoption(
"--benchmark",
action="store_true",
default="False",
help="Pass option to benchmark and write results.csv",
action="store",
type=str,
default=None,
choices=("baseline", "native", "all"),
help="Benchmarks specified engine(s) and writes bench_results.csv.",
)
parser.addoption(
"--onnx_bench",
@@ -40,7 +42,13 @@ def pytest_addoption(parser):
"--update_tank",
action="store_true",
default="False",
help="Update local shark tank with latest artifacts.",
help="Update local shark tank with latest artifacts if model artifact hash mismatched.",
)
parser.addoption(
"--force_update_tank",
action="store_true",
default="False",
help="Force-update local shark tank with artifacts from specified shark_tank URL (defaults to nightly).",
)
parser.addoption(
"--ci_sha",
@@ -51,15 +59,21 @@ def pytest_addoption(parser):
parser.addoption(
"--local_tank_cache",
action="store",
default="",
default=None,
help="Specify the directory in which all downloaded shark_tank artifacts will be cached.",
)
parser.addoption(
"--tank_url",
type=str,
default="gs://shark_tank/latest",
default="gs://shark_tank/nightly",
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
)
parser.addoption(
"--tank_prefix",
type=str,
default=None,
help="Prefix to gs://shark_tank/ model directories from which to download SHARK tank artifacts. Default is nightly.",
)
parser.addoption(
"--benchmark_dispatches",
default=None,
@@ -70,3 +84,9 @@ def pytest_addoption(parser):
default="./temp_dispatch_benchmarks",
help="Directory in which dispatch benchmarks are saved.",
)
parser.addoption(
"--batchsize",
default=1,
type=int,
help="Batch size for the tested model.",
)

75
docs/shark_sd_blender.md Normal file
View File

@@ -0,0 +1,75 @@
# Overview
This document is intended to provide a starting point for using SHARK stable diffusion with Blender.
We currently make use of the [AI-Render Plugin](https://github.com/benrugg/AI-Render) to integrate with Blender.
## Setup SHARK and prerequisites:
* Download the latest SHARK SD webui .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow instructions on the [README](https://github.com/nod-ai/SHARK#readme)
* Once you have the .exe where you would like SHARK to install, run the .exe from terminal/PowerShell with the `--api` flag:
```
## Run the .exe in API mode:
.\shark_sd_<date>_<ver>.exe --api
## For example:
.\shark_sd_20230411_671.exe --api --server_port=8082
## From a the base directory of a source clone of SHARK:
./setup_venv.ps1
python apps\stable_diffusion\web\index.py --api
```
Your local SD server should start and look something like this:
![image](https://user-images.githubusercontent.com/87458719/231369758-e2c3c45a-eccc-4fe5-a788-4a3bf1ace1d1.png)
* Note: When running in api mode with `--api`, the .exe will not function as a webUI. Thus, the address in the terminal output will only be useful for API requests.
### Install AI Render
- Get AI Render on [Blender Market](https://blendermarket.com/products/ai-render) or [Gumroad](https://airender.gumroad.com/l/ai-render)
- Open Blender, then go to Edit > Preferences > Add-ons > Install and then find the zip file
- We will be using the Automatic1111 SD backend for the AI-Render plugin. Follow instructions [here](https://github.com/benrugg/AI-Render/wiki/Local-Installation) to setup local SD backend.
Your AI-Render preferences should be configured as shown; the highlighted part should match your terminal output:
![image](https://user-images.githubusercontent.com/87458719/231390322-59a54a09-520a-4a08-b658-6e37bd63e932.png)
The [AI-Render README](https://github.com/benrugg/AI-Render/blob/main/README.md) has more details on installation and usage, as well as video tutorials.
## Using AI-Render + SHARK in your Blender project
- In the Render Properties tab, in the AI-Render dropdown, enable AI-Render.
![image](https://user-images.githubusercontent.com/87458719/231392843-9bd51744-3ce2-464e-843a-0c4d4c96df0c.png)
- Select an image size (it's usually better to upscale later than go high on the img2img resolution here.)
![image](https://user-images.githubusercontent.com/87458719/231394288-0c4ab8c5-dc30-4dbe-8bc1-7520ded5efe8.png)
- From here, you can enter a prompt and configure img2img Stable Diffusion parameters, and AI-Render will run SHARK SD img2img on the rendered scene.
- AI-Render has useful presets for aesthetic styles, so you should be able to keep your subject prompt simple and focus on creating a decent Blender scene to start from.
![image](https://user-images.githubusercontent.com/87458719/231440729-2fe69586-41cb-4274-9ce7-f6c08def600b.png)
## Examples:
Scene (Input image):
![blender-sample-2](https://user-images.githubusercontent.com/87458719/231450408-0e680086-3e52-4962-a5c1-c703a94d1583.png)
Prompt:
"A bowl of tangerines in front of rocks, masterpiece, oil on canvas, by Georgia O'Keefe, trending on artstation, landscape painting by Caspar David Friedrich"
Negative Prompt (default):
"ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
Example output:
![blender-sample-2_out](https://user-images.githubusercontent.com/87458719/231451145-a0b56897-a7d0-4add-bbed-7e8af21a65df.png)

View File

@@ -6,36 +6,16 @@ from distutils.sysconfig import get_python_lib
import fileinput
from pathlib import Path
# Diffusers 0.13.1 fails with transformers __init.py errros in BLIP. So remove it for now until we fork it
pix2pix_init = Path(get_python_lib() + "/diffusers/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(get_python_lib() + "/diffusers/pipelines/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(
get_python_lib() + "/diffusers/pipelines/stable_diffusion/__init__.py"
# Temorary workaround for transformers/__init__.py.
path_to_tranformers_hook = Path(
get_python_lib()
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
)
for line in fileinput.input(pix2pix_init, inplace=True):
if "StableDiffusionPix2PixZeroPipeline" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
if path_to_tranformers_hook.is_file():
pass
else:
with open(path_to_tranformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")

View File

@@ -1,3 +1,3 @@
[pytest]
addopts = --verbose -p no:warnings
addopts = --verbose -s -p no:warnings
norecursedirs = inference tank/tflite examples benchmarks shark

View File

@@ -2,8 +2,8 @@
--pre
numpy>1.22.4
torchvision
pytorch-triton
torchvision==0.16.0.dev20230322
tabulate
tqdm
@@ -15,8 +15,8 @@ iree-tools-tf
# TensorFlow and JAX.
gin-config
tf-nightly
keras>=2.10
tensorflow>2.11
keras
#tf-models-nightly
#tensorflow-text-nightly
transformers
@@ -33,6 +33,7 @@ lit
pyyaml
python-dateutil
sacremoses
sentencepiece
# web dependecies.
gradio

View File

@@ -16,7 +16,7 @@ parameterized
# Add transformers, diffusers and scipy since it most commonly used
transformers
diffusers
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
scipy
ftfy
gradio
@@ -26,6 +26,8 @@ safetensors
opencv-python
scikit-image
pytorch_lightning # for runwayml models
tk
pywebview
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile

View File

@@ -45,7 +45,7 @@ if ($arguments -eq "--force"){
Remove-Item .\shark.venv -Force -Recurse
if (Test-Path .\shark.venv\) {
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
break
exit 1
}
}
}
@@ -78,12 +78,12 @@ if (!($PyVer.length -ne 0)) {$p} # return Python --version String if py.exe is u
if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any list
{
Write-Host "Please install Python 3.11 and try again"
break
exit 34
}
Write-Host "Installing Build Dependencies"
# make sure we really use 3.11 from list, even if it's not the default.
if (!($PyVer.length -ne 0)) {py -3.11 -m venv .\shark.venv\}
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
else {python -m venv .\shark.venv\}
.\shark.venv\Scripts\activate
python -m pip install --upgrade pip

View File

@@ -129,11 +129,11 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
TV_VERSION=${TV_VER:9:18}
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu117."
echo "Successfully Installed torch + cu118."
else
echo "Could not install torch + cu117." >&2
echo "Could not install torch + cu118." >&2
fi
fi

View File

@@ -35,8 +35,9 @@ def run_cmd(cmd, debug=False):
stderr=subprocess.PIPE,
check=True,
)
result_str = result.stdout.decode()
return result_str
stdout = result.stdout.decode()
stderr = result.stderr.decode()
return stdout, stderr
except subprocess.CalledProcessError as e:
print(e.output)
sys.exit(f"Exiting program due to error running {cmd}")

View File

@@ -90,6 +90,7 @@ def build_benchmark_args(
benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}")
# if time_extractor:
# benchmark_cl.append(time_extractor)
benchmark_cl.append(f"--print_statistics=true")
return benchmark_cl
@@ -129,7 +130,8 @@ def build_benchmark_args_non_tensor_input(
def run_benchmark_module(benchmark_cl):
"""
Run benchmark command, extract result and return iteration/seconds.
Run benchmark command, extract result and return iteration/seconds, host
peak memory, and device peak memory.
# TODO: Add an example of the benchmark command.
Input: benchmark command.
@@ -138,15 +140,22 @@ def run_benchmark_module(benchmark_cl):
assert os.path.exists(
benchmark_path
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
bench_result = run_cmd(" ".join(benchmark_cl))
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
try:
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
match = regex_split.search(bench_stdout)
time_ms = float(match.group(1))
unit = match.group(3)
except AttributeError:
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
match = regex_split.search(bench_stdout)
time_ms = float(match.group(1))
unit = match.group(2)
return 1.0 / (time * 0.001)
iter_per_second = 1.0 / (time_ms * 0.001)
# Extract peak memory.
host_regex = re.compile(r".*HOST_LOCAL:\s*([0-9]+)B peak")
host_peak_b = int(host_regex.search(bench_stderr).group(1))
device_regex = re.compile(r".*DEVICE_LOCAL:\s*([0-9]+)B peak")
device_peak_b = int(device_regex.search(bench_stderr).group(1))
return iter_per_second, host_peak_b, device_peak_b

View File

@@ -52,7 +52,7 @@ def get_iree_device_args(device, extra_args=[]):
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg"]:
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
return ["--iree-llvmcpu-target-cpu-features=host"]
elif frontend in ["tensorflow", "tf", "mhlo"]:
return [
@@ -70,6 +70,7 @@ def get_iree_common_args():
return [
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
]
@@ -188,21 +189,23 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
benchmark_bash.write(" ".join(benchmark_cl))
benchmark_bash.close()
benchmark_data = run_benchmark_module(benchmark_cl)
iter_per_second, _, _ = run_benchmark_module(
benchmark_cl
)
benchmark_file = open(
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
)
benchmark_file.write(f"DISPATCH: {d_}\n")
benchmark_file.write(str(benchmark_data) + "\n")
benchmark_file.write(str(iter_per_second) + "\n")
benchmark_file.write(
"SHARK BENCHMARK RESULT: "
+ str(1 / (benchmark_data * 0.001))
+ str(1 / (iter_per_second * 0.001))
+ "\n"
)
benchmark_file.close()
benchmark_runtimes[d_] = 1 / (benchmark_data * 0.001)
benchmark_runtimes[d_] = 1 / (iter_per_second * 0.001)
elif ".mlir" in f_ and "benchmark" not in f_:
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")

View File

@@ -30,11 +30,10 @@ def get_iree_gpu_args():
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
) and (shark_args.enable_tf32 == True):
return [
"--iree-hal-cuda-disable-loop-nounroll-wa",
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
]
else:
return ["--iree-hal-cuda-disable-loop-nounroll-wa"]
return []
# Get the default gpu args given the architecture.

View File

@@ -131,7 +131,9 @@ def get_vendor(triple):
return "ARM"
if arch == "m1":
return "Apple"
if arch in ["turing", "ampere"]:
if arch in ["arc", "UHD"]:
return "Intel"
if arch in ["turing", "ampere", "pascal"]:
return "NVIDIA"
if arch == "ardeno":
return "Qualcomm"
@@ -149,7 +151,7 @@ def get_device_type(triple):
return "Unknown"
if arch == "cpu":
return "CPU"
if arch in ["turing", "ampere"]:
if arch in ["turing", "ampere", "arc", "pascal"]:
return "DiscreteGPU"
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]:
if product == "ivega10":
@@ -343,6 +345,37 @@ def get_vulkan_target_capabilities(triple):
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "arc":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
cap["subgroupSize"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderFloat64"] = False
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = False
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "cpu":
if product == "swiftshader":
cap["maxComputeSharedMemorySize"] = 16384
@@ -356,6 +389,39 @@ def get_vulkan_target_capabilities(triple):
"ShuffleRelative",
]
elif arch in ["pascal"]:
cap["maxComputeSharedMemorySize"] = 49152
cap["maxComputeWorkGroupInvocations"] = 1536
cap["maxComputeWorkGroupSize"] = [1536, 1024, 64]
cap["subgroupSize"] = 32
cap["minSubgroupSize"] = 32
cap["maxSubgroupSize"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderFloat64"] = True
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = True
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch in ["ampere", "turing"]:
cap["maxComputeSharedMemorySize"] = 49152
cap["maxComputeWorkGroupInvocations"] = 1024

View File

@@ -22,7 +22,8 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_vulkan_device_name():
vulkaninfo_dump = run_cmd("vulkaninfo").split(linesep)
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
@@ -106,8 +107,13 @@ def get_vulkan_target_triple(device_name):
# Windows: AMD Radeon RX 7900 XTX
elif all(x in device_name for x in ("RX", "7900")):
triple = f"rdna3-7900-{system_os}"
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
triple = f"rdna3-w7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
triple = f"rdna2-unknown-{system_os}"
# Intel Targets
elif any(x in device_name for x in ("A770", "A750")):
triple = f"arc-770-{system_os}"
else:
triple = None
return triple
@@ -139,7 +145,7 @@ def get_vulkan_triple_flag(device_name="", extra_args=[]):
def get_iree_vulkan_args(extra_args=[]):
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
res_vulkan_flag = []
vulkan_triple_flag = None

View File

@@ -14,8 +14,10 @@
import argparse
import os
import subprocess
parser = argparse.ArgumentParser(description="SHARK runner.")
parser.add_argument(
"--device",
type=str,
@@ -54,7 +56,7 @@ parser.add_argument(
)
parser.add_argument(
"--shark_prefix",
default="latest",
default=None,
help="gs://shark_tank/<this_flag>/model_directories",
)
parser.add_argument(

View File

@@ -21,9 +21,17 @@ from shark.iree_utils.benchmark_utils import (
from shark.parser import shark_args
from datetime import datetime
import time
from typing import Optional
import csv
import os
TF_CPU_DEVICE = "/CPU:0"
TF_GPU_DEVICE = "/GPU:0"
def _bytes_to_mb_str(bytes_: Optional[int]) -> str:
return "" if bytes_ is None else f"{bytes_ / 1e6:.6f}"
class OnnxFusionOptions(object):
def __init__(self):
@@ -70,6 +78,7 @@ class SharkBenchmarkRunner(SharkRunner):
self.vmfb_file = None
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
self.import_args = {}
SharkRunner.__init__(
self,
mlir_module,
@@ -104,7 +113,6 @@ class SharkBenchmarkRunner(SharkRunner):
def benchmark_torch(self, modelname):
import torch
import torch._dynamo as dynamo
from tank.model_utils import get_torch_model
if self.device == "cuda":
@@ -116,31 +124,54 @@ class SharkBenchmarkRunner(SharkRunner):
torch_device = torch.device(
"cuda:0" if self.device == "cuda" else "cpu"
)
HFmodel, input = get_torch_model(modelname)[:2]
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
frontend_model = HFmodel.model
frontend_model.to(torch_device)
input.to(torch_device)
# frontend_model = torch.compile(frontend_model, mode="max-autotune", backend="inductor")
# TODO: re-enable as soon as pytorch CUDA context issues are resolved
try:
frontend_model = torch.compile(
frontend_model, mode="max-autotune", backend="inductor"
)
except RuntimeError:
frontend_model = HFmodel.model
for i in range(shark_args.num_warmup_iterations):
frontend_model.forward(input)
if self.device == "cuda":
torch.cuda.reset_peak_memory_stats()
begin = time.time()
for i in range(shark_args.num_iterations):
out = frontend_model.forward(input)
if i == shark_args.num_iterations - 1:
end = time.time()
break
end = time.time()
if self.device == "cuda":
stats = torch.cuda.memory_stats()
device_peak_b = stats["allocated_bytes.all.peak"]
frontend_model.to(torch.device("cpu"))
input.to(torch.device("cpu"))
torch.cuda.empty_cache()
else:
device_peak_b = None
print(
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
if self.device == "cuda":
# Set device to CPU so we don't run into segfaults exiting pytest subprocesses.
torch_device = torch.device("cpu")
return [
f"{shark_args.num_iterations/(end-begin)}",
f"{((end-begin)/shark_args.num_iterations)*1000}",
"", # host_peak_b (CPU usage) is not reported by PyTorch.
_bytes_to_mb_str(device_peak_b),
]
def benchmark_tf(self, modelname):
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
@@ -155,38 +186,55 @@ class SharkBenchmarkRunner(SharkRunner):
from tank.model_utils_tf import get_tf_model
# tf_device = "/GPU:0" if self.device == "cuda" else "/CPU:0"
tf_device = "/CPU:0"
# tf_device = TF_GPU_DEVICE if self.device == "cuda" else TF_CPU_DEVICE
tf_device = TF_CPU_DEVICE
with tf.device(tf_device):
(
model,
input,
) = get_tf_model(
modelname
modelname, self.import_args
)[:2]
frontend_model = model
for i in range(shark_args.num_warmup_iterations):
frontend_model.forward(*input)
if tf_device == TF_GPU_DEVICE:
tf.config.experimental.reset_memory_stats(tf_device)
begin = time.time()
for i in range(shark_args.num_iterations):
out = frontend_model.forward(*input)
if i == shark_args.num_iterations - 1:
end = time.time()
break
end = time.time()
if tf_device == TF_GPU_DEVICE:
memory_info = tf.config.experimental.get_memory_info(tf_device)
device_peak_b = memory_info["peak"]
else:
# tf.config.experimental does not currently support measuring
# CPU memory usage.
device_peak_b = None
print(
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
return [
f"{shark_args.num_iterations/(end-begin)}",
f"{((end-begin)/shark_args.num_iterations)*1000}",
"", # host_peak_b (CPU usage) is not reported by TensorFlow.
_bytes_to_mb_str(device_peak_b),
]
def benchmark_c(self):
result = run_benchmark_module(self.benchmark_cl)
print(f"Shark-IREE-C benchmark:{result} iter/second")
return [f"{result}", f"{1000/result}"]
iter_per_second, host_peak_b, device_peak_b = run_benchmark_module(
self.benchmark_cl
)
print(f"Shark-IREE-C benchmark:{iter_per_second} iter/second")
return [
f"{iter_per_second}",
f"{1000/iter_per_second}",
_bytes_to_mb_str(host_peak_b),
_bytes_to_mb_str(device_peak_b),
]
def benchmark_python(self, inputs):
input_list = [x for x in inputs]
@@ -196,8 +244,7 @@ class SharkBenchmarkRunner(SharkRunner):
begin = time.time()
for i in range(shark_args.num_iterations):
out = self.run("forward", input_list)
if i == shark_args.num_iterations - 1:
end = time.time()
end = time.time()
print(
f"Shark-IREE Python benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
@@ -306,11 +353,21 @@ for currently supported models. Exiting benchmark ONNX."
return comp_str
def benchmark_all_csv(
self, inputs: tuple, modelname, dynamic, device_str, frontend
self,
inputs: tuple,
modelname,
dynamic,
device_str,
frontend,
import_args,
mode="native",
):
self.setup_cl(inputs)
self.import_args = import_args
self.mode = mode
field_names = [
"model",
"batch_size",
"engine",
"dialect",
"device",
@@ -324,8 +381,19 @@ for currently supported models. Exiting benchmark ONNX."
"tags",
"notes",
"datetime",
"host_memory_mb",
"device_memory_mb",
"measured_host_memory_mb",
"measured_device_memory_mb",
]
engines = ["frontend", "shark_python", "shark_iree_c"]
# "frontend" must be the first element.
if self.mode == "native":
engines = ["shark_python", "shark_iree_c"]
if self.mode == "baseline":
engines = ["frontend"]
if self.mode == "all":
engines = ["frontend", "shark_python", "shark_iree_c"]
if shark_args.onnx_bench == True:
engines.append("onnxruntime")
@@ -336,75 +404,78 @@ for currently supported models. Exiting benchmark ONNX."
with open("bench_results.csv", mode="a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=field_names)
bench_result = {}
bench_result["model"] = modelname
bench_info = {}
bench_info["model"] = modelname
bench_info["batch_size"] = str(import_args["batch_size"])
bench_info["dialect"] = self.mlir_dialect
bench_info["iterations"] = shark_args.num_iterations
if dynamic == True:
bench_result["shape_type"] = "dynamic"
bench_info["shape_type"] = "dynamic"
else:
bench_result["shape_type"] = "static"
bench_result["device"] = device_str
bench_info["shape_type"] = "static"
bench_info["device"] = device_str
if "fp16" in modelname:
bench_result["data_type"] = "float16"
bench_info["data_type"] = "float16"
else:
bench_result["data_type"] = inputs[0].dtype
bench_info["data_type"] = inputs[0].dtype
for e in engines:
(
bench_result["param_count"],
bench_result["tags"],
bench_result["notes"],
) = ["", "", ""]
engine_result = {}
self.frontend_result = None
if e == "frontend":
bench_result["engine"] = frontend
engine_result["engine"] = frontend
if check_requirements(frontend):
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
engine_result["host_memory_mb"],
engine_result["device_memory_mb"],
) = self.benchmark_frontend(modelname)
self.frontend_result = bench_result["ms/iter"]
bench_result["vs. PyTorch/TF"] = "baseline"
self.frontend_result = engine_result["ms/iter"]
engine_result["vs. PyTorch/TF"] = "baseline"
(
bench_result["param_count"],
bench_result["tags"],
bench_result["notes"],
engine_result["param_count"],
engine_result["tags"],
engine_result["notes"],
) = self.get_metadata(modelname)
else:
self.frontend_result = None
continue
elif e == "shark_python":
bench_result["engine"] = "shark_python"
engine_result["engine"] = "shark_python"
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
) = self.benchmark_python(inputs)
bench_result[
engine_result[
"vs. PyTorch/TF"
] = self.compare_bench_results(
self.frontend_result, bench_result["ms/iter"]
self.frontend_result, engine_result["ms/iter"]
)
elif e == "shark_iree_c":
bench_result["engine"] = "shark_iree_c"
engine_result["engine"] = "shark_iree_c"
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
engine_result["host_memory_mb"],
engine_result["device_memory_mb"],
) = self.benchmark_c()
bench_result[
engine_result[
"vs. PyTorch/TF"
] = self.compare_bench_results(
self.frontend_result, bench_result["ms/iter"]
self.frontend_result, engine_result["ms/iter"]
)
elif e == "onnxruntime":
bench_result["engine"] = "onnxruntime"
engine_result["engine"] = "onnxruntime"
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
) = self.benchmark_onnx(modelname, inputs)
bench_result["dialect"] = self.mlir_dialect
bench_result["iterations"] = shark_args.num_iterations
bench_result["datetime"] = str(datetime.now())
writer.writerow(bench_result)
engine_result["datetime"] = str(datetime.now())
writer.writerow(bench_info | engine_result)

View File

@@ -127,33 +127,105 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
and os.path.isfile(os.path.join(model_dir, "hash.npy"))
):
print(f"""Using cached models from {WORKDIR}...""")
print(
f"""Model artifacts for {model_name} found at {WORKDIR}..."""
)
return True
return False
def _internet_connected():
import requests as req
try:
req.get("http://1.1.1.1")
return True
except:
return False
def get_git_revision_short_hash() -> str:
import subprocess
if shark_args.shark_prefix is not None:
prefix_kw = shark_args.shark_prefix
else:
import json
dir_path = os.path.dirname(os.path.realpath(__file__))
src = os.path.join(dir_path, "..", "tank_version.json")
with open(src, "r") as f:
data = json.loads(f.read())
prefix_kw = data["version"]
print(f"Checking for updates from gs://shark_tank/{prefix_kw}")
return prefix_kw
def get_sharktank_prefix():
tank_prefix = ""
if not _internet_connected():
print(
"No internet connection. Using the model already present in the tank."
)
tank_prefix = "none"
else:
desired_prefix = get_git_revision_short_hash()
storage_client_a = storage.Client.create_anonymous_client()
base_bucket_name = "shark_tank"
base_bucket = storage_client_a.bucket(base_bucket_name)
dir_blobs = base_bucket.list_blobs(prefix=f"{desired_prefix}")
for blob in dir_blobs:
dir_blob_name = blob.name.split("/")
if desired_prefix in dir_blob_name[0]:
tank_prefix = dir_blob_name[0]
break
else:
continue
if tank_prefix == "":
print(
f"shark_tank bucket not found matching ({desired_prefix}). Defaulting to nightly."
)
tank_prefix = "nightly"
return tank_prefix
# Downloads the torch model from gs://shark_tank dir.
def download_model(
model_name,
dynamic=False,
tank_url="gs://shark_tank/latest",
tank_url=None,
frontend=None,
tuned=None,
import_args=None,
):
model_name = model_name.replace("/", "_")
dyn_str = "_dynamic" if dynamic else ""
os.makedirs(WORKDIR, exist_ok=True)
model_dir_name = model_name + "_" + frontend
shark_args.shark_prefix = get_sharktank_prefix()
if import_args["batch_size"] and import_args["batch_size"] != 1:
model_dir_name = (
model_name
+ "_"
+ frontend
+ "_BS"
+ str(import_args["batch_size"])
)
else:
model_dir_name = model_name + "_" + frontend
model_dir = os.path.join(WORKDIR, model_dir_name)
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
if not tank_url:
tank_url = "gs://shark_tank/" + shark_args.shark_prefix
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
if not check_dir_exists(
model_dir_name, frontend=frontend, dynamic=dyn_str
):
print(
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
f"Downloading artifacts for model {model_name} from: {full_gs_url}"
)
download_public_file(full_gs_url, model_dir)
elif shark_args.force_update_tank == True:
print(
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
@@ -179,6 +251,7 @@ def download_model(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
except FileNotFoundError:
print(f"Model artifact hash not found at {model_dir}.")
upstream_hash = None
if local_hash != upstream_hash and shark_args.update_tank == True:
print(f"Updating artifacts for model {model_name}...")
@@ -186,17 +259,28 @@ def download_model(
elif local_hash != upstream_hash:
print(
"Hash does not match upstream in gs://shark_tank/latest. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
"Hash does not match upstream in gs://shark_tank/. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
)
else:
print(
"Local and upstream hashes match. Using cached model artifacts."
)
model_dir = os.path.join(WORKDIR, model_dir_name)
tuned_str = "" if tuned is None else "_" + tuned
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
filename = os.path.join(model_dir, model_name + suffix)
if not os.path.exists(filename):
from tank.generate_sharktank import gen_shark_files
print(
"The model data was not found. Trying to generate artifacts locally."
)
gen_shark_files(model_name, frontend, WORKDIR, import_args)
assert os.path.exists(filename), f"MLIR not found at {filename}"
with open(filename, mode="rb") as f:
mlir_file = f.read()
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
@@ -204,13 +288,3 @@ def download_model(
inputs_tuple = tuple([inputs[key] for key in inputs])
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
return mlir_file, function_name, inputs_tuple, golden_out_tuple
def _internet_connected():
import requests as req
try:
req.get("http://1.1.1.1")
return True
except:
return False

View File

@@ -9,8 +9,8 @@ import hashlib
def create_hash(file_name):
with open(file_name, "rb") as f:
file_hash = hashlib.blake2b()
while chunk := f.read(2**20):
file_hash = hashlib.blake2b(digest_size=64)
while chunk := f.read(2**10):
file_hash.update(chunk)
return file_hash.hexdigest()
@@ -81,7 +81,7 @@ class SharkImporter:
# NOTE: The default function for torch is "forward" and tf-lite is "main".
def _torch_mlir(self, is_dynamic, tracing_required):
def _torch_mlir(self, is_dynamic, tracing_required, mlir_type):
from shark.torch_mlir_utils import get_torch_mlir_module
return get_torch_mlir_module(
@@ -90,6 +90,7 @@ class SharkImporter:
is_dynamic,
tracing_required,
self.return_str,
mlir_type,
)
def _tf_mlir(self, func_name, save_dir="."):
@@ -120,6 +121,7 @@ class SharkImporter:
tracing_required=False,
func_name="forward",
save_dir="./shark_tmp/",
mlir_type="linalg",
):
if self.frontend in ["torch", "pytorch"]:
if self.inputs == None:
@@ -127,7 +129,10 @@ class SharkImporter:
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module"
)
sys.exit(1)
return self._torch_mlir(is_dynamic, tracing_required), func_name
return (
self._torch_mlir(is_dynamic, tracing_required, mlir_type),
func_name,
)
if self.frontend in ["tf", "tensorflow"]:
return self._tf_mlir(func_name, save_dir), func_name
if self.frontend in ["tflite", "tf-lite"]:
@@ -165,8 +170,17 @@ class SharkImporter:
if self.frontend == "torch":
with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file:
mlir_file.write(mlir_data)
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
hash_gen_attempts = 2
for i in range(hash_gen_attempts):
try:
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
except FileNotFoundError as err:
if i < hash_gen_attempts:
continue
else:
raise err
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
return
def import_debug(
@@ -297,6 +311,7 @@ def transform_fx(fx_g):
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.zeros,
]:
node.kwargs = kwargs_dict
# Inputs and outputs of aten.var.mean should be upcasted to fp32.

View File

@@ -19,6 +19,12 @@ import tempfile
from shark.parser import shark_args
import io
mlir_type_mapping_dict = {
"linalg": torch_mlir.OutputType.LINALG_ON_TENSORS,
"mhlo": torch_mlir.OutputType.STABLEHLO,
"tosa": torch_mlir.OutputType.TOSA,
}
def get_module_name_for_asm_dump(module):
"""Gets a name suitable for an assembly dump.
@@ -57,6 +63,7 @@ def get_torch_mlir_module(
dynamic: bool,
jit_trace: bool,
return_str: bool = False,
mlir_type: str = "linalg",
):
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
ignore_traced_shapes = False
@@ -70,10 +77,11 @@ def get_torch_mlir_module(
mlir_module = torch_mlir.compile(
module,
input,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
output_type=mlir_type_mapping_dict[mlir_type],
use_tracing=jit_trace,
ignore_traced_shapes=ignore_traced_shapes,
)
if return_str:
return mlir_module.operation.get_asm()
bytecode_stream = io.BytesIO()

View File

@@ -22,7 +22,7 @@ bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
bert-large-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344",""
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/388","macos"
@@ -35,3 +35,13 @@ squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","mac
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243",""
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"",""
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported",""
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
stabilityai/stable-diffusion-2-1-base,linalg,torch,1e-3,1e-3,default,None,True,False,False,"",""
1 resnet50 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False False macos
22 bert-large-uncased mhlo tf 1e-2 1e-3 default None False False False
23 facebook/deit-small-distilled-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True False Fails during iree-compile.
24 google/vit-base-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/311
25 microsoft/beit-base-patch16-224-pt22k-ft22k linalg torch 1e-2 1e-3 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/390 macos
26 microsoft/MiniLM-L12-H384-uncased linalg torch 1e-2 1e-3 default None False False False
27 google/mobilebert-uncased linalg torch 1e-2 1e-3 default None False False False https://github.com/nod-ai/SHARK/issues/344
28 mobilenet_v3_small linalg torch 1e-1 1e-2 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/388 macos
35 wide_resnet50_2 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col False False False macos
36 efficientnet-v2-s mhlo tf 1e-02 1e-3 default nhcw-nhwc False False False macos
37 mnasnet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True macos
38 efficientnet_b0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True False https://github.com/nod-ai/SHARK/issues/1243
39 efficientnet_b7 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False Fails on MacOS builder, VK device lost macos
40 efficientnet_b0 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False False
41 efficientnet_b7 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False False Fails on MacOS builder, VK device lost macos
42 gpt2 mhlo tf 1e-2 1e-3 default None True False False
43 t5-base linalg torch 1e-2 1e-3 default None True True True Inputs for seq2seq models in torch currently unsupported.
44 t5-base mhlo tf 1e-2 1e-3 default None False False False
45 t5-large linalg torch 1e-2 1e-3 default None True True True Inputs for seq2seq models in torch currently unsupported
46 t5-large mhlo tf 1e-2 1e-3 default None False False False
47 stabilityai/stable-diffusion-2-1-base linalg torch 1e-3 1e-3 default None True False False

View File

@@ -70,7 +70,7 @@ if __name__ == "__main__":
backend_config = "dylib"
# backend = "cuda"
# backend_config = "cuda"
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
flatbuffer_blob = compile_str(
compiler_module,
target_backends=[backend],

View File

@@ -146,7 +146,6 @@ if __name__ == "__main__":
backend_config = "cuda"
args = [
"--iree-cuda-llvm-target-arch=sm_80",
"--iree-hal-cuda-disable-loop-nounroll-wa",
"--iree-enable-fusion-with-reduction-ops",
]

View File

@@ -91,7 +91,7 @@ if __name__ == "__main__":
backend_config = "dylib"
# backend = "cuda"
# backend_config = "cuda"
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
flatbuffer_blob = compile_str(
compiler_module,
target_backends=[backend],

View File

@@ -86,7 +86,7 @@ if __name__ == "__main__":
backend_config = "dylib"
# backend = "cuda"
# backend_config = "cuda"
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
flatbuffer_blob = compile_str(
compiler_module,
target_backends=[backend],

View File

@@ -26,16 +26,17 @@ from apps.stable_diffusion.src.utils.stable_args import (
def create_hash(file_name):
with open(file_name, "rb") as f:
file_hash = hashlib.blake2b()
while chunk := f.read(2**20):
file_hash = hashlib.blake2b(digest_size=64)
while chunk := f.read(2**10):
file_hash.update(chunk)
return file_hash.hexdigest()
def save_torch_model(torch_model_list):
def save_torch_model(torch_model_list, local_tank_cache, import_args):
from tank.model_utils import (
get_hf_model,
get_hf_seq2seq_model,
get_vision_model,
get_hf_img_cls_model,
get_fp16_model,
@@ -58,8 +59,7 @@ def save_torch_model(torch_model_list):
if model_type == "stable_diffusion":
args.use_tuned = False
args.import_mlir = True
args.use_tuned = False
args.local_tank_cache = WORKDIR
args.local_tank_cache = local_tank_cache
precision_values = ["fp16"]
seq_lengths = [64, 77]
@@ -74,24 +74,41 @@ def save_torch_model(torch_model_list):
width=512,
height=512,
use_base_vae=False,
custom_vae="",
debug=True,
sharktank_dir=WORKDIR,
sharktank_dir=local_tank_cache,
generate_vmfb=False,
)
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(torch_model_name)
model, input, _ = get_vision_model(
torch_model_name, import_args
)
elif model_type == "hf":
model, input, _ = get_hf_model(torch_model_name)
model, input, _ = get_hf_model(torch_model_name, import_args)
elif model_type == "hf_seq2seq":
model, input, _ = get_hf_seq2seq_model(
torch_model_name, import_args
)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(torch_model_name)
model, input, _ = get_hf_img_cls_model(
torch_model_name, import_args
)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name)
model, input, _ = get_fp16_model(torch_model_name, import_args)
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"
)
if import_args["batch_size"] != 1:
torch_model_dir = os.path.join(
local_tank_cache,
str(torch_model_name)
+ "_torch"
+ f"_BS{str(import_args['batch_size'])}",
)
else:
torch_model_dir = os.path.join(
local_tank_cache, str(torch_model_name) + "_torch"
)
os.makedirs(torch_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
@@ -115,13 +132,18 @@ def save_torch_model(torch_model_list):
)
def save_tf_model(tf_model_list):
def save_tf_model(tf_model_list, local_tank_cache, import_args):
from tank.model_utils_tf import (
get_causal_image_model,
get_masked_lm_model,
get_causal_lm_model,
get_keras_model,
get_TFhf_model,
get_tfhf_seq2seq_model,
)
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
@@ -145,16 +167,38 @@ def save_tf_model(tf_model_list):
input = None
print(f"Generating artifacts for model {tf_model_name}")
if model_type == "hf":
model, input, _ = get_causal_lm_model(tf_model_name)
if model_type == "img":
model, input, _ = get_causal_image_model(tf_model_name)
if model_type == "keras":
model, input, _ = get_keras_model(tf_model_name)
if model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name)
model, input, _ = get_masked_lm_model(
tf_model_name, import_args
)
elif model_type == "img":
model, input, _ = get_causal_image_model(
tf_model_name, import_args
)
elif model_type == "keras":
model, input, _ = get_keras_model(tf_model_name, import_args)
elif model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name, import_args)
elif model_type == "tfhf_seq2seq":
model, input, _ = get_tfhf_seq2seq_model(
tf_model_name, import_args
)
elif model_type == "hf_causallm":
model, input, _ = get_causal_lm_model(
tf_model_name, import_args
)
tf_model_name = tf_model_name.replace("/", "_")
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
if import_args["batch_size"] != 1:
tf_model_dir = os.path.join(
local_tank_cache,
str(tf_model_name)
+ "_tf"
+ f"_BS{str(import_args['batch_size'])}",
)
else:
tf_model_dir = os.path.join(
local_tank_cache, str(tf_model_name) + "_tf"
)
os.makedirs(tf_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
model,
@@ -166,13 +210,9 @@ def save_tf_model(tf_model_list):
dir=tf_model_dir,
model_name=tf_model_name,
)
mlir_hash = create_hash(
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
)
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
def save_tflite_model(tflite_model_list):
def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
from shark.tflite_utils import TFLitePreprocessor
with open(tflite_model_list) as csvfile:
@@ -184,18 +224,18 @@ def save_tflite_model(tflite_model_list):
print("tflite_model_name", tflite_model_name)
print("tflite_model_link", tflite_model_link)
tflite_model_name_dir = os.path.join(
WORKDIR, str(tflite_model_name) + "_tflite"
local_tank_cache, str(tflite_model_name) + "_tflite"
)
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
# Preprocess to get SharkImporter input args
# Preprocess to get SharkImporter input import_args
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()
# Use SharkImporter to get SharkInference input args
# Use SharkImporter to get SharkInference input import_args
my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
@@ -219,6 +259,71 @@ def save_tflite_model(tflite_model_list):
)
def check_requirements(frontend):
import importlib
has_pkgs = False
if frontend == "torch":
tv_spec = importlib.util.find_spec("torchvision")
has_pkgs = tv_spec is not None
elif frontend in ["tensorflow", "tf"]:
tf_spec = importlib.util.find_spec("tensorflow")
has_pkgs = tf_spec is not None
return has_pkgs
class NoImportException(Exception):
"Raised when requirements are not met for OTF model artifact generation."
pass
def gen_shark_files(modelname, frontend, tank_dir, importer_args):
# If a model's artifacts are requested by shark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
# TODO: Add TFlite support.
import tempfile
import_args = importer_args
if check_requirements(frontend):
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tf_model_list.csv"
)
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
)
# Create a temporary .csv with only the desired entry.
if frontend == "tf":
with open(tf_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_tf_model(custom_model_csv.name, tank_dir, import_args)
elif frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_torch_model(custom_model_csv.name, tank_dir, import_args)
else:
raise NoImportException
# Validates whether the file is present or not.
def is_valid_file(arg):
if not os.path.exists(arg):
@@ -228,7 +333,7 @@ def is_valid_file(arg):
if __name__ == "__main__":
# Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality
# Note, all of these flags are overridden by the import of import_args from stable_args.py, flags are duplicated temporarily to preserve functionality
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--torch_model_csv",
@@ -256,23 +361,26 @@ if __name__ == "__main__":
# )
# parser.add_argument("--upload", type=bool, default=False)
# old_args = parser.parse_args()
# old_import_args = parser.parse_import_args()
import_args = {
"batch_size": "1",
}
print(import_args)
home = str(Path.home())
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
WORKDIR = os.path.join(os.path.dirname(__file__), "..", "gen_shark_tank")
torch_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tf_model_list.csv"
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
tflite_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
)
save_torch_model(
os.path.join(os.path.dirname(__file__), "tank", "torch_sd_list.csv")
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
WORKDIR,
import_args,
)
save_torch_model(torch_model_csv)
save_tf_model(tf_model_csv)
save_tflite_model(tflite_model_csv)
save_torch_model(torch_model_csv, WORKDIR, import_args)
save_tf_model(tf_model_csv, WORKDIR, import_args)
save_tflite_model(tflite_model_csv, WORKDIR, import_args)

View File

@@ -31,4 +31,12 @@ xlm-roberta-base,False,False,-,-,-
facebook/convnext-tiny-224,False,False,-,-,-
efficientnet-v2-s,False,False,22M,"image-classification,cnn","Includes MBConv and Fused-MBConv"
mnasnet1_0,False,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
bert-large-uncased,True,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
t5-base,True,False,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
t5-large,True,False,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
efficientnet_b0,True,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
gpt2,True,False,110M,"nlp;transformer-decoder;auto-regressive","12 layers, 768 hidden units, 12 attention heads"
t5-base,True,False,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
t5-large,True,False,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
1 model_name use_tracing dynamic param_count tags notes
31 facebook/convnext-tiny-224 False False - - -
32 efficientnet-v2-s False False 22M image-classification,cnn Includes MBConv and Fused-MBConv
33 mnasnet1_0 False True - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
34 bert-large-uncased True True 330M nlp;bert-variant;transformer-encoder 24 layers, 1024 hidden units, 16 attention heads
35 t5-base True False 220M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
36 t5-large True False 770M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
37 bert-large-uncased True hf True 330M nlp;bert-variant;transformer-encoder
38 efficientnet_b0 True False 5.3M image-classification;cnn;conv2d;depthwise-conv Smallest EfficientNet variant with 224x224 input
39 efficientnet_b7 True False 66M image-classification;cnn;conv2d;depthwise-conv Largest EfficientNet variant with 600x600 input
40 gpt2 True False 110M nlp;transformer-decoder;auto-regressive 12 layers, 768 hidden units, 12 attention heads
41 t5-base True False 220M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
42 t5-large True False 770M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer

View File

@@ -1,5 +1,4 @@
from shark.shark_inference import SharkInference
from shark.parser import shark_args
import torch
import numpy as np
@@ -7,6 +6,8 @@ import sys
torch.manual_seed(0)
BATCH_SIZE = 1
vision_models = [
"alexnet",
"resnet101",
@@ -17,6 +18,8 @@ vision_models = [
"wide_resnet50_2",
"mobilenet_v3_small",
"mnasnet1_0",
"efficientnet_b0",
"efficientnet_b7",
]
hf_img_cls_models = [
"google/vit-base-patch16-224",
@@ -25,17 +28,23 @@ hf_img_cls_models = [
"microsoft/beit-base-patch16-224-pt22k-ft22k",
"nvidia/mit-b0",
]
hf_seq2seq_models = [
"t5-base",
"t5-large",
]
def get_torch_model(modelname):
def get_torch_model(modelname, import_args):
if modelname in vision_models:
return get_vision_model(modelname)
return get_vision_model(modelname, import_args)
elif modelname in hf_img_cls_models:
return get_hf_img_cls_model(modelname)
return get_hf_img_cls_model(modelname, import_args)
elif modelname in hf_seq2seq_models:
return get_hf_seq2seq_model(modelname, import_args)
elif "fp16" in modelname:
return get_fp16_model(modelname)
return get_fp16_model(modelname, import_args)
else:
return get_hf_model(modelname)
return get_hf_model(modelname, import_args)
##################### Hugging Face Image Classification Models ###################################
@@ -78,13 +87,14 @@ class HuggingFaceImageClassification(torch.nn.Module):
return self.model.forward(inputs)[0]
def get_hf_img_cls_model(name):
def get_hf_img_cls_model(name, import_args):
model = HuggingFaceImageClassification(name)
# you can use preprocess_input_image to get the test_input or just random value.
test_input = preprocess_input_image(name)
# test_input = torch.FloatTensor(1, 3, 224, 224).uniform_(-1, 1)
# print("test_input.shape: ", test_input.shape)
# test_input.shape: torch.Size([1, 3, 224, 224])
test_input = test_input.repeat(int(import_args["batch_size"]), 1, 1, 1)
actual_out = model(test_input)
# print("actual_out.shape ", actual_out.shape)
# actual_out.shape torch.Size([1, 1000])
@@ -114,18 +124,58 @@ class HuggingFaceLanguage(torch.nn.Module):
return self.model.forward(tokens)[0]
def get_hf_model(name):
def get_hf_model(name, import_args):
from transformers import (
BertTokenizer,
)
model = HuggingFaceLanguage(name)
# TODO: Currently the test input is set to (1,128)
test_input = torch.randint(2, (1, 128))
test_input = torch.randint(2, (int(import_args["batch_size"]), 128))
actual_out = model(test_input)
return model, test_input, actual_out
##################### Hugging Face Seq2SeqLM Models ###################################
# We use a maximum sequence length of 512 since this is the default used in the T5 config.
T5_MAX_SEQUENCE_LENGTH = 512
class HFSeq2SeqLanguageModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
from transformers import AutoTokenizer, T5Model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenization_kwargs = {
"pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH,
"padding": True,
"return_tensors": "pt",
}
self.model = T5Model.from_pretrained(model_name, return_dict=True)
def preprocess_input(self, text):
return self.tokenizer(text, **self.tokenization_kwargs)
def forward(self, input_ids, decoder_input_ids):
return self.model.forward(
input_ids, decoder_input_ids=decoder_input_ids
)[0]
def get_hf_seq2seq_model(name, import_args):
m = HFSeq2SeqLanguageModel(name)
encoded_input_ids = m.preprocess_input(
"Studies have been shown that owning a dog is good for you"
).input_ids
decoder_input_ids = m.preprocess_input("Studies show that").input_ids
decoder_input_ids = m.model._shift_right(decoder_input_ids)
test_input = (encoded_input_ids, decoder_input_ids)
actual_out = m.forward(*test_input)
return m, test_input, actual_out
################################################################################
##################### Torch Vision Models ###################################
@@ -141,27 +191,55 @@ class VisionModule(torch.nn.Module):
return self.model.forward(input)
def get_vision_model(torch_model):
def get_vision_model(torch_model, import_args):
import torchvision.models as models
vision_models_dict = {
"alexnet": models.alexnet(weights="DEFAULT"),
"resnet18": models.resnet18(weights="DEFAULT"),
"resnet50": models.resnet50(weights="DEFAULT"),
"resnet50_fp16": models.resnet50(weights="DEFAULT"),
"resnet101": models.resnet101(weights="DEFAULT"),
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
"mobilenet_v3_small": models.mobilenet_v3_small(weights="DEFAULT"),
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
}
if isinstance(torch_model, str):
fp16_model = None
if "fp16" in torch_model:
fp16_model = True
torch_model = vision_models_dict[torch_model]
default_image_size = (224, 224)
modelname = torch_model
if modelname == "alexnet":
torch_model = models.alexnet(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "resnet18":
torch_model = models.resnet18(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "resnet50":
torch_model = models.resnet50(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "resnet50_fp16":
torch_model = models.resnet50(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "resnet50_fp16":
torch_model = models.resnet50(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "resnet101":
torch_model = models.resnet101(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "squeezenet1_0":
torch_model = models.squeezenet1_0(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "wide_resnet50_2":
torch_model = models.wide_resnet50_2(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "mobilenet_v3_small":
torch_model = models.mobilenet_v3_small(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "mnasnet1_0":
torch_model = models.mnasnet1_0(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "efficientnet_b0":
torch_model = models.efficientnet_b0(weights="DEFAULT")
input_image_size = (224, 224)
if modelname == "efficientnet_b7":
torch_model = models.efficientnet_b7(weights="DEFAULT")
input_image_size = (600, 600)
fp16_model = False
if "fp16" in modelname:
fp16_model = True
model = VisionModule(torch_model)
test_input = torch.randn(1, 3, 224, 224)
test_input = torch.randn(
int(import_args["batch_size"]), 3, *input_image_size
)
actual_out = model(test_input)
if fp16_model is not None:
test_input_fp16 = test_input.to(
@@ -202,13 +280,14 @@ class BertHalfPrecisionModel(torch.nn.Module):
return self.model.forward(tokens)[0]
def get_fp16_model(torch_model):
def get_fp16_model(torch_model, import_args):
from transformers import AutoTokenizer
modelname = torch_model.replace("_fp16", "")
model = BertHalfPrecisionModel(modelname)
tokenizer = AutoTokenizer.from_pretrained(modelname)
text = "Replace me by any text you like."
text = [text] * int(import_args["batch_size"])
test_input_fp16 = tokenizer(
text,
truncation=True,

View File

@@ -1,17 +1,19 @@
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
import numpy as np
from transformers import (
AutoModelForSequenceClassification,
BertTokenizer,
TFBertModel,
)
BATCH_SIZE = 1
MAX_SEQUENCE_LENGTH = 128
################################## MHLO/TF models #########################################
# TODO : Generate these lists or fetch model source from tank/tf/tf_model_list.csv
keras_models = ["resnet50", "efficientnet-v2-s"]
keras_models = [
"resnet50",
"efficientnet_b0",
"efficientnet_b7",
"efficientnet-v2-s",
]
maskedlm_models = [
"albert-base-v2",
"bert-base-uncased",
@@ -32,36 +34,61 @@ maskedlm_models = [
"hf-internal-testing/tiny-random-flaubert",
"xlm-roberta",
]
causallm_models = [
"gpt2",
]
tfhf_models = [
"microsoft/MiniLM-L12-H384-uncased",
]
tfhf_seq2seq_models = [
"t5-base",
"t5-large",
]
img_models = [
"google/vit-base-patch16-224",
"facebook/convnext-tiny-224",
]
def get_tf_model(name):
def get_tf_model(name, import_args):
if name in keras_models:
return get_keras_model(name)
return get_keras_model(name, import_args)
elif name in maskedlm_models:
return get_causal_lm_model(name)
return get_masked_lm_model(name, import_args)
elif name in causallm_models:
return get_causal_lm_model(name, import_args)
elif name in tfhf_models:
return get_TFhf_model(name)
return get_TFhf_model(name, import_args)
elif name in img_models:
return get_causal_image_model(name)
return get_causal_image_model(name, import_args)
elif name in tfhf_seq2seq_models:
return get_tfhf_seq2seq_model(name, import_args)
else:
raise Exception(
"TF model not found! Please check that the modelname has been input correctly."
)
##################### Tensorflow Hugging Face LM Models ###################################
##################### Tensorflow Hugging Face Bert Models ###################################
from transformers import (
AutoModelForSequenceClassification,
BertTokenizer,
TFBertModel,
)
BERT_MAX_SEQUENCE_LENGTH = 128
# Create a set of 2-dimensional inputs
tf_bert_input = [
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
tf.TensorSpec(
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
tf.TensorSpec(
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
]
@@ -81,27 +108,37 @@ class TFHuggingFaceLanguage(tf.Module):
return self.m.predict(input_ids, attention_mask, token_type_ids)
def get_TFhf_model(name):
def get_TFhf_model(name, import_args):
model = TFHuggingFaceLanguage(name)
tokenizer = BertTokenizer.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased"
)
text = "Replace me by any text you'd like."
text = [text] * BATCH_SIZE
encoded_input = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
)
for key in encoded_input:
encoded_input[key] = tf.expand_dims(
tf.convert_to_tensor(encoded_input[key]), 0
)
test_input = (
encoded_input["input_ids"],
encoded_input["attention_mask"],
encoded_input["token_type_ids"],
max_length=BERT_MAX_SEQUENCE_LENGTH,
)
test_input = [
tf.reshape(
tf.convert_to_tensor(encoded_input["input_ids"], dtype=tf.int32),
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
),
tf.reshape(
tf.convert_to_tensor(
encoded_input["attention_mask"], dtype=tf.int32
),
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
),
tf.reshape(
tf.convert_to_tensor(
encoded_input["token_type_ids"], dtype=tf.int32
),
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
),
]
actual_out = model.forward(*test_input)
return model, test_input, actual_out
@@ -115,34 +152,40 @@ def compare_tensors_tf(tf_tensor, numpy_tensor):
return np.allclose(tf_to_numpy, numpy_tensor, rtol, atol)
##################### Tensorflow Hugging Face Masked LM Models ###################################
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
import tensorflow as tf
# Create a set of input signature.
input_signature_maskedlm = [
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
]
# For supported models please see here:
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForCasualLM
# Tokenizer for language models
def preprocess_input(
model_name, text="This is just used to compile the model"
model_name, max_length, text="This is just used to compile the model"
):
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = [text] * BATCH_SIZE
inputs = tokenizer(
text,
padding="max_length",
return_tensors="tf",
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
max_length=max_length,
)
return inputs
##################### Tensorflow Hugging Face Masked LM Models ###################################
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
MASKED_LM_MAX_SEQUENCE_LENGTH = 128
# Create a set of input signature.
input_signature_maskedlm = [
tf.TensorSpec(
shape=[BATCH_SIZE, MASKED_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
tf.TensorSpec(
shape=[BATCH_SIZE, MASKED_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
]
# For supported models please see here:
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForMaskedLM
class MaskedLM(tf.Module):
def __init__(self, model_name):
super(MaskedLM, self).__init__()
@@ -156,19 +199,143 @@ class MaskedLM(tf.Module):
return self.m.predict(input_ids, attention_mask)
def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
def get_masked_lm_model(
hf_name, import_args, text="Hello, this is the default text."
):
model = MaskedLM(hf_name)
encoded_input = preprocess_input(hf_name, text)
encoded_input = preprocess_input(
hf_name, MASKED_LM_MAX_SEQUENCE_LENGTH, text
)
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
actual_out = model.forward(*test_input)
return model, test_input, actual_out
##################### Tensorflow Hugging Face Causal LM Models ###################################
from transformers import AutoConfig, TFAutoModelForCausalLM, TFGPT2Model
CAUSAL_LM_MAX_SEQUENCE_LENGTH = 1024
input_signature_causallm = [
tf.TensorSpec(
shape=[BATCH_SIZE, CAUSAL_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
tf.TensorSpec(
shape=[BATCH_SIZE, CAUSAL_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
]
# For supported models please see here:
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForCausalLM
# For more background, see:
# https://huggingface.co/blog/tf-xla-generate
class CausalLM(tf.Module):
def __init__(self, model_name):
super(CausalLM, self).__init__()
# Decoder-only models need left padding.
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, padding_side="left", pad_token="</s>"
)
self.tokenization_kwargs = {
"pad_to_multiple_of": CAUSAL_LM_MAX_SEQUENCE_LENGTH,
"padding": True,
"return_tensors": "tf",
}
self.model = TFGPT2Model.from_pretrained(model_name, return_dict=True)
self.model.predict = lambda x, y: self.model(
input_ids=x, attention_mask=y
)[0]
def preprocess_input(self, text):
return self.tokenizer(text, **self.tokenization_kwargs)
@tf.function(input_signature=input_signature_causallm, jit_compile=True)
def forward(self, input_ids, attention_mask):
return self.model.predict(input_ids, attention_mask)
def get_causal_lm_model(
hf_name, import_args, text="Hello, this is the default text."
):
model = CausalLM(hf_name)
batched_text = [text] * BATCH_SIZE
encoded_input = model.preprocess_input(batched_text)
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
actual_out = model.forward(*test_input)
return model, test_input, actual_out
##################### TensorflowHugging Face Seq2SeqLM Models ###################################
# We use a maximum sequence length of 512 since this is the default used in the T5 config.
T5_MAX_SEQUENCE_LENGTH = 512
input_signature_t5 = [
tf.TensorSpec(
shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH],
dtype=tf.int32,
name="input_ids",
),
tf.TensorSpec(
shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH],
dtype=tf.int32,
name="attention_mask",
),
]
class TFHFSeq2SeqLanguageModel(tf.Module):
def __init__(self, model_name):
super(TFHFSeq2SeqLanguageModel, self).__init__()
from transformers import (
AutoTokenizer,
AutoConfig,
TFAutoModelForSeq2SeqLM,
TFT5Model,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenization_kwargs = {
"pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH,
"padding": True,
"return_tensors": "tf",
}
self.model = TFT5Model.from_pretrained(model_name, return_dict=True)
self.model.predict = lambda x, y: self.model(x, decoder_input_ids=y)[0]
def preprocess_input(self, text):
return self.tokenizer(text, **self.tokenization_kwargs)
@tf.function(input_signature=input_signature_t5, jit_compile=True)
def forward(self, input_ids, decoder_input_ids):
return self.model.predict(input_ids, decoder_input_ids)
def get_tfhf_seq2seq_model(name, import_args):
m = TFHFSeq2SeqLanguageModel(name)
text = "Studies have been shown that owning a dog is good for you"
batched_text = [text] * BATCH_SIZE
encoded_input_ids = m.preprocess_input(batched_text).input_ids
text = "Studies show that"
batched_text = [text] * BATCH_SIZE
decoder_input_ids = m.preprocess_input(batched_text).input_ids
decoder_input_ids = m.model._shift_right(decoder_input_ids)
test_input = (encoded_input_ids, decoder_input_ids)
actual_out = m.forward(*test_input)
return m, test_input, actual_out
##################### TensorFlow Keras Resnet Models #########################################################
# Static shape, including batch size (1).
# Can be dynamic once dynamic shape support is ready.
RESNET_INPUT_SHAPE = [1, 224, 224, 3]
EFFICIENTNET_INPUT_SHAPE = [1, 384, 384, 3]
RESNET_INPUT_SHAPE = [BATCH_SIZE, 224, 224, 3]
EFFICIENTNET_V2_S_INPUT_SHAPE = [BATCH_SIZE, 384, 384, 3]
EFFICIENTNET_B0_INPUT_SHAPE = [BATCH_SIZE, 224, 224, 3]
EFFICIENTNET_B7_INPUT_SHAPE = [BATCH_SIZE, 600, 600, 3]
class ResNetModule(tf.Module):
@@ -195,25 +362,79 @@ class ResNetModule(tf.Module):
return tf.keras.applications.resnet50.preprocess_input(image)
class EfficientNetModule(tf.Module):
class EfficientNetB0Module(tf.Module):
def __init__(self):
super(EfficientNetModule, self).__init__()
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
super(EfficientNetB0Module, self).__init__()
self.m = tf.keras.applications.efficientnet.EfficientNetB0(
weights="imagenet",
include_top=True,
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
input_shape=tuple(EFFICIENTNET_B0_INPUT_SHAPE[1:]),
)
self.m.predict = lambda x: self.m.call(x, training=False)
@tf.function(
input_signature=[tf.TensorSpec(EFFICIENTNET_INPUT_SHAPE, tf.float32)],
input_signature=[
tf.TensorSpec(EFFICIENTNET_B0_INPUT_SHAPE, tf.float32)
],
jit_compile=True,
)
def forward(self, inputs):
return self.m.predict(inputs)
def input_shape(self):
return EFFICIENTNET_INPUT_SHAPE
return EFFICIENTNET_B0_INPUT_SHAPE
def preprocess_input(self, image):
return tf.keras.applications.efficientnet.preprocess_input(image)
class EfficientNetB7Module(tf.Module):
def __init__(self):
super(EfficientNetB7Module, self).__init__()
self.m = tf.keras.applications.efficientnet.EfficientNetB7(
weights="imagenet",
include_top=True,
input_shape=tuple(EFFICIENTNET_B7_INPUT_SHAPE[1:]),
)
self.m.predict = lambda x: self.m.call(x, training=False)
@tf.function(
input_signature=[
tf.TensorSpec(EFFICIENTNET_B7_INPUT_SHAPE, tf.float32)
],
jit_compile=True,
)
def forward(self, inputs):
return self.m.predict(inputs)
def input_shape(self):
return EFFICIENTNET_B7_INPUT_SHAPE
def preprocess_input(self, image):
return tf.keras.applications.efficientnet.preprocess_input(image)
class EfficientNetV2SModule(tf.Module):
def __init__(self):
super(EfficientNetV2SModule, self).__init__()
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
weights="imagenet",
include_top=True,
input_shape=tuple(EFFICIENTNET_V2_S_INPUT_SHAPE[1:]),
)
self.m.predict = lambda x: self.m.call(x, training=False)
@tf.function(
input_signature=[
tf.TensorSpec(EFFICIENTNET_V2_S_INPUT_SHAPE, tf.float32)
],
jit_compile=True,
)
def forward(self, inputs):
return self.m.predict(inputs)
def input_shape(self):
return EFFICIENTNET_V2_S_INPUT_SHAPE
def preprocess_input(self, image):
return tf.keras.applications.efficientnet_v2.preprocess_input(image)
@@ -224,12 +445,17 @@ def load_image(path_to_image, width, height, channels):
image = tf.image.decode_image(image, channels=channels)
image = tf.image.resize(image, (width, height))
image = image[tf.newaxis, :]
image = tf.tile(image, [BATCH_SIZE, 1, 1, 1])
return image
def get_keras_model(modelname):
def get_keras_model(modelname, import_args):
if modelname == "efficientnet-v2-s":
model = EfficientNetModule()
model = EfficientNetV2SModule()
elif modelname == "efficientnet_b0":
model = EfficientNetB0Module()
elif modelname == "efficientnet_b7":
model = EfficientNetB7Module()
else:
model = ResNetModule()
@@ -256,7 +482,7 @@ import requests
# Create a set of input signature.
input_signature_img_cls = [
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
tf.TensorSpec(shape=[BATCH_SIZE, 3, 224, 224], dtype=tf.float32),
]
@@ -304,11 +530,14 @@ def preprocess_input_image(model_name):
)
# inputs: {'pixel_values': <tf.Tensor: shape=(1, 3, 224, 224), dtype=float32, numpy=array([[[[]]]], dtype=float32)>}
inputs = feature_extractor(images=image, return_tensors="tf")
inputs["pixel_values"] = tf.tile(
inputs["pixel_values"], [BATCH_SIZE, 1, 1, 1]
)
return [inputs[str(*inputs)]]
def get_causal_image_model(hf_name):
def get_causal_image_model(hf_name, import_args):
model = AutoModelImageClassfication(hf_name)
test_input = preprocess_input_image(hf_name)
# TFSequenceClassifierOutput(loss=None, logits=<tf.Tensor: shape=(1, 1000), dtype=float32, numpy=

View File

@@ -4,10 +4,8 @@ from shark.iree_utils._common import (
get_supported_device_list,
)
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
from parameterized import parameterized
from shark.shark_downloader import download_model
from shark.shark_inference import SharkInference
from shark.parser import shark_args
from parameterized import parameterized
import iree.compiler as ireec
import pytest
import unittest
@@ -15,8 +13,8 @@ import numpy as np
import csv
import tempfile
import os
import sys
import shutil
import multiprocessing
def load_csv_and_convert(filename, gen=False):
@@ -48,7 +46,9 @@ def load_csv_and_convert(filename, gen=False):
)
# This is a pytest workaround
if gen:
with open("tank/dict_configs.py", "w+") as out:
with open(
os.path.join(os.path.dirname(__file__), "dict_configs.py"), "w+"
) as out:
out.write("ALL = [\n")
for c in model_configs:
out.write(str(c) + ",\n")
@@ -68,7 +68,9 @@ def get_valid_test_params():
dynamic_list = (True, False)
# TODO: This is soooo ugly, but for some reason creating the dict at runtime
# results in strange pytest failures.
load_csv_and_convert("tank/all_models.csv", True)
load_csv_and_convert(
os.path.join(os.path.dirname(__file__), "all_models.csv"), True
)
from tank.dict_configs import ALL
config_list = ALL
@@ -135,9 +137,12 @@ class SharkModuleTester:
self.config = config
def create_and_check_module(self, dynamic, device):
shark_args.update_tank = self.update_tank
shark_args.force_update_tank = self.force_update_tank
shark_args.shark_prefix = self.shark_tank_prefix
shark_args.local_tank_cache = self.local_tank_cache
shark_args.force_update_tank = self.update_tank
shark_args.dispatch_benchmarks = self.benchmark_dispatches
if self.benchmark_dispatches is not None:
_m = self.config["model_name"].split("/")
_m.extend([self.config["framework"], str(dynamic), device])
@@ -161,17 +166,40 @@ class SharkModuleTester:
if "winograd" in self.config["flags"]:
shark_args.use_winograd = True
model, func_name, inputs, golden_out = download_model(
self.config["model_name"],
tank_url=self.tank_url,
frontend=self.config["framework"],
)
import_config = {
"batch_size": self.batch_size,
}
from shark.shark_downloader import download_model
from shark.shark_inference import SharkInference
from tank.generate_sharktank import NoImportException
dl_gen_attempts = 2
for i in range(dl_gen_attempts):
try:
model, func_name, inputs, golden_out = download_model(
self.config["model_name"],
frontend=self.config["framework"],
import_args=import_config,
)
except NoImportException as err:
pytest.xfail(
reason=f"Artifacts for this model/config must be generated locally. Please make sure {self.config['framework']} is installed."
)
except AssertionError as err:
if i < dl_gen_attempts - 1:
continue
else:
pytest.xfail(
"Generating OTF may require exiting the subprocess for files to be available."
)
break
is_bench = True if self.benchmark is not None else False
shark_module = SharkInference(
model,
device=device,
mlir_dialect=self.config["dialect"],
is_benchmark=self.benchmark,
is_benchmark=is_bench,
)
try:
@@ -185,6 +213,10 @@ class SharkModuleTester:
result = shark_module(func_name, inputs)
golden_out, result = self.postprocess_outputs(golden_out, result)
if self.tf32 == "true":
print("Validating with relaxed tolerances.")
atol = 1e-02
rtol = 1e-03
try:
np.testing.assert_allclose(
golden_out,
@@ -197,23 +229,31 @@ class SharkModuleTester:
self.save_reproducers()
if self.ci == True:
self.upload_repro()
if self.benchmark == True:
self.benchmark_module(shark_module, inputs, dynamic, device)
if self.benchmark is not None:
self.benchmark_module(
shark_module, inputs, dynamic, device, mode=self.benchmark
)
print(msg)
pytest.xfail(
reason=f"Numerics Mismatch: Use -s flag to print stderr during pytests."
)
if self.benchmark == True:
self.benchmark_module(shark_module, inputs, dynamic, device)
if self.benchmark is not None:
self.benchmark_module(
shark_module, inputs, dynamic, device, mode=self.benchmark
)
if self.save_repro == True:
self.save_reproducers()
def benchmark_module(self, shark_module, inputs, dynamic, device):
def benchmark_module(
self, shark_module, inputs, dynamic, device, mode="native"
):
model_config = {
"batch_size": self.batch_size,
}
shark_args.enable_tf32 = self.tf32
if shark_args.enable_tf32 == True:
shark_module.compile()
shark_args.enable_tf32 = False
shark_args.onnx_bench = self.onnx_bench
shark_module.shark_runner.benchmark_all_csv(
@@ -222,6 +262,8 @@ class SharkModuleTester:
dynamic,
device,
self.config["framework"],
import_args=model_config,
mode=mode,
)
def save_reproducers(self):
@@ -271,6 +313,9 @@ class SharkModuleTest(unittest.TestCase):
@parameterized.expand(param_list, name_func=shark_test_name_func)
def test_module(self, dynamic, device, config):
self.module_tester = SharkModuleTester(config)
self.module_tester.batch_size = self.pytestconfig.getoption(
"batchsize"
)
self.module_tester.benchmark = self.pytestconfig.getoption("benchmark")
self.module_tester.save_repro = self.pytestconfig.getoption(
"save_repro"
@@ -290,7 +335,12 @@ class SharkModuleTest(unittest.TestCase):
self.module_tester.update_tank = self.pytestconfig.getoption(
"update_tank"
)
self.module_tester.tank_url = self.pytestconfig.getoption("tank_url")
self.module_tester.force_update_tank = self.pytestconfig.getoption(
"force_update_tank"
)
self.module_tester.shark_tank_prefix = self.pytestconfig.getoption(
"tank_prefix"
)
self.module_tester.benchmark_dispatches = self.pytestconfig.getoption(
"benchmark_dispatches"
)
@@ -307,19 +357,26 @@ class SharkModuleTest(unittest.TestCase):
if config["xfail_vkm"] == "True" and device in ["metal", "vulkan"]:
pytest.xfail(reason=config["xfail_reason"])
if os.name == "nt" and "enabled_windows" not in config["xfail_other"]:
if (
self.pytestconfig.getoption("ci") == True
and os.name == "nt"
and "enabled_windows" not in config["xfail_other"]
):
pytest.xfail(reason="this model skipped on windows")
# Special cases that need to be marked.
if "macos" in config["xfail_other"] and device in [
"metal",
"vulkan",
]:
if get_vulkan_triple_flag() is not None:
if "m1-moltenvk-macos" in get_vulkan_triple_flag():
pytest.xfail(
reason="conv-related issue on MacStudio, returns VK_ERROR_DEVICE_LOST."
)
if (
"macos" in config["xfail_other"]
and device
in [
"metal",
"vulkan",
]
and sys.platform == "darwin"
):
pytest.skip(
reason="conv-related issue on MacStudio, returns VK_ERROR_DEVICE_LOST."
)
if (
config["model_name"]
in [
@@ -342,6 +399,10 @@ class SharkModuleTest(unittest.TestCase):
pytest.xfail(
reason="Numerics issues: https://github.com/nod-ai/SHARK/issues/476"
)
if config["framework"] == "tf" and self.module_tester.batch_size != 1:
pytest.xfail(
reason="Configurable batch sizes temp. unavailable for tensorflow models."
)
safe_name = (
f"{config['model_name']}_{config['framework']}_{dynamic}_{device}"
)

View File

@@ -19,3 +19,10 @@ facebook/convnext-tiny-224,img
google/vit-base-patch16-224,img
efficientnet-v2-s,keras
bert-large-uncased,hf
t5-base,tfhf_seq2seq
t5-large,tfhf_seq2seq
efficientnet_b0,keras
efficientnet_b7,keras
gpt2,hf_causallm
t5-base,tfhf_seq2seq
t5-large,tfhf_seq2seq
1 model_name model_type
19 google/vit-base-patch16-224 img
20 efficientnet-v2-s keras
21 bert-large-uncased hf
22 t5-base tfhf_seq2seq
23 t5-large tfhf_seq2seq
24 efficientnet_b0 keras
25 efficientnet_b7 keras
26 gpt2 hf_causallm
27 t5-base tfhf_seq2seq
28 t5-large tfhf_seq2seq

View File

@@ -1,4 +1,6 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
microsoft/MiniLM-L12-H384-uncased,True,hf,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
bert-base-uncased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
@@ -18,4 +20,4 @@ nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encod
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
1 model_name use_tracing model_type dynamic param_count tags notes
2 efficientnet_b0 True vision False 5.3M image-classification;cnn;conv2d;depthwise-conv Smallest EfficientNet variant with 224x224 input
3 efficientnet_b7 True vision False 66M image-classification;cnn;conv2d;depthwise-conv Largest EfficientNet variant with 600x600 input
4 microsoft/MiniLM-L12-H384-uncased True hf True 66M nlp;bert-variant;transformer-encoder Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)
5 bert-base-uncased True hf True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
6 bert-base-cased True hf True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
20 mnasnet1_0 False vision True - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
21 resnet50_fp16 False vision True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
22 bert-base-uncased_fp16 True fp16 False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
23 bert-large-uncased True hf True 330M nlp;bert-variant;transformer-encoder 24 layers, 1024 hidden units, 16 attention heads

View File

@@ -1,4 +1,3 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
stabilityai/stable-diffusion-2-1-base,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
stabilityai/stable-diffusion-2-1,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
prompthero/openjourney,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
1 model_name use_tracing model_type dynamic param_count tags notes
2 stabilityai/stable-diffusion-2-1-base True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A
3 stabilityai/stable-diffusion-2-1 True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A
prompthero/openjourney True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A

3
tank_version.json Normal file
View File

@@ -0,0 +1,3 @@
{
"version": "2023-03-31_02d52bb"
}