Compare commits

...

63 Commits

Author SHA1 Message Date
Ean Garvey
308856a947 Touch unet if base cfg needed for SD pipeline init (#1281) 2023-04-05 03:02:29 -05:00
m68k-fr
151b4e142f [SD] Fix encoder error for model_max_length not beeing 77 (#1278)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-04-04 22:39:29 -07:00
Ean Garvey
e5a69a7c36 pin diffusers to e47459c (#1279) 2023-04-04 18:29:21 -07:00
m68k-fr
450b6cafc4 [SD] Add weight emphasis to prompts encoder (#1276) 2023-04-04 09:47:04 -07:00
Daniel Garvey
237d26baa2 update model db to reflect changes (#1277)
* remove 1/1 tqdm progress bar

* update model_db to reflect changes
2023-04-04 11:46:55 -05:00
Daniel Garvey
67d6ee1104 remove 1/1 tqdm progress bar (#1274) 2023-04-03 22:30:09 -05:00
Ean Garvey
98b069488e Add tank_version.json (#1272) 2023-04-03 18:36:23 -07:00
jinchen62
e0f227643a Fix webui circular import issue (#1271) 2023-04-03 16:00:10 -07:00
jinchen62
a0af3bb0cb xload and unload models (#1242) 2023-04-03 14:42:18 -07:00
powderluv
2cd61a5b96 strip source map (#1270) 2023-04-03 14:41:32 -07:00
Gaurav Shukla
f49d41a807 [SD] Add Stable diffusion text2image rest API (#1265)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-03 12:02:24 -07:00
Ean Garvey
2191fc8952 Separate pytest benchmark modes and fix model updates for SHARK downloader / pytest. (#1264)
* Only xfail windows models in CI

* downloader: make model updates more robust.

* Separate baseline and native benchmarks in pytest.

* Fix native benchmarks

* Fix torchvision model utils.
2023-04-03 08:24:21 -07:00
PhaneeshB
aea7796e60 add gradio client to spec 2023-04-03 18:57:19 +05:30
Abhishek Varma
a376619f1e [SD] Improve vmfb caching algo and retry mechanism (#1248)
-- This commit gets rid of the all-or-nothing vmfb caching mechanism
   and improves the retry mechanism by providing lower-level granularity
   for compiling each model units.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-03-31 09:38:14 -07:00
powderluv
02d52bb626 Add Intel ARC A770 target triple (#1263)
This just enables the plumbing. It generates black images.
2023-03-29 14:49:05 -07:00
Abhishek Varma
3b63645f79 [SD] Fix custom model path for WebUI (#1260)
-- This commit fixes custom model path for WebUI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-29 09:48:11 -07:00
Ean Garvey
d6f740b998 allow pytest to retry getting model artifacts + disable autotuning for pytorch benchmarks (#1257)
* Adds a few xfails to enable macOS builder

* Convert string batch sizes to ints where needed.

* allow pytest to retry getting model artifacts

* Reduce attempts and add assert msg.
2023-03-28 23:38:45 -05:00
Daniel Garvey
594c6b8ea2 fix ckpt dir (#1258) 2023-03-28 14:31:01 -07:00
Ean Garvey
96b1560da5 Make batch size configurable via pytest and fix sharktank generation. (#1227)
* Fix sharktank generation and add batch_size pytest option for torch.

* Disable torch dynamo until py3.11 supported

* Compile torchmodel without dynamo if torch.compile fails

* Use release versions of TF/Keras for importer.

* Pin torchvision and remove debug prints.

* Remove duplicates from torch model list.

* Update generate_sharktank.py

* xfail a few models that fail sharktank generation/ numerics
2023-03-28 14:33:39 -05:00
Abhishek Varma
0ef6a0e234 [SD] Fix Stencil scribble crash by updating image resize (#1255)
-- This commit updates Stencil resize feature to cap the size of
   images within [128,768] as supported by the SD pipeline.
-- This solves the issue of scribble crashing on larger image.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-28 10:13:11 -07:00
Gaurav Shukla
641d535f44 [SD] Fix device path issue for cpu (#1256)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-28 10:09:49 -07:00
Daniel Garvey
5bb7846227 single entry point exe for all cli apps (#1158)
usage:
add --app="img2img" (or "inpaint" "outpaint" "txt2img")
2023-03-28 11:15:21 -05:00
yzhang93
8f84258fb8 Fix check for use_tuned conditions (#1252) 2023-03-27 11:21:25 -07:00
Ean Garvey
7619e76bbd Disable and xfail some models that fail validation/compilation. (#1251)
* Rollback T5 models for torch as the inputs give some issues that aren't trivial to resolve
* xfail efficientnet-b0 on torch+cuda -- see CUDA requesting shared memory size larger than allowed size openxla/iree#12771
2023-03-27 12:42:53 -05:00
Daniel Garvey
9267eadbfa disable openjourney gen for nightly (#1249) 2023-03-27 11:55:34 -05:00
Phaneesh Barwaria
431132b8ee Fix img2img mode switch (#1247)
* add updated scheduler value in global config

* clear scheduler global variable with others
2023-03-27 07:01:22 -07:00
cstueckrath
fb35e13e7a fix Python version detection bug (#1246)
* fix Python version detection bug

* Update setup_venv.ps1
2023-03-27 07:00:40 -07:00
yzhang93
17a67897d1 Add SD v2.1 768x768 tuned model (#1244)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-24 10:39:15 -07:00
Gaurav Shukla
da449b73aa [SD] Disable lora training tab for now (#1241)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-24 09:16:24 -07:00
Kyle Herndon
0b0526699a Fix incorrect device argument initialization for LoRA training by extracting the device type and number and formatting it for pytorch (#1237)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
2023-03-24 01:10:50 -07:00
Boian Petkantchin
4fac46f7bb In models testing fix paths to be relative to the script dir not cwd (#1128)
authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-03-22 15:26:52 -05:00
Daniel Garvey
49925950f1 fix false positives (#1193) 2023-03-22 15:25:39 -05:00
Thomas
807947c0c8 Remove deprecated cli option iree-hal-cuda-disable-loop-nounroll-wa (#1235) 2023-03-22 12:05:15 -05:00
Abhishek Varma
593428bda4 [SD] Fix for transformers/__init__.py issue in PyInstaller (#1233)
-- This commit fixes the transformers/__init__.py issue in PyInstaller.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-22 08:43:53 -07:00
Abhishek Varma
cede9b4fec [SD] Fix custom_vae as a required parameter in inpaint (#1232) 2023-03-22 04:30:17 -07:00
Prashant Kumar
c2360303f0 Add the int8 quantized model. 2023-03-22 16:28:13 +05:30
jinchen62
420366c1b8 Move schedulers to global obj (#1225) 2023-03-21 22:40:43 -07:00
Ean Garvey
d31bae488c Set iree-input-type to tm_tensor for SD (#1228) 2023-03-21 19:07:31 -07:00
Kyle Herndon
c23fcf3748 Fix incorrect device argument initialization for LoRA training (#1231)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-21 19:07:18 -07:00
jinchen62
7dbbb1726a Fix SD obj not defined if fail to get models from pretrained (#1222) 2023-03-21 07:55:17 -07:00
Abhishek Varma
8b8cc7fd33 [SD] Update LoRA inference to handle various checkpoints (#1215) 2023-03-21 06:52:20 -07:00
Ean Garvey
e3c96a2b9d Move sentencepiece to importer requirements. (#1218) 2023-03-21 00:39:57 -05:00
Ean Garvey
5e3f50647d Set --vulkan_large_heap_block_size default to 2gb. (#1220) 2023-03-20 21:07:09 -07:00
gpetters94
7899e1803a Add fix for attention slicing fp16 (#1217) 2023-03-20 19:11:29 -07:00
mariecwhite
d105246b9c Fix t5 models 2023-03-21 10:39:59 +11:00
mariecwhite
90c958bca2 Add T5-base and T5-large Torch and TF Models (#1116) 2023-03-20 17:32:50 -05:00
mariecwhite
f99903e023 Add EfficientNet B0 and B7 Torch and TF models 2023-03-21 09:22:05 +11:00
mariecwhite
c6f44ef1b3 Add EfficientNet B0 and B7 Torch and TF models 2023-03-21 09:14:45 +11:00
mariecwhite
8dcd4d5aeb Make batch size configurable 2023-03-20 18:03:17 -04:00
Phoenix Meadowlark
d319f4684e Add peak memory reporting for IREE, TF and PyTorch (#1216) 2023-03-20 15:40:49 -05:00
Ean Garvey
54d7b6d83e Generate model artifacts in pytests if they don't exist in the cloud. (#1121)
* Add gen_shark_files fn to shark_downloader for OTF artifact generation

* add generate_sharktank as a tank/ python module.

* Fix some paths in tank generation.
2023-03-20 12:13:19 -05:00
m68k-fr
4a622532e5 [Web] Stop images (#1212) 2023-03-19 14:37:30 -07:00
cstueckrath
650b2ada58 add pytorch_lightning to requirements (#1211)
* add pytorch_lightning to requirements

this will additionally add lightning-utilities and torchmetrics

* Update shark_sd.spec

* Update shark_sd_cli.spec
2023-03-19 12:29:54 -07:00
m68k-fr
f87f8949f3 [Web] CSS fix for gradio V3.22.1 (#1210) 2023-03-19 06:13:59 -07:00
m68k-fr
7dc9bf8148 [Web] Move "stop Batch" button to "Advanced Options" toggle (#1209) 2023-03-18 20:54:42 -07:00
Kyle Herndon
ba48ff8d25 Implement LoRA training and UI for training and UI for inference in img2img, inpaint, outpaint (#1200)
txt2img inference UI is already committed.

Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-17 12:54:56 -07:00
Gaurav Shukla
638840925c [SD] Add support for larger size upscaling (#1204)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-17 10:20:48 -07:00
m68k-fr
b661656c03 [Web] Fix custom model path for upscaler (#1199) 2023-03-16 15:57:23 -07:00
Gaurav Shukla
0225434389 [SD] Add sendTo Upscaler
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-16 20:49:19 +05:30
Gaurav Shukla
7ffe20b1c2 [SD] Release memory used by upscaler when not in use
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-16 20:49:19 +05:30
Gaurav Shukla
d8f0c4655d [SD] Add Upscaler web
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-16 20:49:19 +05:30
Gaurav Shukla
7e8d3ec0df [SD] Add upscalar pipeline
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-16 20:49:19 +05:30
jinchen62
9c08eec565 Clear memory cache when switching model and mode (#1194) 2023-03-15 22:18:26 -07:00
74 changed files with 5384 additions and 1695 deletions

View File

@@ -112,7 +112,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
@@ -120,9 +120,9 @@ jobs:
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
@@ -145,17 +145,19 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
pytest -k vulkan -s
pytest -k vulkan -s --ci
- name: Validate Stable Diffusion Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
python build_tools/stable_diffusion_testing.py --device=vulkan

View File

@@ -114,12 +114,12 @@ source shark.venv/bin/activate
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
```shell
python3.11 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
python3.11 apps/stable_diffusion/scripts/main.py --app=txt2img --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc

View File

@@ -1,4 +1,5 @@
from apps.stable_diffusion.scripts.txt2img import txt2img_inf
from apps.stable_diffusion.scripts.img2img import img2img_inf
from apps.stable_diffusion.scripts.inpaint import inpaint_inf
from apps.stable_diffusion.scripts.outpaint import outpaint_inf
from apps.stable_diffusion.scripts.upscaler import upscaler_inf
from apps.stable_diffusion.scripts.train_lora_word import lora_train

View File

@@ -2,7 +2,7 @@ import sys
import torch
import time
from PIL import Image
from dataclasses import dataclass
import transformers
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
@@ -13,25 +13,9 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
@dataclass
class Config:
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
use_stencil: str
img2img_obj = None
config_obj = None
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -40,15 +24,15 @@ init_import_mlir = args.import_mlir
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# Both width and height should be > 384 and multiple of 8.
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image):
width, height = image.size
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 384:
n_size = 384
if min_size < 128:
n_size = 128
if width == min_size:
width = n_size
height = n_size / aspect_ratio
@@ -61,6 +45,22 @@ def resize_stencil(image: Image.Image):
n_height = height // 8
n_width *= 8
n_height *= 8
min_size = min(width, height)
if min_size > 768:
n_size = 768
if width == min_size:
height = n_size
width = n_size * aspect_ratio
else:
width = n_size
height = n_size / aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
new_image = image.resize((n_width, n_height))
return new_image, n_width, n_height
@@ -87,12 +87,19 @@ def img2img_inf(
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import get_custom_model_pathfile
global img2img_obj
global config_obj
global schedulers
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
@@ -102,16 +109,13 @@ def img2img_inf(
args.strength = strength
args.scheduler = scheduler
args.img_path = "not none"
args.ondemand = ondemand
if init_image is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB")
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -126,6 +130,10 @@ def img2img_inf(
else:
args.hf_model_id = custom_model
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -149,6 +157,7 @@ def img2img_inf(
args.precision = precision
dtype = torch.float32 if precision == "fp32" else torch.half
new_config_obj = Config(
"img2img",
args.hf_model_id,
args.ckpt_loc,
precision,
@@ -157,10 +166,16 @@ def img2img_inf(
height,
width,
device,
use_stencil,
use_lora=args.use_lora,
use_stencil=use_stencil,
)
if not img2img_obj or config_obj != new_config_obj:
config_obj = new_config_obj
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
@@ -175,58 +190,67 @@ def img2img_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
if use_stencil is not None:
args.use_tuned = False
img2img_obj = StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
global_obj.set_sd_obj(
StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
global_obj.set_sd_obj(
Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
img2img_obj.scheduler = schedulers[scheduler]
global_obj.set_sd_scheduler(args.scheduler)
start_time = time.time()
img2img_obj.log = ""
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"STRENGTH": strength}
text_output = ""
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = img2img_obj.generate_images(
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
@@ -243,26 +267,23 @@ def img2img_inf(
cpu_scheduling,
use_stencil=use_stencil,
)
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
img2img_obj.log += "\n"
yield generated_imgs, generated_imgs[0], img2img_obj.log
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, strength={args.strength}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += img2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -314,6 +335,8 @@ if __name__ == "__main__":
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
@@ -331,6 +354,8 @@ if __name__ == "__main__":
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
start_time = time.time()
@@ -366,3 +391,7 @@ if __name__ == "__main__":
extra_info = {"STRENGTH": args.strength}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,7 @@
import torch
import time
from PIL import Image
from dataclasses import dataclass
import transformers
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -11,24 +11,9 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
@dataclass
class Config:
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
inpaint_obj = None
config_obj = None
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -57,12 +42,19 @@ def inpaint_inf(
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import get_custom_model_pathfile
global inpaint_obj
global config_obj
global schedulers
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
@@ -71,12 +63,9 @@ def inpaint_inf(
args.scheduler = scheduler
args.img_path = "not none"
args.mask_path = "not none"
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -91,12 +80,17 @@ def inpaint_inf(
else:
args.hf_model_id = custom_model
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"inpaint",
args.hf_model_id,
args.ckpt_loc,
precision,
@@ -105,10 +99,17 @@ def inpaint_inf(
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
)
if not inpaint_obj or config_obj != new_config_obj:
config_obj = new_config_obj
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
@@ -123,38 +124,43 @@ def inpaint_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
inpaint_obj = InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
inpaint_obj.scheduler = schedulers[scheduler]
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
inpaint_obj.log = ""
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = inpaint_obj.generate_images(
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
@@ -172,26 +178,23 @@ def inpaint_inf(
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
inpaint_obj.log += "\n"
yield generated_imgs, generated_imgs[0], inpaint_obj.log
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += inpaint_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -221,6 +224,7 @@ if __name__ == "__main__":
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
@@ -228,9 +232,10 @@ if __name__ == "__main__":
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
@@ -273,3 +278,7 @@ if __name__ == "__main__":
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,19 @@
from apps.stable_diffusion.src import args
from apps.stable_diffusion.scripts import (
img2img,
txt2img,
# inpaint,
# outpaint,
)
if __name__ == "__main__":
if args.app == "txt2img":
txt2img.main()
elif args.app == "img2img":
img2img.main()
# elif args.app == "inpaint":
# inpaint.main()
# elif args.app == "outpaint":
# outpaint.main()
else:
print(f"args.app value is {args.app} but this isn't supported")

View File

@@ -1,7 +1,7 @@
import torch
import time
from PIL import Image
from dataclasses import dataclass
import transformers
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
@@ -11,24 +11,9 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
@dataclass
class Config:
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
outpaint_obj = None
config_obj = None
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -60,12 +45,19 @@ def outpaint_inf(
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import get_custom_model_pathfile
global outpaint_obj
global config_obj
global schedulers
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
@@ -73,12 +65,9 @@ def outpaint_inf(
args.steps = steps
args.scheduler = scheduler
args.img_path = "not none"
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -93,12 +82,17 @@ def outpaint_inf(
else:
args.hf_model_id = custom_model
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"outpaint",
args.hf_model_id,
args.ckpt_loc,
precision,
@@ -107,10 +101,17 @@ def outpaint_inf(
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
)
if not outpaint_obj or config_obj != new_config_obj:
config_obj = new_config_obj
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
@@ -125,27 +126,31 @@ def outpaint_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
outpaint_obj = OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
OutpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
outpaint_obj.scheduler = schedulers[scheduler]
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
outpaint_obj.log = ""
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
@@ -155,10 +160,11 @@ def outpaint_inf(
top = True if "up" in directions else False
bottom = True if "down" in directions else False
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = outpaint_obj.generate_images(
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
init_image,
@@ -181,26 +187,23 @@ def outpaint_inf(
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
outpaint_obj.log += "\n"
yield generated_imgs, generated_imgs[0], outpaint_obj.log
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += outpaint_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -234,6 +237,8 @@ if __name__ == "__main__":
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
@@ -298,3 +303,7 @@ if __name__ == "__main__":
}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,674 @@
# Install the required libs
# pip install -U git+https://github.com/huggingface/diffusers.git
# pip install accelerate transformers ftfy
# HuggingFace Token
# YOUR_TOKEN = "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"
# Import required libraries
import itertools
import math
import os
from typing import List
import random
import torch_mlir
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import PIL
import logging
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
import torch._dynamo as dynamo
from torch.fx.experimental.proxy_tensor import make_fx
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
from shark.shark_inference import SharkInference
torch._dynamo.config.verbose = True
from diffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker,
)
from PIL import Image
from tqdm.auto import tqdm
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer,
)
from io import BytesIO
from dataclasses import dataclass
from apps.stable_diffusion.src import (
args,
get_schedulers,
set_init_device_flags,
clear_all,
)
# Setup the dataset
class LoraDataset(Dataset):
def __init__(
self,
data_root,
tokenizer,
size=512,
repeats=100,
interpolation="bicubic",
set="train",
prompt="myloraprompt",
center_crop=False,
):
self.data_root = data_root
self.tokenizer = tokenizer
self.size = size
self.center_crop = center_crop
self.prompt = prompt
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
self.num_images = len(self.image_paths)
self._length = self.num_images
if set == "train":
self._length = self.num_images * repeats
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
example["input_ids"] = self.tokenizer(
self.prompt,
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids[0]
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = np.array(image).astype(np.uint8)
image = (image / 127.5 - 1.0).astype(np.float32)
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
return example
########## Setting up the model ##########
def lora_train(
prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
training_images_dir: str,
lora_save_dir: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
print(
"Note LoRA training is not compatible with the latest torch-mlir branch"
)
print(
"To run LoRA training you'll need this to follow this guide for the torch-mlir branch: https://github.com/nod-ai/SHARK/tree/main/shark/examples/shark_training/stable_diffusion"
)
torch.manual_seed(seed)
args.prompts = [prompt]
args.steps = steps
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = custom_model
else:
args.hf_model_id = custom_model
args.training_images_dir = training_images_dir
args.lora_save_dir = lora_save_dir
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
device_str = device.split("=>", 1)[1].strip().split("://")
if len(device_str) > 1:
device_str = device_str[0] + ":" + device_str[1]
else:
device_str = device_str[0]
args.device = device_str
# Load the Stable Diffusion model
text_encoder = CLIPTextModel.from_pretrained(
args.hf_model_id, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(args.hf_model_id, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
args.hf_model_id, subfolder="unet"
)
def freeze_params(params):
for param in params:
param.requires_grad = False
# Freeze everything but LoRA
freeze_params(vae.parameters())
freeze_params(unet.parameters())
freeze_params(text_encoder.parameters())
# Move vae and unet to device
vae.to(args.device)
unet.to(args.device)
text_encoder.to(args.device)
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = vae
def forward(self, input):
x = self.vae.encode(input, return_dict=False)[0]
return x
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = unet
def forward(self, x, y, z):
return self.unet.forward(x, y, z, return_dict=False)[0]
shark_vae = VaeModel()
shark_unet = UnetModel()
####### Creating our training data ########
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id,
subfolder="tokenizer",
)
# Let's create the Dataset and Dataloader
train_dataset = LoraDataset(
data_root=args.training_images_dir,
tokenizer=tokenizer,
size=vae.sample_size,
prompt=args.prompts[0],
repeats=100,
center_crop=False,
set="train",
)
def create_dataloader(train_batch_size=1):
return torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True
)
# Create noise_scheduler for training
noise_scheduler = DDPMScheduler.from_config(
args.hf_model_id, subfolder="scheduler"
)
######## Training ###########
# Define hyperparameters for our training. If you are not happy with your results,
# you can tune the `learning_rate` and the `max_train_steps`
# Setting up all training args
hyperparameters = {
"learning_rate": 5e-04,
"scale_lr": True,
"max_train_steps": steps,
"train_batch_size": batch_size,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True,
"mixed_precision": "fp16",
"seed": 42,
"output_dir": "sd-concept-output",
}
# creating output directory
cwd = os.getcwd()
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
while not os.path.exists(str(out_dir)):
try:
os.mkdir(out_dir)
except OSError as error:
print("Output directory not created")
###### Torch-MLIR Compilation ######
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
@make_simple_dynamo_backend
def refbackend_torchdynamo_backend(
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
):
# handling usage of empty tensor without initializing
transform_fx(fx_graph)
fx_graph.recompile()
if _returns_nothing(fx_graph):
return fx_graph
removed_none_indexes = _remove_nones(fx_graph)
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
mlir_module = torch_mlir.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors"
)
bytecode_stream = BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
def compiled_callable(*inputs):
inputs = [x.numpy() for x in inputs]
result = shark_module("forward", inputs)
if was_unwrapped:
result = [
result,
]
if not isinstance(result, list):
result = torch.from_numpy(result)
else:
result = tuple(torch.from_numpy(x) for x in result)
result = list(result)
for removed_index in removed_none_indexes:
result.insert(removed_index, None)
result = tuple(result)
return result
return compiled_callable
def predictions(torch_func, jit_func, batchA, batchB):
res = jit_func(batchA.numpy(), batchB.numpy())
if res is not None:
# prediction = torch.from_numpy(res)
prediction = res
else:
prediction = None
return prediction
logger = logging.getLogger(__name__)
train_batch_size = hyperparameters["train_batch_size"]
gradient_accumulation_steps = hyperparameters[
"gradient_accumulation_steps"
]
learning_rate = hyperparameters["learning_rate"]
if hyperparameters["scale_lr"]:
learning_rate = (
learning_rate
* gradient_accumulation_steps
* train_batch_size
# * accelerator.num_processes
)
# Initialize the optimizer
optimizer = torch.optim.AdamW(
lora_layers.parameters(), # only optimize the embeddings
lr=learning_rate,
)
# Training function
def train_func(batch_pixel_values, batch_input_ids):
# Convert images to latent space
latents = shark_vae(batch_pixel_values).sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.num_train_timesteps,
(bsz,),
device=latents.device,
).long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch_input_ids)[0]
# Predict the noise residual
noise_pred = shark_unet(
noisy_latents,
timesteps,
encoder_hidden_states,
)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
)
loss = (
F.mse_loss(noise_pred, target, reduction="none")
.mean([1, 2, 3])
.mean()
)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
def training_function():
max_train_steps = hyperparameters["max_train_steps"]
output_dir = hyperparameters["output_dir"]
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
train_dataloader = create_dataloader(train_batch_size)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
num_train_epochs = math.ceil(
max_train_steps / num_update_steps_per_epoch
)
# Train!
total_batch_size = (
train_batch_size
* gradient_accumulation_steps
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(
f" Instantaneous batch size per device = {train_batch_size}"
)
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(
# range(max_train_steps), disable=not accelerator.is_local_main_process
range(max_train_steps)
)
progress_bar.set_description("Steps")
global_step = 0
params__ = [
i for i in text_encoder.get_input_embeddings().parameters()
]
for epoch in range(num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
dynamo_callable = dynamo.optimize(
refbackend_torchdynamo_backend
)(train_func)
lam_func = lambda x, y: dynamo_callable(
torch.from_numpy(x), torch.from_numpy(y)
)
loss = predictions(
train_func,
lam_func,
batch["pixel_values"],
batch["input_ids"],
)
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item()}
progress_bar.set_postfix(**logs)
if global_step >= max_train_steps:
break
training_function()
# Save the lora weights
unet.save_attn_procs(args.lora_save_dir)
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
if param.grad is not None:
del param.grad # free some memory
torch.cuda.empty_cache()
if __name__ == "__main__":
if args.clear_all:
clear_all()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
if len(args.prompts) != 1:
print("Need exactly one prompt for the LoRA word")
lora_train(
args.prompts[0],
args.height,
args.width,
args.training_steps,
args.guidance_scale,
args.seed,
args.batch_count,
args.batch_size,
args.scheduler,
"None",
args.hf_model_id,
args.precision,
args.device,
args.max_length,
args.training_images_dir,
args.lora_save_dir,
)

View File

@@ -1,6 +1,6 @@
import torch
import transformers
import time
from dataclasses import dataclass
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
@@ -12,191 +12,7 @@ from apps.stable_diffusion.src import (
)
@dataclass
class Config:
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
use_lora: str
txt2img_obj = None
config_obj = None
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import get_custom_model_pathfile
global txt2img_obj
global config_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora,
)
if not txt2img_obj or config_obj != new_config_obj:
config_obj = new_config_obj
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
)
txt2img_obj.scheduler = schedulers[scheduler]
start_time = time.time()
txt2img_obj.log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = txt2img_obj.generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
txt2img_obj.log += "\n"
yield generated_imgs, generated_imgs[0], txt2img_obj.log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += (
f"\nsteps={steps}, guidance_scale={guidance_scale}, seed={seeds}"
)
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
# text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -206,7 +22,6 @@ if __name__ == "__main__":
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
use_lora = args.use_lora
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
@@ -222,7 +37,9 @@ if __name__ == "__main__":
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
for current_batch in range(args.batch_count):
@@ -262,3 +79,7 @@ if __name__ == "__main__":
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,277 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# Exposed to UI.
def upscaler_inf(
prompt: str,
negative_prompt: str,
init_image,
height: int,
width: int,
steps: int,
noise_level: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
if init_image is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB").resize((height, width))
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
args.height = 128
args.width = 128
new_config_obj = Config(
"upscaler",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
args.height,
args.width,
device,
use_lora=args.use_lora,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_size = batch_size
args.max_length = max_length
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
"DDPM"
)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"NOISE LEVEL": noise_level}
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
low_res_img = image
high_res_img = Image.new("RGB", (height * 4, width * 4))
for i in range(0, width, 128):
for j in range(0, height, 128):
box = (j, i, j + 128, i + 128)
upscaled_image = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
low_res_img.crop(box),
batch_size,
args.height,
args.width,
steps,
noise_level,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
save_output_img(high_res_img, img_seed, extra_info)
generated_imgs.append(high_res_img)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output
if __name__ == "__main__":
if args.clear_all:
clear_all()
if args.img_path is None:
print("Flag --img_path is required.")
exit()
# When the models get uploaded, it should be default to False.
args.import_mlir = True
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
image = (
Image.open(args.img_path)
.convert("RGB")
.resize((args.height, args.width))
)
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model
upscaler_obj = UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ddpm_scheduler=schedulers["DDPM"],
ondemand=args.ondemand,
)
start_time = time.time()
generated_imgs = upscaler_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.batch_size,
args.height,
args.width,
args.steps,
args.noise_level,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, noise_level={args.noise_level}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += upscaler_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
extra_info = {"NOISE LEVEL": args.noise_level}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)

View File

@@ -21,9 +21,11 @@ datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv-python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')

View File

@@ -22,8 +22,10 @@ datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
@@ -42,7 +44,7 @@ hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
a = Analysis(
['scripts/txt2img.py'],
['scripts/main.py'],
pathex=['.'],
binaries=binaries,
datas=datas,

View File

@@ -12,5 +12,6 @@ from apps.stable_diffusion.src.pipelines import (
InpaintPipeline,
OutpaintPipeline,
StencilPipeline,
UpscalerPipeline,
)
from apps.stable_diffusion.src.schedulers import get_schedulers

View File

@@ -11,13 +11,14 @@ from apps.stable_diffusion.src.utils import (
get_opt_flags,
base_models,
args,
fetch_or_delete_vmfbs,
fetch_vmfb,
preprocessCKPT,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
get_stencil_model_id,
update_lora_weight,
)
@@ -54,29 +55,9 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
return new_shape
# Get the input info for various models i.e. "unet", "clip", "vae", "vae_encode".
def get_input_info(model_info, max_len, width, height, batch_size):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = defaultdict(list)
for k in model_info:
for inp in model_info[k]:
shape = model_info[k][inp]["shape"]
dtype = dtype_config[model_info[k][inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, max_len, width, height, batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map[k].append(tensor)
return input_map
def check_compilation(model, model_name):
if not model:
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
class SharkifyStableDiffusionModel:
@@ -97,8 +78,10 @@ class SharkifyStableDiffusionModel:
sharktank_dir: str = "",
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
use_lora: str = ""
use_lora: str = "",
use_quantize: str = None,
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -106,6 +89,7 @@ class SharkifyStableDiffusionModel:
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.use_quantize = use_quantize
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
@@ -137,24 +121,38 @@ class SharkifyStableDiffusionModel:
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.is_upscaler = is_upscaler
self.use_stencil = get_stencil_model_id(use_stencil)
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora
print(self.model_name)
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
self.sharktank_dir = sharktank_dir
self.generate_vmfb = generate_vmfb
def get_extended_name_for_all_model(self, mask_to_fetch):
self.inputs = dict()
self.model_to_run = ""
if self.custom_weights != "":
self.model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights, self.is_inpaint)
else:
self.model_to_run = args.hf_model_id
self.custom_vae = self.process_custom_vae()
self.base_model_id = fetch_and_update_base_model_id(self.model_to_run)
if self.base_model_id != "" and args.ckpt_loc != "":
args.hf_model_id = self.base_model_id
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
index = 0
for model in sub_model_list:
if mask_to_fetch[index] == False:
index += 1
continue
sub_model = model
model_config = self.model_name
if "vae" == model:
@@ -169,11 +167,34 @@ class SharkifyStableDiffusionModel:
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
sys.exit("please specify max_len in the range [32, 77].")
if not (width % 8 == 0 and width >= 384):
sys.exit("width should be greater than 384 and multiple of 8")
if not (height % 8 == 0 and height >= 384):
sys.exit("height should be greater than 384 and multiple of 8")
if not (width % 8 == 0 and width >= 128):
sys.exit("width should be greater than 128 and multiple of 8")
if not (height % 8 == 0 and height >= 128):
sys.exit("height should be greater than 128 and multiple of 8")
# Get the input info for a model i.e. "unet", "clip", "vae", etc.
def get_input_info_for(self, model_info):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = []
for inp in model_info:
shape = model_info[inp]["shape"]
dtype = dtype_config[model_info[inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, self.max_len, self.width, self.height, self.batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map.append(tensor)
return input_map
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
@@ -198,6 +219,7 @@ class SharkifyStableDiffusionModel:
use_tuned=self.use_tuned,
model_name=self.model_name["vae_encode"],
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_vae_encode
@@ -253,13 +275,14 @@ class SharkifyStableDiffusionModel:
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_vae
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -267,6 +290,8 @@ class SharkifyStableDiffusionModel:
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
@@ -295,7 +320,7 @@ class SharkifyStableDiffusionModel:
unet = ControlledUnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_unet"])
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
shark_controlled_unet = compile_through_fx(
unet,
@@ -305,6 +330,7 @@ class SharkifyStableDiffusionModel:
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_controlled_unet
@@ -358,6 +384,7 @@ class SharkifyStableDiffusionModel:
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_cnet
@@ -371,7 +398,7 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
self.unet.load_attn_procs(use_lora)
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
@@ -416,18 +443,59 @@ class SharkifyStableDiffusionModel:
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_unet
def get_unet_upscaler(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, latent, timestep, text_embedding, noise_level):
unet_out = self.unet.forward(
latent,
timestep,
text_embedding,
noise_level,
return_dict=False,
)[0]
return unet_out
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
shark_unet = compile_through_fx(
unet,
inputs,
model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_unet
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
def forward(self, input):
return self.text_encoder(input)[0]
@@ -447,6 +515,7 @@ class SharkifyStableDiffusionModel:
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
)
return shark_clip
@@ -469,128 +538,120 @@ class SharkifyStableDiffusionModel:
vae_checkpoint = vae_checkpoint["state_dict"]
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
# configiration.
def compile_all(self, base_model_id, need_vae_encode, need_stencil):
self.inputs = get_input_info(
base_models[base_model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
compiled_controlnet = None
compiled_controlled_unet = None
compiled_unet = None
if need_stencil:
compiled_controlnet = self.get_control_net()
compiled_controlled_unet = self.get_controlled_unet()
else:
compiled_unet = self.get_unet()
if self.custom_vae != "":
print("Plugging in custom Vae")
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
if need_stencil:
return compiled_clip, compiled_controlled_unet, compiled_vae, compiled_controlnet
if need_vae_encode:
compiled_vae_encode = self.get_vae_encode()
return compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode
return compiled_clip, compiled_unet, compiled_vae
def __call__(self):
# Step 1:
# -- Fetch all vmfbs for the model, if present, else delete the lot.
need_vae_encode, need_stencil = False, False
if args.img_path is not None:
if self.use_stencil is not None:
need_stencil = True
def compile_unet_variants(self, model):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler()
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
return get_unet()
else:
need_vae_encode = True
# `mask_to_fetch` prepares a mask to pick a combination out of :-
# ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
mask_to_fetch = [True, True, False, True, False, False]
if need_vae_encode:
mask_to_fetch = [True, True, False, True, True, False]
elif need_stencil:
mask_to_fetch = [True, False, True, True, False, True]
self.model_name = self.get_extended_name_for_all_model(mask_to_fetch)
vmfbs = fetch_or_delete_vmfbs(self.model_name, self.precision)
if vmfbs[0]:
# -- If all vmfbs are indeed present, we also try and fetch the base
# model configuration for running SD with custom checkpoints.
if self.custom_weights != "":
args.hf_model_id = fetch_and_update_base_model_id(self.custom_weights)
if args.hf_model_id == "":
sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.")
print("Loaded vmfbs from cache and successfully fetched base model configuration.")
return vmfbs
# Step 2:
# -- If vmfbs weren't found, we try to see if the base model configuration
# for the required SD run is known to us and bypass the retry mechanism.
model_to_run = ""
if self.custom_weights != "":
model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights, self.is_inpaint)
return self.get_unet()
else:
model_to_run = args.hf_model_id
# For custom Vae user can provide either the repo-id or a checkpoint file,
# and for a checkpoint file we'd need to process it via Diffusers' script.
self.custom_vae = self.process_custom_vae()
base_model_fetched = fetch_and_update_base_model_id(model_to_run)
if base_model_fetched != "":
print("Compiling all the models with the fetched base model configuration.")
if args.ckpt_loc != "":
args.hf_model_id = base_model_fetched
return self.compile_all(base_model_fetched, need_vae_encode, need_stencil)
return self.get_controlled_unet()
# Step 3:
# -- This is the retry mechanism where the base model's configuration is not
# known to us and figure that out by trial and error.
print("Inferring base model configuration.")
for model_id in base_models:
try:
if need_vae_encode:
compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode = self.compile_all(model_id, need_vae_encode, need_stencil)
elif need_stencil:
compiled_clip, compiled_unet, compiled_vae, compiled_controlnet = self.compile_all(model_id, need_vae_encode, need_stencil)
else:
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id, need_vae_encode, need_stencil)
except Exception as e:
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
if need_vae_encode:
return (
compiled_clip,
compiled_unet,
compiled_vae,
compiled_vae_encode,
)
if need_stencil:
return (
compiled_clip,
compiled_unet,
compiled_vae,
compiled_controlnet,
)
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
)
def vae_encode(self):
# Fetch vmfb for the model if present
vmfb = fetch_vmfb("vae_encode", self.model_name["vae_encode"], self.precision)
if vmfb:
return vmfb
try:
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
compiled_vae_encode = self.get_vae_encode()
check_compilation(compiled_vae_encode, "Vae Encode")
return compiled_vae_encode
except Exception as e:
sys.exit(e)
def clip(self):
vmfb = fetch_vmfb("clip", self.model_name["clip"], self.precision)
if vmfb:
return vmfb
try:
self.inputs["clip"] = self.get_input_info_for(base_models["clip"])
compiled_clip = self.get_clip()
check_compilation(compiled_clip, "Clip")
return compiled_clip
except Exception as e:
sys.exit(e)
def unet(self):
model = "stencil_unet" if self.use_stencil is not None else "unet"
vmfb = fetch_vmfb(model, self.model_name[model], self.precision)
if vmfb:
return vmfb
try:
compiled_unet = None
unet_inputs = base_models[model]
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet = self.compile_unet_variants(model)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
try:
compiled_unet = self.compile_unet_variants(model)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(self.model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
break
check_compilation(compiled_unet, "Unet")
return compiled_unet
except Exception as e:
sys.exit(e)
def vae(self):
vmfb = fetch_vmfb("vae", self.model_name["vae"], self.precision)
if vmfb:
return vmfb
try:
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
self.inputs["vae"] = self.get_input_info_for(vae_input)
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
compiled_vae = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")
return compiled_vae
except Exception as e:
sys.exit(e)
def controlnet(self):
vmfb = fetch_vmfb("stencil_adaptor", self.model_name["stencil_adaptor"], self.precision)
if vmfb:
return vmfb
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
compiled_stencil_adaptor = self.get_control_net()
check_compilation(compiled_stencil_adaptor, "Stencil")
return compiled_stencil_adaptor
except Exception as e:
sys.exit(e)

View File

@@ -20,6 +20,15 @@ hf_model_variant_map = {
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
bucket_key = "gs://shark_tank/prashant_nod"
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
@@ -41,6 +50,12 @@ def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
# TODO: Get the quantize model from model_db.json
if args.use_quantize == "int8":
bk, mk, flags = get_quantize_model()
return get_shark_model(bk, mk, flags)
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"

View File

@@ -13,3 +13,6 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpain
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_stencil import (
StencilPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_upscaler import (
UpscalerPipeline,
)

View File

@@ -20,16 +20,15 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class Image2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -40,9 +39,30 @@ class Image2ImagePipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_image_latents(
self,
@@ -89,9 +109,12 @@ class Image2ImagePipeline(StableDiffusionPipeline):
return latents, timesteps
def encode_image(self, input_image):
self.load_vae_encode()
vae_encode_start = time.time()
latents = self.vae_encode("forward", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
if self.ondemand:
self.unload_vae_encode()
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
return latents
@@ -131,8 +154,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -161,6 +186,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
@@ -168,5 +194,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -19,16 +19,15 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class InpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -39,9 +38,30 @@ class InpaintPipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
@@ -305,9 +325,12 @@ class InpaintPipeline(StableDiffusionPipeline):
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
@@ -383,8 +406,10 @@ class InpaintPipeline(StableDiffusionPipeline):
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -428,6 +453,7 @@ class InpaintPipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
@@ -435,6 +461,8 @@ class InpaintPipeline(StableDiffusionPipeline):
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
if inpaint_full_res:
output_image = self.apply_overlay(

View File

@@ -20,16 +20,15 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i
StableDiffusionPipeline,
)
import math
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class OutpaintPipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -40,9 +39,30 @@ class OutpaintPipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def prepare_latents(
self,
@@ -123,9 +143,12 @@ class OutpaintPipeline(StableDiffusionPipeline):
)
mask = mask.to(dtype)
self.load_vae_encode()
masked_image = masked_image.to(dtype)
masked_image_latents = self.vae_encode("forward", (masked_image,))
masked_image_latents = torch.from_numpy(masked_image_latents)
if self.ondemand:
self.unload_vae_encode()
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
@@ -384,8 +407,10 @@ class OutpaintPipeline(StableDiffusionPipeline):
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -506,6 +531,7 @@ class OutpaintPipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
@@ -513,6 +539,8 @@ class OutpaintPipeline(StableDiffusionPipeline):
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
res_img = all_imgs[0].resize(
(image_to_process.width, image_to_process.height)

View File

@@ -20,16 +20,16 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import controlnet_hint_conversion
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class StencilPipeline(StableDiffusionPipeline):
def __init__(
self,
controlnet: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -39,9 +39,22 @@ class StencilPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.controlnet = controlnet
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
def load_controlnet(self):
if self.controlnet is not None:
return
self.controlnet = self.sd_model.controlnet()
def unload_controlnet(self):
del self.controlnet
self.controlnet = None
def prepare_latents(
self,
@@ -68,6 +81,113 @@ class StencilPipeline(StableDiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_stencil_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
controlnet_conditioning_scale: float = 1.0,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
self.load_controlnet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.ondemand:
self.unload_unet()
self.unload_controlnet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
@@ -108,8 +228,10 @@ class StencilPipeline(StableDiffusionPipeline):
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -134,11 +256,11 @@ class StencilPipeline(StableDiffusionPipeline):
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
controlnet=self.controlnet,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
@@ -146,5 +268,7 @@ class StencilPipeline(StableDiffusionPipeline):
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -1,5 +1,4 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
@@ -19,15 +18,12 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -39,8 +35,12 @@ class Text2ImagePipeline(StableDiffusionPipeline):
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
def prepare_latents(
self,
@@ -110,8 +110,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
@@ -128,12 +130,15 @@ class Text2ImagePipeline(StableDiffusionPipeline):
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
self.load_vae()
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -0,0 +1,319 @@
import inspect
import torch
import time
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from PIL import Image
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, Image.Image):
image = [image]
if isinstance(image[0], Image.Image):
w, h = image[0].size
w, h = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class UpscalerPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
low_res_scheduler: Union[
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.low_res_scheduler = low_res_scheduler
def prepare_extra_step_kwargs(self, generator, eta):
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
latents = 1 / 0.08333 * (latents.float())
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height,
width,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def produce_img_latents(
self,
latents,
image,
text_embeddings,
guidance_scale,
noise_level,
total_timesteps,
dtype,
cpu_scheduling,
extra_step_kwargs,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = torch.cat([latent_model_input, image], dim=1)
timestep = torch.tensor([t]).to(dtype).detach().numpy()
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if cpu_scheduling:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
else:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
noise_level,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
)
# 4. Preprocess image
image = preprocess(image).to(dtype)
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long)
noise = torch.randn(
image.shape,
generator=generator,
).to(dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
image = torch.cat([image] * 2)
noise_level = torch.cat([noise_level] * image.shape[0])
height, width = image.shape[2:]
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
eta = 0.0
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# guidance scale as a float32 tensor.
# guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
image=image,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
noise_level=noise_level,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
extra_step_kwargs=extra_step_kwargs,
)
# Img latents -> PIL images
all_imgs = []
self.load_vae()
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -7,6 +7,7 @@ import time
from typing import Union
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
@@ -19,7 +20,6 @@ from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
get_vae,
get_clip,
get_unet,
@@ -29,15 +29,15 @@ from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
import sys
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
class StableDiffusionPipeline:
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -49,14 +49,85 @@ class StableDiffusionPipeline:
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.unet = unet
self.vae = None
self.text_encoder = None
self.unet = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
self.sd_model = sd_model
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
try:
self.tokenizer = get_tokenizer()
except:
self.load_unet()
self.unload_unet()
self.tokenizer = get_tokenizer()
def load_clip(self):
if self.text_encoder is not None:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
)
self.text_encoder = self.sd_model.clip()
else:
try:
self.text_encoder = get_clip()
except:
print("download pipeline failed, falling back to import_mlir")
self.text_encoder = self.sd_model.clip()
def unload_clip(self):
del self.text_encoder
self.text_encoder = None
def load_unet(self):
if self.unet is not None:
return
if self.import_mlir or self.use_lora:
self.unet = self.sd_model.unet()
else:
try:
self.unet = get_unet()
except:
print("download pipeline failed, falling back to import_mlir")
self.unet = self.sd_model.unet()
def unload_unet(self):
del self.unet
self.unet = None
def load_vae(self):
if self.vae is not None:
return
if self.import_mlir or self.use_lora:
self.vae = self.sd_model.vae()
else:
try:
self.vae = get_vae()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae = self.sd_model.vae()
def unload_vae(self):
del self.vae
self.vae = None
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -76,12 +147,13 @@ class StableDiffusionPipeline:
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
self.load_clip()
clip_inf_start = time.time()
text_embeddings = self.text_encoder("forward", (text_input,))
clip_inf_time = (time.time() - clip_inf_start) * 1000
# self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
@@ -110,109 +182,6 @@ class StableDiffusionPipeline:
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def produce_stencil_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
controlnet=None,
controlnet_conditioning_scale: float = 1.0,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def produce_img_latents(
self,
latents,
@@ -225,10 +194,12 @@ class StableDiffusionPipeline:
masked_image_latents=None,
return_all_latents=False,
):
self.status = SD_STATE_IDLE
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
@@ -274,6 +245,11 @@ class StableDiffusionPipeline:
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -307,101 +283,555 @@ class StableDiffusionPipeline:
width: int,
use_base_vae: bool,
use_tuned: bool,
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
if (
not import_mlir
and not use_lora
and cls.__name__ == "StencilPipeline"
):
sys.exit("StencilPipeline not supported with SharkTank currently.")
is_inpaint = cls.__name__ in [
"InpaintPipeline",
"OutpaintPipeline",
]
if import_mlir:
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
debug=debug,
is_inpaint=is_inpaint,
use_stencil=use_stencil,
use_lora=use_lora,
)
if cls.__name__ in [
"Image2ImagePipeline",
"InpaintPipeline",
"OutpaintPipeline",
]:
clip, unet, vae, vae_encode = mlir_import()
return cls(
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
)
if cls.__name__ in ["StencilPipeline"]:
clip, unet, vae, controlnet = mlir_import()
return cls(
controlnet, vae, clip, get_tokenizer(), unet, scheduler
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
try:
if cls.__name__ in [
"Image2ImagePipeline",
"InpaintPipeline",
"OutpaintPipeline",
]:
return cls(
get_vae_encode(),
get_vae(),
get_clip(),
get_tokenizer(),
get_unet(),
scheduler,
)
if cls.__name__ == "StencilPipeline":
import sys
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
sys.exit(
"StencilPipeline not supported with SharkTank currently."
)
sd_model = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
use_quantize=use_quantize,
)
if cls.__name__ in ["UpscalerPipeline"]:
return cls(
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
scheduler,
ddpm_scheduler,
sd_model,
import_mlir,
use_lora,
ondemand,
)
except:
print("download pipeline failed, falling back to import_mlir")
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
custom_vae,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
is_inpaint=is_inpaint,
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
# #####################################################
# Implements text embeddings with weights from prompts
# https://huggingface.co/AlanB/lpw_stable_diffusion_mod
# #####################################################
def encode_prompts_weight(
self,
prompt,
negative_prompt,
model_max_length,
do_classifier_free_guidance=True,
max_embeddings_multiples=1,
num_images_per_prompt=1,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
model_max_length (int):
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not,
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
num_images_per_prompt (`int`):
number of images that should be generated per prompt
SHARK: num_images_per_prompt is not used (defaulted to 1)
"""
# SHARK: Save model_max_length, load the clip and init inference time
self.model_max_length = model_max_length
self.load_clip()
clip_inf_start = time.time()
batch_size = len(prompt) if isinstance(prompt, list) else 1
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
if cls.__name__ in [
"Image2ImagePipeline",
"InpaintPipeline",
"OutpaintPipeline",
]:
clip, unet, vae, vae_encode = mlir_import()
return cls(
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
)
if cls.__name__ == "StencilPipeline":
clip, unet, vae, controlnet = mlir_import()
return cls(
controlnet, vae, clip, get_tokenizer(), unet, scheduler
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt
if do_classifier_free_guidance
else None,
max_embeddings_multiples=max_embeddings_multiples,
)
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = text_embeddings.shape
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = uncond_embeddings.shape
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
# self.unload_clip()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()
from typing import List, Optional, Union
import re
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(
pipe: StableDiffusionPipeline, prompt: List[str], max_length: int
):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens,
weights,
max_length,
bos,
eos,
no_boseos_middle=True,
chunk_length=77,
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length
if no_boseos_middle
else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
tokens[i] = (
[bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
)
if no_boseos_middle:
weights[i] = (
[1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
)
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j
* (chunk_length - 2) : min(
len(weights[i]), (j + 1) * (chunk_length - 2)
)
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
# text_embedding = pipe.text_encoder(text_input_chunk)[0]
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
formatted_text_input_chunk = torch.cat(
[text_input_chunk, text_input_chunk]
)
text_embedding = pipe.text_encoder(
"forward", (formatted_text_input_chunk,)
)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
else:
# SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens
# Convert the result to tensor
# text_embeddings = pipe.text_encoder(text_input)[0]
formatted_text_input = torch.cat([text_input, text_input])
text_embeddings = pipe.text_encoder(
"forward", (formatted_text_input,)
)[0]
text_embeddings = torch.from_numpy(text_embeddings)[None, :]
return text_embeddings
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
pipe (`StableDiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(
pipe, prompt, max_length - 2
)
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(
prompt, max_length=max_length, truncation=True
).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(
max_length, max([len(token) for token in uncond_tokens])
)
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(
uncond_tokens, dtype=torch.long, device="cpu"
)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
prompt_weights = torch.tensor(
prompt_weights, dtype=torch.float, device="cpu"
)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
uncond_weights = torch.tensor(
uncond_weights, dtype=torch.float, device="cpu"
)
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@@ -1,6 +1,7 @@
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
KDPM2DiscreteScheduler,
@@ -19,6 +20,10 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",

View File

@@ -24,7 +24,7 @@ from apps.stable_diffusion.src.utils.utils import (
get_available_devices,
get_opt_flags,
preprocessCKPT,
fetch_or_delete_vmfbs,
fetch_vmfb,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
@@ -32,4 +32,6 @@ from apps.stable_diffusion.src.utils.utils import (
get_extended_name,
clear_all,
save_output_img,
get_generation_text_info,
update_lora_weight,
)

View File

@@ -1,6 +1,41 @@
{
"stabilityai/stable-diffusion-2-1": {
"unet": {
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"vae_upscaler": {
"latents" : {
"shape" : [
"1*batch_size",4,"8*height","8*width"
],
"dtype":"f32"
}
}
},
"unet": {
"stabilityai/stable-diffusion-2-1": {
"latents": {
"shape": [
"1*batch_size",
@@ -29,34 +64,7 @@
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"CompVis/stable-diffusion-v1-4": {
"unet": {
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
@@ -85,11 +93,40 @@
"dtype": "f32"
}
},
"stencil_adaptor": {
"stabilityai/stable-diffusion-2-inpainting": {
"latents": {
"shape": [
"1*batch_size",
4,
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"runwayml/stable-diffusion-inpainting": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
@@ -109,12 +146,72 @@
],
"dtype": "f32"
},
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"stencil_unet": {
"stabilityai/stable-diffusion-x4-upscaler": {
"latents": {
"shape": [
"2*batch_size",
7,
"8*height",
"8*width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"noise_level": {
"shape": [2],
"dtype": "i64"
}
}
},
"stencil_adaptor": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
}
},
"stencil_unet": {
"CompVis/stable-diffusion-v1-4": {
"latents": {
"shape": [
"1*batch_size",
@@ -194,143 +291,6 @@
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"stabilityai/stable-diffusion-2-inpainting": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"runwayml/stable-diffusion-inpainting": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
9,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [
"1*batch_size",3,"8*height","8*width"
],
"dtype":"f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
}
}

View File

@@ -1,85 +1,19 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/sd_untuned",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
"stablediffusion/untuned":"gs://shark_tank/nightly"
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/unet/fp32/length_64/untuned":"unet_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/vae/fp32/length_64/untuned":"vae_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"anythingv3/v1_4/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v1_4/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
"anythingv3/v1_4/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
"anythingv3/v1_4/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v1_4/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v1_4/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
"anythingv3/v1_4/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
"anythingv3/v1_4/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
"anythingv3/v1_4/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v1_4/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
"anythingv3/v1_4/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v1_4/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v1_4/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
"analogdiffusion/v1_4/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
"analogdiffusion/v1_4/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v1_4/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v1_4/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
"analogdiffusion/v1_4/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
"analogdiffusion/v1_4/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
"analogdiffusion/v1_4/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v1_4/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
"analogdiffusion/v1_4/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
"openjourney/v1_4/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
"openjourney/v1_4/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
"openjourney/v1_4/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
"openjourney/v1_4/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
"openjourney/v1_4/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
"openjourney/v1_4/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
"openjourney/v1_4/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
"dreamlike/v1_4/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
"dreamlike/v1_4/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
"dreamlike/v1_4/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
"dreamlike/v1_4/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
"dreamlike/v1_4/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v1_4/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v1_4/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan"
}
]

View File

@@ -76,18 +76,19 @@ def load_winograd_configs():
return winograd_config_dir
def load_lower_configs():
def load_lower_configs(base_model_id=None):
from apps.stable_diffusion.src.models import get_variant_version
from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
)
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
if not base_model_id:
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
variant, version = get_variant_version(base_model_id)
@@ -114,7 +115,14 @@ def load_lower_configs():
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
and args.width == 768
):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
@@ -212,7 +220,7 @@ def annotate_with_lower_configs(
return bytecode
def sd_model_annotation(mlir_model, model_name):
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
@@ -220,7 +228,7 @@ def sd_model_annotation(mlir_model, model_name):
winograd_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs()
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
winograd_model, lowering_config_dir, model_name, use_winograd
)
@@ -232,7 +240,7 @@ def sd_model_annotation(mlir_model, model_name):
)
else:
use_winograd = False
lowering_config_dir = load_lower_configs()
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)

View File

@@ -22,6 +22,12 @@ p = argparse.ArgumentParser(
### Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
)
p.add_argument(
"-p",
"--prompts",
@@ -69,7 +75,7 @@ p.add_argument(
"--height",
type=int,
default=512,
choices=range(384, 769, 8),
choices=range(128, 769, 8),
help="the height of the output image.",
)
@@ -77,7 +83,7 @@ p.add_argument(
"--width",
type=int,
default=512,
choices=range(384, 769, 8),
choices=range(128, 769, 8),
help="the width of the output image.",
)
@@ -88,6 +94,13 @@ p.add_argument(
help="the value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="the value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
@@ -102,6 +115,31 @@ p.add_argument(
help="the strength of change applied on the given input image for img2img",
)
##############################################################################
### Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The no. of steps to train",
)
##############################################################################
### Inpainting and Outpainting Params
##############################################################################
@@ -308,6 +346,21 @@ p.add_argument(
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
)
p.add_argument(
"--ondemand",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
@@ -328,7 +381,7 @@ p.add_argument(
p.add_argument(
"--vulkan_large_heap_block_size",
default="4147483648",
default="2073741824",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
@@ -456,6 +509,12 @@ p.add_argument(
help="flag for setting server port",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for enabling rest API",
)
##############################################################################
### SD model auto-annotation flags
##############################################################################

View File

@@ -9,6 +9,8 @@ from pathlib import Path
import numpy as np
from random import randint
import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
@@ -21,7 +23,7 @@ from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
download_from_original_stable_diffusion_ckpt,
)
@@ -78,7 +80,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
mlir_model, device=args.device, mlir_dialect="tm_tensor"
)
return _compile_module(shark_module, model_name, extra_args)
@@ -95,6 +97,7 @@ def compile_through_fx(
debug=False,
generate_vmfb=True,
extra_args=[],
base_model_id=None,
):
from shark.parser import shark_args
@@ -116,19 +119,21 @@ def compile_through_fx(
if use_tuned:
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
mlir_module = sd_model_annotation(mlir_module, model_name)
mlir_module = sd_model_annotation(
mlir_module, model_name, base_model_id
)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
mlir_dialect="tm_tensor",
)
if generate_vmfb:
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
mlir_dialect="tm_tensor",
)
del mlir_module
gc.collect()
@@ -264,8 +269,9 @@ def set_init_device_flags():
if (
args.precision != "fp16"
or args.height != 512
or args.width != 512
or args.height not in [512, 768]
or (args.height == 512 and args.width != 512)
or (args.height == 768 and args.width != 768)
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
@@ -299,6 +305,20 @@ def set_init_device_flags():
]:
args.use_tuned = False
elif (
args.height == 768
and args.width == 768
and (
base_model_id
not in [
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
]
or "rdna3" not in args.iree_vulkan_target_triple
)
):
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
else:
@@ -368,7 +388,7 @@ def get_available_devices():
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
available_devices.append("device => cpu")
return available_devices
@@ -454,7 +474,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
num_in_channels = 9 if is_inpaint else 4
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
@@ -464,6 +484,115 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
print("Loading complete")
def processLoRA(model, use_lora, splitting_prefix):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
alpha = 0.75
visited = []
# directly update weight in model
process_unet = "te" not in splitting_prefix
for key in state_dict:
if ".alpha" in key or key in visited:
continue
curr_layer = model
if ("text" not in key and process_unet) or (
"text" in key and not process_unet
):
layer_infos = (
key.split(".")[0].split(splitting_prefix)[-1].split("_")
)
else:
continue
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = (
state_dict[pair_keys[0]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
weight_down = (
state_dict[pair_keys[1]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
return model
def update_lora_weight_for_unet(unet, use_lora):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
return unet
main_file_name = get_path_stem(use_lora)
if ".bin" in use_lora:
main_file_name += ".bin"
elif ".safetensors" in use_lora:
main_file_name += ".safetensors"
elif ".pt" in use_lora:
main_file_name += ".pt"
else:
sys.exit("Only .bin and .safetensors format for LoRA is supported")
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_")
def update_lora_weight(model, use_lora, model_name):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora)
try:
return processLoRA(model, use_lora, "lora_te_")
except:
return None
def load_vmfb(vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
@@ -474,34 +603,14 @@ def load_vmfb(vmfb_path, model, precision):
return shark_module
# This utility returns vmfbs of Clip, Unet, Vae and Vae_encode, in case all of them
# are present; deletes them otherwise.
def fetch_or_delete_vmfbs(extended_model_name, precision="fp32"):
vmfb_path = [
get_vmfb_path_name(extended_model_name[model])
for model in extended_model_name
]
number_of_vmfbs = len(vmfb_path)
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = True
compiled_models = [None] * number_of_vmfbs
for i in range(number_of_vmfbs):
all_vmfb_present = all_vmfb_present and vmfb_present[i]
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(number_of_vmfbs):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
else:
model_name = [model for model in extended_model_name.keys()]
for i in range(number_of_vmfbs):
compiled_models[i] = load_vmfb(
vmfb_path[i], model_name[i], precision
)
return compiled_models
# This utility returns vmfb of sub-model of the SD pipeline, if present.
def fetch_vmfb(model, extended_model_name, precision="fp32"):
vmfb_path = get_vmfb_path_name(extended_model_name)
vmfb_present = os.path.isfile(vmfb_path)
compiled_model = (
load_vmfb(vmfb_path, model, precision) if vmfb_present else None
)
return compiled_model
# `fetch_and_update_base_model_id` is a resource utility function which
@@ -629,3 +738,14 @@ def save_output_img(output_img, img_seed, extra_info={}):
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
return text_output

View File

@@ -1,143 +1,215 @@
import os
import sys
import transformers
from apps.stable_diffusion.src import args, clear_all
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
import gradio as gr
from apps.stable_diffusion.src import args, clear_all
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
)
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create the custom model folder if it doesn't already exist
get_custom_model_path().mkdir(parents=True, exist_ok=True)
if args.clear_all:
clear_all()
if __name__ == "__main__":
if args.api:
from apps.stable_diffusion.web.ui import txt2img_inf
from fastapi import FastAPI, APIRouter
import uvicorn
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
app = FastAPI()
app.add_api_route("/sdapi/txt2img", txt2img_inf, methods=["post"])
app.include_router(APIRouter())
uvicorn.run(app, host="0.0.0.0", port=args.server_port)
sys.exit(0)
import gradio as gr
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
)
return os.path.join(base_path, relative_path)
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create the custom model folder if it doesn't already exist
dir = ["models", "vae", "lora"]
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_gallery,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
img2img_web,
img2img_gallery,
img2img_init_image,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
inpaint_web,
inpaint_gallery,
inpaint_init_image,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
outpaint_web,
outpaint_gallery,
outpaint_init_image,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
img2img_web.render()
with gr.TabItem(label="Inpainting", id=2):
inpaint_web.render()
with gr.TabItem(label="Outpainting", id=3):
outpaint_web.render()
register_button_click(
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_gallery,
txt2img_sendto_img2img,
1,
[txt2img_gallery],
[img2img_init_image, tabs],
)
register_button_click(
txt2img_sendto_inpaint,
2,
[txt2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_outpaint,
3,
[txt2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_upscaler,
img2img_web,
img2img_gallery,
img2img_init_image,
img2img_sendto_inpaint,
2,
[img2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_outpaint,
3,
[img2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_upscaler,
inpaint_web,
inpaint_gallery,
inpaint_init_image,
inpaint_sendto_img2img,
1,
[inpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
inpaint_sendto_outpaint,
3,
[inpaint_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
inpaint_sendto_upscaler,
outpaint_web,
outpaint_gallery,
outpaint_init_image,
outpaint_sendto_img2img,
1,
[outpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
outpaint_sendto_inpaint,
2,
[outpaint_gallery],
[inpaint_init_image, tabs],
outpaint_sendto_upscaler,
upscaler_web,
upscaler_gallery,
upscaler_init_image,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
lora_train_web,
)
# init global sd pipeline and config
global_obj._init()
sd_web.queue()
sd_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
img2img_web.render()
with gr.TabItem(label="Inpainting", id=2):
inpaint_web.render()
with gr.TabItem(label="Outpainting", id=3):
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.Tabs(visible=False) as experimental_tabs:
with gr.TabItem(label="LoRA Training", id=5):
lora_train_web.render()
register_button_click(
txt2img_sendto_img2img,
1,
[txt2img_gallery],
[img2img_init_image, tabs],
)
register_button_click(
txt2img_sendto_inpaint,
2,
[txt2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_outpaint,
3,
[txt2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
txt2img_sendto_upscaler,
4,
[txt2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
img2img_sendto_inpaint,
2,
[img2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_outpaint,
3,
[img2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
img2img_sendto_upscaler,
4,
[img2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
inpaint_sendto_img2img,
1,
[inpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
inpaint_sendto_outpaint,
3,
[inpaint_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
inpaint_sendto_upscaler,
4,
[inpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
outpaint_sendto_img2img,
1,
[outpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
outpaint_sendto_inpaint,
2,
[outpaint_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
outpaint_sendto_upscaler,
4,
[outpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
upscaler_sendto_img2img,
1,
[upscaler_gallery],
[img2img_init_image, tabs],
)
register_button_click(
upscaler_sendto_inpaint,
2,
[upscaler_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
upscaler_sendto_outpaint,
3,
[upscaler_gallery],
[outpaint_init_image, tabs],
)
sd_web.queue()
sd_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

View File

@@ -1,9 +1,11 @@
from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_inf,
txt2img_web,
txt2img_gallery,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_web,
@@ -11,6 +13,7 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_init_image,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
img2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.inpaint_ui import (
inpaint_web,
@@ -18,6 +21,7 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
inpaint_init_image,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
inpaint_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.outpaint_ui import (
outpaint_web,
@@ -25,4 +29,14 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
outpaint_init_image,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
outpaint_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.upscaler_ui import (
upscaler_web,
upscaler_gallery,
upscaler_init_image,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
)
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web

View File

@@ -7,93 +7,100 @@ Procedure to upgrade the dark theme:
*/
:root {
--color-accent-soft: var(--neutral-700);
--color-background-primary: var(--neutral-950);
--color-background-secondary: var(--neutral-900);
--color-border-accent: var(--neutral-600);
--color-border-primary: var(--neutral-700);
--text-color-code-background: var(--neutral-800);
--text-color-link-active: var(--secondary-500);
--text-color-link: var(--secondary-500);
--text-color-link-hover: var(--secondary-400);
--text-color-link-visited: var(--secondary-600);
--text-color-subdued: var(--neutral-400);
--body-background-color: var(--color-background-primary);
--body-background-fill: var(--background-fill-primary);
--body-text-color: var(--neutral-100);
--color-accent-soft: var(--neutral-700);
--background-fill-primary: var(--neutral-950);
--background-fill-secondary: var(--neutral-900);
--border-color-accent: var(--neutral-600);
--border-color-primary: var(--neutral-700);
--link-text-color-active: var(--secondary-500);
--link-text-color: var(--secondary-500);
--link-text-color-hover: var(--secondary-400);
--link-text-color-visited: var(--secondary-600);
--body-text-color-subdued: var(--neutral-400);
--shadow-spread: 1px;
--block-background: var(--neutral-800);
--block-border-color: var(--color-border-primary);
--block-border-width: 1px;
--block-info-color: var(--text-color-subdued);
--block-label-background: var(--color-background-secondary);
--block-label-border-color: var(--color-border-primary);
--block-label-border-width: 1px;
--block-label-color: var(--neutral-200);
--block-shadow: none;
--block-title-background: none;
--block-title-border-color: none;
--block-title-border-width: 0px;
--block-title-color: var(--neutral-200);
--panel-background: var(--color-background-secondary);
--panel-border-color: var(--color-border-primary);
--checkbox-background: var(--neutral-800);
--checkbox-background-focus: var(--checkbox-background);
--checkbox-background-hover: var(--checkbox-background);
--checkbox-background-selected: var(--secondary-600);
--block-background-fill: var(--neutral-800);
--block-border-color: var(--border-color-primary);
--block_border_width: None;
--block-info-text-color: var(--body-text-color-subdued);
--block-label-background-fill: var(--background-fill-secondary);
--block-label-border-color: var(--border-color-primary);
--block_label_border_width: None;
--block-label-text-color: var(--neutral-200);
--block_shadow: None;
--block_title_background_fill: None;
--block_title_border_color: None;
--block_title_border_width: None;
--block-title-text-color: var(--neutral-200);
--panel-background-fill: var(--background-fill-secondary);
--panel-border-color: var(--border-color-primary);
--panel_border_width: None;
--checkbox-background-color: var(--neutral-800);
--checkbox-background-color-focus: var(--checkbox-background-color);
--checkbox-background-color-hover: var(--checkbox-background-color);
--checkbox-background-color-selected: var(--secondary-600);
--checkbox-border-color: var(--neutral-700);
--checkbox-border-color-focus: var(--secondary-500);
--checkbox-border-color-hover: var(--neutral-600);
--checkbox-border-color-selected: var(--secondary-600);
--checkbox-label-background: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-selected: var(--checkbox-label-background);
--checkbox-label-border-color: var(--color-border-primary);
--checkbox-label-border-color-hover: var(--color-border-primary);
--checkbox-text-color: var(--body-text-color);
--checkbox-text-color-selected: var(--checkbox-text-color);
--error-background: var(--color-background-primary);
--error-border-color: var(--color-border-primary);
--error-border-width: var(--error-border-width);
--error-color: #ef4444;
--input-background: var(--neutral-800);
--input-background-focus: var(--secondary-600);
--input-background-hover: var(--input-background);
--input-border-color: var(--color-border-primary);
--checkbox-border-width: var(--input-border-width);
--checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
--checkbox-label-border-color: var(--border-color-primary);
--checkbox-label-border-color-hover: var(--checkbox-label-border-color);
--checkbox-label-border-width: var(--input-border-width);
--checkbox-label-text-color: var(--body-text-color);
--checkbox-label-text-color-selected: var(--checkbox-label-text-color);
--error-background-fill: var(--background-fill-primary);
--error-border-color: var(--border-color-primary);
--error_border_width: None;
--error-text-color: #ef4444;
--input-background-fill: var(--neutral-800);
--input-background-fill-focus: var(--secondary-600);
--input-background-fill-hover: var(--input-background-fill);
--input-border-color: var(--border-color-primary);
--input-border-color-focus: var(--neutral-700);
--input-border-color-hover: var(--color-border-primary);
--input-border-color-hover: var(--input-border-color);
--input_border_width: None;
--input-placeholder-color: var(--neutral-500);
--input-shadow: var(--input-shadow);
--input_shadow: None;
--input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset);
--loader-color: var(--color-accent);
--stat-color-background: linear-gradient(to right, var(--primary-400), var(--primary-600));
--loader_color: None;
--slider_color: None;
--stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600));
--table-border-color: var(--neutral-700);
--table-even-background: var(--neutral-950);
--table-odd-background: var(--neutral-900);
--table-even-background-fill: var(--neutral-950);
--table-odd-background-fill: var(--neutral-900);
--table-row-focus: var(--color-accent-soft);
--button-cancel-background: linear-gradient(to bottom right, #dc2626, #b91c1c);
--button-cancel-background-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
--button-border-width: var(--input-border-width);
--button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
--button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
--button-cancel-border-color: #dc2626;
--button-cancel-border-color-hover: var(--button-cancel-border-color);
--button-cancel-text-color: white;
--button-cancel-text-color-hover: var(--button-cancel-text-color);
--button-primary-background: linear-gradient(to bottom right, var(--primary-600), var(--primary-700));
--button-primary-background-hover: linear-gradient(to bottom right, var(--primary-600), var(--primary-600));
--button-primary-border-color: var(--primary-600);
--button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600));
--button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500));
--button-primary-border-color: var(--primary-500);
--button-primary-border-color-hover: var(--button-primary-border-color);
--button-primary-text-color: white;
--button-primary-text-color-hover: var(--button-primary-text-color);
--button-secondary-background: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
--button-secondary-background-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
--button-secondary-border-color: var(--neutral-600);
--button-secondary-border-color-hover: var(--button-secondary-border-color);
--button-secondary-text-color: white;
--button-secondary-text-color-hover: var(--button-secondary-text-color);
--block-border-width: 1px;
--block-label-border-width: 1px;
--form-gap-width: 1px;
--error-border-width: 1px;
--input-border-width: 1px;
}
/* SHARK theme */
body {
background-color: var(--color-background-primary);
}
/* display in full width for desktop devices */
@media (min-width: 1536px)
@@ -131,7 +138,7 @@ body {
}
#prompt_box textarea, #negative_prompt_box textarea {
background-color: var(--color-background-primary) !important;
background-color: var(--background-fill-primary) !important;
}
#prompt_examples {
@@ -143,7 +150,6 @@ body {
}
#ui_body {
background-color: var(--color-background-secondary) !important;
padding: var(--size-2) !important;
border-radius: 0.5em !important;
}
@@ -172,6 +178,7 @@ footer {
/* Hide "remove buttons" from ui dropdowns */
#custom_model .token-remove.remove-all,
#lora_weights .token-remove.remove-all,
#scheduler .token-remove.remove-all,
#device .token-remove.remove-all,
#stencil_model .token-remove.remove-all {

View File

@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
scheduler_list,
predefined_models,
cancel_sd,
)
@@ -73,6 +74,21 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
value="None",
choices=["None", "canny", "openpose", "scribble"],
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -128,22 +144,29 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
step=0.01,
label="Denoising Strength",
)
with gr.Row():
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
@@ -153,6 +176,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -174,8 +198,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -201,6 +223,9 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
img2img_sendto_outpaint = gr.Button(
value="SendTo Outpaint"
)
img2img_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
kwargs = dict(
fn=img2img_inf,
@@ -225,6 +250,9 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
use_stencil,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[img2img_gallery, std_output],
show_progress=args.progress_bar,
@@ -233,6 +261,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
scheduler_list,
predefined_paint_models,
cancel_sd,
)
@@ -68,6 +69,21 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
type="pil",
).style(height=350)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -130,22 +146,29 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Row():
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
@@ -155,6 +178,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -176,8 +200,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -203,6 +225,9 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
inpaint_sendto_outpaint = gr.Button(
value="SendTo Outpaint"
)
inpaint_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
kwargs = dict(
fn=inpaint_inf,
@@ -227,6 +252,9 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[inpaint_gallery, std_output],
show_progress=args.progress_bar,
@@ -235,6 +263,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -0,0 +1,205 @@
from pathlib import Path
import os
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import lora_train
from apps.stable_diffusion.src import prompt_examples, args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list_txt2img,
predefined_models,
)
with gr.Blocks(title="Lora Training") as lora_train_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Group(elem_id="image_dir_box_outer"):
training_images_dir = gr.Textbox(
label="ImageDirectory",
value=args.training_images_dir,
lines=1,
elem_id="prompt_box",
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
choices=scheduler_list_txt2img,
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1,
2000,
value=args.training_steps,
step=1,
label="Training Steps",
)
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
train_lora = gr.Button("Train LoRA")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
std_output = gr.Textbox(
value="Nothing to show.",
lines=1,
show_label=False,
)
lora_save_dir = (
args.lora_save_dir if args.lora_save_dir else Path.cwd()
)
lora_save_dir = Path(lora_save_dir, "lora")
output_loc = gr.Textbox(
label="Saving Lora at",
value=lora_save_dir,
)
kwargs = dict(
fn=lora_train,
inputs=[
prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
training_images_dir,
output_loc,
],
outputs=[std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
train_click = train_lora.click(**kwargs)
stop_batch.click(fn=None, cancels=[prompt_submit, train_click])

View File

@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
scheduler_list,
predefined_paint_models,
cancel_sd,
)
@@ -65,6 +66,21 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Input Image", type="pil"
).style(height=300)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -149,22 +165,29 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
steps = gr.Slider(
1, 100, value=20, step=1, label="Steps"
)
with gr.Row():
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
@@ -174,6 +197,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -195,8 +219,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -220,6 +242,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Row():
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
outpaint_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
kwargs = dict(
fn=outpaint_inf,
@@ -247,6 +272,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[outpaint_gallery, std_output],
show_progress=args.progress_bar,
@@ -255,6 +283,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -1,9 +1,9 @@
from pathlib import Path
import os
import torch
import time
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import txt2img_inf
from apps.stable_diffusion.src import prompt_examples, args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -11,7 +11,187 @@ from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
scheduler_list_txt2img,
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
utils,
save_output_img,
prompt_examples,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"txt2img",
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
return generated_imgs, text_output
with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row(elem_id="ui_title"):
@@ -69,15 +249,13 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
lines=1,
elem_id="negative_prompt_box",
)
with gr.Accordion(
label="Lora based inference option", open=False
):
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -107,10 +285,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
with gr.Row():
height = gr.Slider(
384, 768, value=args.height, step=8, label="Height"
384,
768,
value=args.height,
step=8,
label="Height",
)
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
384,
768,
value=args.width,
step=8,
label="Width",
)
precision = gr.Radio(
label="Precision",
@@ -141,23 +327,31 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -179,8 +373,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
@@ -215,6 +407,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
txt2img_sendto_outpaint = gr.Button(
value="SendTo Outpaint"
)
txt2img_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
kwargs = dict(
fn=txt2img_inf,
@@ -238,6 +433,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[txt2img_gallery, std_output],
show_progress=args.progress_bar,
@@ -246,8 +442,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
from apps.stable_diffusion.web.utils.png_metadata import (

View File

@@ -0,0 +1,262 @@
from pathlib import Path
import os
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import upscaler_inf
from apps.stable_diffusion.src import args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
predefined_upscaler_models,
)
with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
choices=["None"]
+ get_custom_model_files()
+ predefined_upscaler_models,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
elem_id="negative_prompt_box",
)
upscaler_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="DDIM",
choices=scheduler_list,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
128,
512,
value=args.height,
step=128,
label="Height",
)
width = gr.Slider(
128,
512,
value=args.width,
step=128,
label="Width",
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
noise_level = gr.Slider(
0,
100,
value=args.noise_level,
step=1,
label="Noise Level",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
upscaler_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2])
std_output = gr.Textbox(
value="Nothing to show.",
lines=1,
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
with gr.Row():
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
upscaler_sendto_outpaint = gr.Button(
value="SendTo Outpaint"
)
kwargs = dict(
fn=upscaler_inf,
inputs=[
prompt,
negative_prompt,
upscaler_init_image,
height,
width,
steps,
noise_level,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
],
outputs=[upscaler_gallery, std_output],
show_progress=args.progress_bar,
)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
)

View File

@@ -4,6 +4,27 @@ from apps.stable_diffusion.src import get_available_devices
import glob
from pathlib import Path
from apps.stable_diffusion.src import args
from dataclasses import dataclass
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
@dataclass
class Config:
mode: str
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
use_lora: str
use_stencil: str
custom_model_filetypes = (
"*.ckpt",
@@ -35,10 +56,14 @@ predefined_models = [
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]
predefined_paint_models = [
"runwayml/stable-diffusion-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
]
predefined_upscaler_models = [
"stabilityai/stable-diffusion-x4-upscaler",
]
def resource_path(relative_path):
@@ -49,24 +74,63 @@ def resource_path(relative_path):
return os.path.join(base_path, relative_path)
def get_custom_model_path():
return Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd(), "models")
def get_custom_model_path(model="models"):
# If `--ckpt_dir` is provided it'd override the heirarchical folder
# structure in WebUI :-
# model
# |___lora
# |___vae
if args.ckpt_dir:
return Path(args.ckpt_dir)
match model:
case "models":
return Path(Path.cwd(), "models")
case "vae":
return Path(Path.cwd(), "models/vae")
case "lora":
return Path(Path.cwd(), "models/lora")
case _:
return ""
def get_custom_model_pathfile(custom_model_name):
return os.path.join(get_custom_model_path(), custom_model_name)
def get_custom_model_pathfile(custom_model_name, model="models"):
return os.path.join(get_custom_model_path(model), custom_model_name)
def get_custom_model_files():
def get_custom_model_files(model="models"):
ckpt_files = []
for extn in custom_model_filetypes:
file_types = custom_model_filetypes
if model == "lora":
file_types = custom_model_filetypes + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_custom_model_path(), extn))
for x in glob.glob(
os.path.join(get_custom_model_path(model), extn)
)
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_custom_vae_or_lora_weights(weights, hf_id, model):
use_weight = ""
if weights == "None" and not hf_id:
use_weight = ""
elif not hf_id:
use_weight = get_custom_model_pathfile(weights, model)
else:
use_weight = hf_id
return use_weight
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:
global_obj.set_sd_status(SD_STATE_CANCEL)
except Exception:
pass
nodlogo_loc = resource_path("logos/nod-logo.png")
available_devices = get_available_devices()

View File

@@ -0,0 +1,71 @@
import gc
"""
The global objects include SD pipeline and config.
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
Also we could avoid memory leak when switching models by clearing the cache.
"""
def _init():
global _sd_obj
global _config_obj
global _schedulers
_sd_obj = None
_config_obj = None
_schedulers = None
def set_sd_obj(value):
global _sd_obj
_sd_obj = value
def set_sd_scheduler(key):
global _sd_obj
_sd_obj.scheduler = _schedulers[key]
def set_sd_status(value):
global _sd_obj
_sd_obj.status = value
def set_cfg_obj(value):
global _config_obj
_config_obj = value
def set_schedulers(value):
global _schedulers
_schedulers = value
def get_sd_obj():
return _sd_obj
def get_sd_status():
return _sd_obj.status
def get_cfg_obj():
return _config_obj
def get_scheduler(key):
return _schedulers[key]
def clear_cache():
global _sd_obj
global _config_obj
global _schedulers
del _sd_obj
del _config_obj
del _schedulers
gc.collect()
_sd_obj = None
_config_obj = None
_schedulers = None

View File

@@ -2,4 +2,4 @@
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python generate_sharktank.py
python tank/generate_sharktank.py

View File

@@ -87,11 +87,22 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]
counter = 0
for import_opt in import_options:
for model_name in hf_model_names:
if model_name in to_skip:
continue
for use_tune in tuned_options:
if (
model_name == "stabilityai/stable-diffusion-2-1"
and use_tune == tuned_options[0]
):
continue
elif (
model_name == "stabilityai/stable-diffusion-2-1-base"
and use_tune == tuned_options[1]
):
continue
command = (
[
executable, # executable is the python from the venv used to run this
@@ -174,9 +185,21 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
else:
print(command)
print("failed to generate image for this configuration")
if "2_1_base" in model_name:
print("failed a known successful model.")
exit(1)
with open(dumpfile_name, "r+") as f:
output = f.readlines()
print("\n".join(output))
exit(1)
if os.name == "nt":
counter += 1
if counter % 2 == 0:
extra_flags.append(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
else:
if counter != 1:
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)

View File

@@ -2,9 +2,11 @@ def pytest_addoption(parser):
# Attaches SHARK command-line arguments to the pytest machinery.
parser.addoption(
"--benchmark",
action="store_true",
default="False",
help="Pass option to benchmark and write results.csv",
action="store",
type=str,
default=None,
choices=("baseline", "native", "all"),
help="Benchmarks specified engine(s) and writes bench_results.csv.",
)
parser.addoption(
"--onnx_bench",
@@ -40,7 +42,13 @@ def pytest_addoption(parser):
"--update_tank",
action="store_true",
default="False",
help="Update local shark tank with latest artifacts.",
help="Update local shark tank with latest artifacts if model artifact hash mismatched.",
)
parser.addoption(
"--force_update_tank",
action="store_true",
default="False",
help="Force-update local shark tank with artifacts from specified shark_tank URL (defaults to nightly).",
)
parser.addoption(
"--ci_sha",
@@ -51,15 +59,21 @@ def pytest_addoption(parser):
parser.addoption(
"--local_tank_cache",
action="store",
default="",
default=None,
help="Specify the directory in which all downloaded shark_tank artifacts will be cached.",
)
parser.addoption(
"--tank_url",
type=str,
default="gs://shark_tank/latest",
default="gs://shark_tank/nightly",
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
)
parser.addoption(
"--tank_prefix",
type=str,
default=None,
help="Prefix to gs://shark_tank/ model directories from which to download SHARK tank artifacts. Default is nightly.",
)
parser.addoption(
"--benchmark_dispatches",
default=None,
@@ -70,3 +84,9 @@ def pytest_addoption(parser):
default="./temp_dispatch_benchmarks",
help="Directory in which dispatch benchmarks are saved.",
)
parser.addoption(
"--batchsize",
default=1,
type=int,
help="Batch size for the tested model.",
)

View File

@@ -6,36 +6,16 @@ from distutils.sysconfig import get_python_lib
import fileinput
from pathlib import Path
# Diffusers 0.13.1 fails with transformers __init.py errros in BLIP. So remove it for now until we fork it
pix2pix_init = Path(get_python_lib() + "/diffusers/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(get_python_lib() + "/diffusers/pipelines/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(
get_python_lib() + "/diffusers/pipelines/stable_diffusion/__init__.py"
# Temorary workaround for transformers/__init__.py.
path_to_tranformers_hook = Path(
get_python_lib()
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
)
for line in fileinput.input(pix2pix_init, inplace=True):
if "StableDiffusionPix2PixZeroPipeline" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
if path_to_tranformers_hook.is_file():
pass
else:
with open(path_to_tranformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")

View File

@@ -1,3 +1,3 @@
[pytest]
addopts = --verbose -p no:warnings
addopts = --verbose -s -p no:warnings
norecursedirs = inference tank/tflite examples benchmarks shark

View File

@@ -2,8 +2,8 @@
--pre
numpy>1.22.4
torchvision
pytorch-triton
torchvision==0.16.0.dev20230322
tabulate
tqdm
@@ -15,8 +15,8 @@ iree-tools-tf
# TensorFlow and JAX.
gin-config
tf-nightly
keras>=2.10
tensorflow>2.11
keras
#tf-models-nightly
#tensorflow-text-nightly
transformers
@@ -33,6 +33,7 @@ lit
pyyaml
python-dateutil
sacremoses
sentencepiece
# web dependecies.
gradio

View File

@@ -16,7 +16,7 @@ parameterized
# Add transformers, diffusers and scipy since it most commonly used
transformers
diffusers
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
scipy
ftfy
gradio
@@ -25,6 +25,7 @@ omegaconf
safetensors
opencv-python
scikit-image
pytorch_lightning # for runwayml models
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile

View File

@@ -45,7 +45,7 @@ if ($arguments -eq "--force"){
Remove-Item .\shark.venv -Force -Recurse
if (Test-Path .\shark.venv\) {
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
break
exit 1
}
}
}
@@ -78,12 +78,12 @@ if (!($PyVer.length -ne 0)) {$p} # return Python --version String if py.exe is u
if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any list
{
Write-Host "Please install Python 3.11 and try again"
break
exit 34
}
Write-Host "Installing Build Dependencies"
# make sure we really use 3.11 from list, even if it's not the default.
if (!($PyVer.length -ne 0)) {py -3.11 -m venv .\shark.venv\}
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
else {python -m venv .\shark.venv\}
.\shark.venv\Scripts\activate
python -m pip install --upgrade pip

View File

@@ -129,11 +129,11 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
TV_VERSION=${TV_VER:9:18}
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu117."
echo "Successfully Installed torch + cu118."
else
echo "Could not install torch + cu117." >&2
echo "Could not install torch + cu118." >&2
fi
fi

View File

@@ -35,8 +35,9 @@ def run_cmd(cmd, debug=False):
stderr=subprocess.PIPE,
check=True,
)
result_str = result.stdout.decode()
return result_str
stdout = result.stdout.decode()
stderr = result.stderr.decode()
return stdout, stderr
except subprocess.CalledProcessError as e:
print(e.output)
sys.exit(f"Exiting program due to error running {cmd}")

View File

@@ -90,6 +90,7 @@ def build_benchmark_args(
benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}")
# if time_extractor:
# benchmark_cl.append(time_extractor)
benchmark_cl.append(f"--print_statistics=true")
return benchmark_cl
@@ -129,7 +130,8 @@ def build_benchmark_args_non_tensor_input(
def run_benchmark_module(benchmark_cl):
"""
Run benchmark command, extract result and return iteration/seconds.
Run benchmark command, extract result and return iteration/seconds, host
peak memory, and device peak memory.
# TODO: Add an example of the benchmark command.
Input: benchmark command.
@@ -138,15 +140,22 @@ def run_benchmark_module(benchmark_cl):
assert os.path.exists(
benchmark_path
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
bench_result = run_cmd(" ".join(benchmark_cl))
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
try:
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
match = regex_split.search(bench_stdout)
time_ms = float(match.group(1))
unit = match.group(3)
except AttributeError:
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
match = regex_split.search(bench_stdout)
time_ms = float(match.group(1))
unit = match.group(2)
return 1.0 / (time * 0.001)
iter_per_second = 1.0 / (time_ms * 0.001)
# Extract peak memory.
host_regex = re.compile(r".*HOST_LOCAL:\s*([0-9]+)B peak")
host_peak_b = int(host_regex.search(bench_stderr).group(1))
device_regex = re.compile(r".*DEVICE_LOCAL:\s*([0-9]+)B peak")
device_peak_b = int(device_regex.search(bench_stderr).group(1))
return iter_per_second, host_peak_b, device_peak_b

View File

@@ -52,7 +52,7 @@ def get_iree_device_args(device, extra_args=[]):
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg"]:
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
return ["--iree-llvmcpu-target-cpu-features=host"]
elif frontend in ["tensorflow", "tf", "mhlo"]:
return [
@@ -70,6 +70,7 @@ def get_iree_common_args():
return [
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
]
@@ -188,21 +189,23 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
benchmark_bash.write(" ".join(benchmark_cl))
benchmark_bash.close()
benchmark_data = run_benchmark_module(benchmark_cl)
iter_per_second, _, _ = run_benchmark_module(
benchmark_cl
)
benchmark_file = open(
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
)
benchmark_file.write(f"DISPATCH: {d_}\n")
benchmark_file.write(str(benchmark_data) + "\n")
benchmark_file.write(str(iter_per_second) + "\n")
benchmark_file.write(
"SHARK BENCHMARK RESULT: "
+ str(1 / (benchmark_data * 0.001))
+ str(1 / (iter_per_second * 0.001))
+ "\n"
)
benchmark_file.close()
benchmark_runtimes[d_] = 1 / (benchmark_data * 0.001)
benchmark_runtimes[d_] = 1 / (iter_per_second * 0.001)
elif ".mlir" in f_ and "benchmark" not in f_:
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")

View File

@@ -30,11 +30,10 @@ def get_iree_gpu_args():
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
) and (shark_args.enable_tf32 == True):
return [
"--iree-hal-cuda-disable-loop-nounroll-wa",
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
]
else:
return ["--iree-hal-cuda-disable-loop-nounroll-wa"]
return []
# Get the default gpu args given the architecture.

View File

@@ -131,6 +131,8 @@ def get_vendor(triple):
return "ARM"
if arch == "m1":
return "Apple"
if arch in ["arc", "UHD"]:
return "Intel"
if arch in ["turing", "ampere"]:
return "NVIDIA"
if arch == "ardeno":
@@ -149,7 +151,7 @@ def get_device_type(triple):
return "Unknown"
if arch == "cpu":
return "CPU"
if arch in ["turing", "ampere"]:
if arch in ["turing", "ampere", "arc"]:
return "DiscreteGPU"
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]:
if product == "ivega10":
@@ -343,6 +345,37 @@ def get_vulkan_target_capabilities(triple):
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "arc":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
cap["subgroupSize"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
"Arithmetic",
"Ballot",
"Shuffle",
"ShuffleRelative",
"Clustered",
"Quad",
]
cap["shaderFloat16"] = True
cap["shaderFloat64"] = False
cap["shaderInt8"] = True
cap["shaderInt16"] = True
cap["shaderInt64"] = False
cap["storageBuffer16BitAccess"] = True
cap["storagePushConstant16"] = True
cap["uniformAndStorageBuffer16BitAccess"] = True
cap["storageBuffer8BitAccess"] = True
cap["storagePushConstant8"] = True
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "cpu":
if product == "swiftshader":
cap["maxComputeSharedMemorySize"] = 16384

View File

@@ -22,7 +22,8 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_vulkan_device_name():
vulkaninfo_dump = run_cmd("vulkaninfo").split(linesep)
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
@@ -108,6 +109,9 @@ def get_vulkan_target_triple(device_name):
triple = f"rdna3-7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
triple = f"rdna2-unknown-{system_os}"
# Intel Targets
elif any(x in device_name for x in ("A770", "A750")):
triple = f"arc-770-{system_os}"
else:
triple = None
return triple
@@ -139,7 +143,7 @@ def get_vulkan_triple_flag(device_name="", extra_args=[]):
def get_iree_vulkan_args(extra_args=[]):
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
res_vulkan_flag = []
vulkan_triple_flag = None

View File

@@ -14,8 +14,10 @@
import argparse
import os
import subprocess
parser = argparse.ArgumentParser(description="SHARK runner.")
parser.add_argument(
"--device",
type=str,
@@ -54,7 +56,7 @@ parser.add_argument(
)
parser.add_argument(
"--shark_prefix",
default="latest",
default=None,
help="gs://shark_tank/<this_flag>/model_directories",
)
parser.add_argument(

View File

@@ -21,9 +21,17 @@ from shark.iree_utils.benchmark_utils import (
from shark.parser import shark_args
from datetime import datetime
import time
from typing import Optional
import csv
import os
TF_CPU_DEVICE = "/CPU:0"
TF_GPU_DEVICE = "/GPU:0"
def _bytes_to_mb_str(bytes_: Optional[int]) -> str:
return "" if bytes_ is None else f"{bytes_ / 1e6:.6f}"
class OnnxFusionOptions(object):
def __init__(self):
@@ -70,6 +78,7 @@ class SharkBenchmarkRunner(SharkRunner):
self.vmfb_file = None
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
self.import_args = {}
SharkRunner.__init__(
self,
mlir_module,
@@ -104,7 +113,6 @@ class SharkBenchmarkRunner(SharkRunner):
def benchmark_torch(self, modelname):
import torch
import torch._dynamo as dynamo
from tank.model_utils import get_torch_model
if self.device == "cuda":
@@ -116,31 +124,54 @@ class SharkBenchmarkRunner(SharkRunner):
torch_device = torch.device(
"cuda:0" if self.device == "cuda" else "cpu"
)
HFmodel, input = get_torch_model(modelname)[:2]
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
frontend_model = HFmodel.model
frontend_model.to(torch_device)
input.to(torch_device)
# frontend_model = torch.compile(frontend_model, mode="max-autotune", backend="inductor")
# TODO: re-enable as soon as pytorch CUDA context issues are resolved
try:
frontend_model = torch.compile(
frontend_model, mode="max-autotune", backend="inductor"
)
except RuntimeError:
frontend_model = HFmodel.model
for i in range(shark_args.num_warmup_iterations):
frontend_model.forward(input)
if self.device == "cuda":
torch.cuda.reset_peak_memory_stats()
begin = time.time()
for i in range(shark_args.num_iterations):
out = frontend_model.forward(input)
if i == shark_args.num_iterations - 1:
end = time.time()
break
end = time.time()
if self.device == "cuda":
stats = torch.cuda.memory_stats()
device_peak_b = stats["allocated_bytes.all.peak"]
frontend_model.to(torch.device("cpu"))
input.to(torch.device("cpu"))
torch.cuda.empty_cache()
else:
device_peak_b = None
print(
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
if self.device == "cuda":
# Set device to CPU so we don't run into segfaults exiting pytest subprocesses.
torch_device = torch.device("cpu")
return [
f"{shark_args.num_iterations/(end-begin)}",
f"{((end-begin)/shark_args.num_iterations)*1000}",
"", # host_peak_b (CPU usage) is not reported by PyTorch.
_bytes_to_mb_str(device_peak_b),
]
def benchmark_tf(self, modelname):
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
@@ -155,38 +186,55 @@ class SharkBenchmarkRunner(SharkRunner):
from tank.model_utils_tf import get_tf_model
# tf_device = "/GPU:0" if self.device == "cuda" else "/CPU:0"
tf_device = "/CPU:0"
# tf_device = TF_GPU_DEVICE if self.device == "cuda" else TF_CPU_DEVICE
tf_device = TF_CPU_DEVICE
with tf.device(tf_device):
(
model,
input,
) = get_tf_model(
modelname
modelname, self.import_args
)[:2]
frontend_model = model
for i in range(shark_args.num_warmup_iterations):
frontend_model.forward(*input)
if tf_device == TF_GPU_DEVICE:
tf.config.experimental.reset_memory_stats(tf_device)
begin = time.time()
for i in range(shark_args.num_iterations):
out = frontend_model.forward(*input)
if i == shark_args.num_iterations - 1:
end = time.time()
break
end = time.time()
if tf_device == TF_GPU_DEVICE:
memory_info = tf.config.experimental.get_memory_info(tf_device)
device_peak_b = memory_info["peak"]
else:
# tf.config.experimental does not currently support measuring
# CPU memory usage.
device_peak_b = None
print(
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
return [
f"{shark_args.num_iterations/(end-begin)}",
f"{((end-begin)/shark_args.num_iterations)*1000}",
"", # host_peak_b (CPU usage) is not reported by TensorFlow.
_bytes_to_mb_str(device_peak_b),
]
def benchmark_c(self):
result = run_benchmark_module(self.benchmark_cl)
print(f"Shark-IREE-C benchmark:{result} iter/second")
return [f"{result}", f"{1000/result}"]
iter_per_second, host_peak_b, device_peak_b = run_benchmark_module(
self.benchmark_cl
)
print(f"Shark-IREE-C benchmark:{iter_per_second} iter/second")
return [
f"{iter_per_second}",
f"{1000/iter_per_second}",
_bytes_to_mb_str(host_peak_b),
_bytes_to_mb_str(device_peak_b),
]
def benchmark_python(self, inputs):
input_list = [x for x in inputs]
@@ -196,8 +244,7 @@ class SharkBenchmarkRunner(SharkRunner):
begin = time.time()
for i in range(shark_args.num_iterations):
out = self.run("forward", input_list)
if i == shark_args.num_iterations - 1:
end = time.time()
end = time.time()
print(
f"Shark-IREE Python benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
@@ -306,11 +353,21 @@ for currently supported models. Exiting benchmark ONNX."
return comp_str
def benchmark_all_csv(
self, inputs: tuple, modelname, dynamic, device_str, frontend
self,
inputs: tuple,
modelname,
dynamic,
device_str,
frontend,
import_args,
mode="native",
):
self.setup_cl(inputs)
self.import_args = import_args
self.mode = mode
field_names = [
"model",
"batch_size",
"engine",
"dialect",
"device",
@@ -324,8 +381,19 @@ for currently supported models. Exiting benchmark ONNX."
"tags",
"notes",
"datetime",
"host_memory_mb",
"device_memory_mb",
"measured_host_memory_mb",
"measured_device_memory_mb",
]
engines = ["frontend", "shark_python", "shark_iree_c"]
# "frontend" must be the first element.
if self.mode == "native":
engines = ["shark_python", "shark_iree_c"]
if self.mode == "baseline":
engines = ["frontend"]
if self.mode == "all":
engines = ["frontend", "shark_python", "shark_iree_c"]
if shark_args.onnx_bench == True:
engines.append("onnxruntime")
@@ -336,75 +404,78 @@ for currently supported models. Exiting benchmark ONNX."
with open("bench_results.csv", mode="a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=field_names)
bench_result = {}
bench_result["model"] = modelname
bench_info = {}
bench_info["model"] = modelname
bench_info["batch_size"] = str(import_args["batch_size"])
bench_info["dialect"] = self.mlir_dialect
bench_info["iterations"] = shark_args.num_iterations
if dynamic == True:
bench_result["shape_type"] = "dynamic"
bench_info["shape_type"] = "dynamic"
else:
bench_result["shape_type"] = "static"
bench_result["device"] = device_str
bench_info["shape_type"] = "static"
bench_info["device"] = device_str
if "fp16" in modelname:
bench_result["data_type"] = "float16"
bench_info["data_type"] = "float16"
else:
bench_result["data_type"] = inputs[0].dtype
bench_info["data_type"] = inputs[0].dtype
for e in engines:
(
bench_result["param_count"],
bench_result["tags"],
bench_result["notes"],
) = ["", "", ""]
engine_result = {}
self.frontend_result = None
if e == "frontend":
bench_result["engine"] = frontend
engine_result["engine"] = frontend
if check_requirements(frontend):
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
engine_result["host_memory_mb"],
engine_result["device_memory_mb"],
) = self.benchmark_frontend(modelname)
self.frontend_result = bench_result["ms/iter"]
bench_result["vs. PyTorch/TF"] = "baseline"
self.frontend_result = engine_result["ms/iter"]
engine_result["vs. PyTorch/TF"] = "baseline"
(
bench_result["param_count"],
bench_result["tags"],
bench_result["notes"],
engine_result["param_count"],
engine_result["tags"],
engine_result["notes"],
) = self.get_metadata(modelname)
else:
self.frontend_result = None
continue
elif e == "shark_python":
bench_result["engine"] = "shark_python"
engine_result["engine"] = "shark_python"
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
) = self.benchmark_python(inputs)
bench_result[
engine_result[
"vs. PyTorch/TF"
] = self.compare_bench_results(
self.frontend_result, bench_result["ms/iter"]
self.frontend_result, engine_result["ms/iter"]
)
elif e == "shark_iree_c":
bench_result["engine"] = "shark_iree_c"
engine_result["engine"] = "shark_iree_c"
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
engine_result["host_memory_mb"],
engine_result["device_memory_mb"],
) = self.benchmark_c()
bench_result[
engine_result[
"vs. PyTorch/TF"
] = self.compare_bench_results(
self.frontend_result, bench_result["ms/iter"]
self.frontend_result, engine_result["ms/iter"]
)
elif e == "onnxruntime":
bench_result["engine"] = "onnxruntime"
engine_result["engine"] = "onnxruntime"
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
) = self.benchmark_onnx(modelname, inputs)
bench_result["dialect"] = self.mlir_dialect
bench_result["iterations"] = shark_args.num_iterations
bench_result["datetime"] = str(datetime.now())
writer.writerow(bench_result)
engine_result["datetime"] = str(datetime.now())
writer.writerow(bench_info | engine_result)

View File

@@ -127,33 +127,105 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
and os.path.isfile(os.path.join(model_dir, "hash.npy"))
):
print(f"""Using cached models from {WORKDIR}...""")
print(
f"""Model artifacts for {model_name} found at {WORKDIR}..."""
)
return True
return False
def _internet_connected():
import requests as req
try:
req.get("http://1.1.1.1")
return True
except:
return False
def get_git_revision_short_hash() -> str:
import subprocess
if shark_args.shark_prefix is not None:
prefix_kw = shark_args.shark_prefix
else:
import json
dir_path = os.path.dirname(os.path.realpath(__file__))
src = os.path.join(dir_path, "..", "tank_version.json")
with open(src, "r") as f:
data = json.loads(f.read())
prefix_kw = data["version"]
print(f"Checking for updates from gs://shark_tank/{prefix_kw}")
return prefix_kw
def get_sharktank_prefix():
tank_prefix = ""
if not _internet_connected():
print(
"No internet connection. Using the model already present in the tank."
)
tank_prefix = "none"
else:
desired_prefix = get_git_revision_short_hash()
storage_client_a = storage.Client.create_anonymous_client()
base_bucket_name = "shark_tank"
base_bucket = storage_client_a.bucket(base_bucket_name)
dir_blobs = base_bucket.list_blobs(prefix=f"{desired_prefix}")
for blob in dir_blobs:
dir_blob_name = blob.name.split("/")
if desired_prefix in dir_blob_name[0]:
tank_prefix = dir_blob_name[0]
break
else:
continue
if tank_prefix == "":
print(
f"shark_tank bucket not found matching ({desired_prefix}). Defaulting to nightly."
)
tank_prefix = "nightly"
return tank_prefix
# Downloads the torch model from gs://shark_tank dir.
def download_model(
model_name,
dynamic=False,
tank_url="gs://shark_tank/latest",
tank_url=None,
frontend=None,
tuned=None,
import_args={"batch_size": "1"},
):
model_name = model_name.replace("/", "_")
dyn_str = "_dynamic" if dynamic else ""
os.makedirs(WORKDIR, exist_ok=True)
model_dir_name = model_name + "_" + frontend
shark_args.shark_prefix = get_sharktank_prefix()
if import_args["batch_size"] != 1:
model_dir_name = (
model_name
+ "_"
+ frontend
+ "_BS"
+ str(import_args["batch_size"])
)
else:
model_dir_name = model_name + "_" + frontend
model_dir = os.path.join(WORKDIR, model_dir_name)
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
if not tank_url:
tank_url = "gs://shark_tank/" + shark_args.shark_prefix
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
if not check_dir_exists(
model_dir_name, frontend=frontend, dynamic=dyn_str
):
print(
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
f"Downloading artifacts for model {model_name} from: {full_gs_url}"
)
download_public_file(full_gs_url, model_dir)
elif shark_args.force_update_tank == True:
print(
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
@@ -179,6 +251,7 @@ def download_model(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
except FileNotFoundError:
print(f"Model artifact hash not found at {model_dir}.")
upstream_hash = None
if local_hash != upstream_hash and shark_args.update_tank == True:
print(f"Updating artifacts for model {model_name}...")
@@ -186,17 +259,28 @@ def download_model(
elif local_hash != upstream_hash:
print(
"Hash does not match upstream in gs://shark_tank/latest. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
"Hash does not match upstream in gs://shark_tank/. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
)
else:
print(
"Local and upstream hashes match. Using cached model artifacts."
)
model_dir = os.path.join(WORKDIR, model_dir_name)
tuned_str = "" if tuned is None else "_" + tuned
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
filename = os.path.join(model_dir, model_name + suffix)
if not os.path.exists(filename):
from tank.generate_sharktank import gen_shark_files
print(
"The model data was not found. Trying to generate artifacts locally."
)
gen_shark_files(model_name, frontend, WORKDIR, import_args)
assert os.path.exists(filename), f"MLIR not found at {filename}"
with open(filename, mode="rb") as f:
mlir_file = f.read()
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
@@ -204,13 +288,3 @@ def download_model(
inputs_tuple = tuple([inputs[key] for key in inputs])
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
return mlir_file, function_name, inputs_tuple, golden_out_tuple
def _internet_connected():
import requests as req
try:
req.get("http://1.1.1.1")
return True
except:
return False

View File

@@ -9,8 +9,8 @@ import hashlib
def create_hash(file_name):
with open(file_name, "rb") as f:
file_hash = hashlib.blake2b()
while chunk := f.read(2**20):
file_hash = hashlib.blake2b(digest_size=64)
while chunk := f.read(2**10):
file_hash.update(chunk)
return file_hash.hexdigest()
@@ -165,8 +165,17 @@ class SharkImporter:
if self.frontend == "torch":
with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file:
mlir_file.write(mlir_data)
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
hash_gen_attempts = 2
for i in range(hash_gen_attempts):
try:
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
except FileNotFoundError as err:
if i < hash_gen_attempts:
continue
else:
raise err
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
return
def import_debug(
@@ -297,6 +306,7 @@ def transform_fx(fx_g):
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.zeros,
]:
node.kwargs = kwargs_dict
# Inputs and outputs of aten.var.mean should be upcasted to fp32.

View File

@@ -22,7 +22,7 @@ bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
bert-large-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344",""
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/388","macos"
@@ -35,3 +35,12 @@ squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","mac
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243",""
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"",""
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported",""
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
1 resnet50 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False False macos
22 bert-large-uncased mhlo tf 1e-2 1e-3 default None False False False
23 facebook/deit-small-distilled-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True False Fails during iree-compile.
24 google/vit-base-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/311
25 microsoft/beit-base-patch16-224-pt22k-ft22k linalg torch 1e-2 1e-3 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/390 macos
26 microsoft/MiniLM-L12-H384-uncased linalg torch 1e-2 1e-3 default None False False False
27 google/mobilebert-uncased linalg torch 1e-2 1e-3 default None False False False https://github.com/nod-ai/SHARK/issues/344
28 mobilenet_v3_small linalg torch 1e-1 1e-2 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/388 macos
35 wide_resnet50_2 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col False False False macos
36 efficientnet-v2-s mhlo tf 1e-02 1e-3 default nhcw-nhwc False False False macos
37 mnasnet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True macos
38 efficientnet_b0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True False https://github.com/nod-ai/SHARK/issues/1243
39 efficientnet_b7 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False Fails on MacOS builder, VK device lost macos
40 efficientnet_b0 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False False
41 efficientnet_b7 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False False Fails on MacOS builder, VK device lost macos
42 gpt2 mhlo tf 1e-2 1e-3 default None True False False
43 t5-base linalg torch 1e-2 1e-3 default None True True True Inputs for seq2seq models in torch currently unsupported.
44 t5-base mhlo tf 1e-2 1e-3 default None False False False
45 t5-large linalg torch 1e-2 1e-3 default None True True True Inputs for seq2seq models in torch currently unsupported
46 t5-large mhlo tf 1e-2 1e-3 default None False False False

View File

@@ -70,7 +70,7 @@ if __name__ == "__main__":
backend_config = "dylib"
# backend = "cuda"
# backend_config = "cuda"
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
flatbuffer_blob = compile_str(
compiler_module,
target_backends=[backend],

View File

@@ -146,7 +146,6 @@ if __name__ == "__main__":
backend_config = "cuda"
args = [
"--iree-cuda-llvm-target-arch=sm_80",
"--iree-hal-cuda-disable-loop-nounroll-wa",
"--iree-enable-fusion-with-reduction-ops",
]

View File

@@ -91,7 +91,7 @@ if __name__ == "__main__":
backend_config = "dylib"
# backend = "cuda"
# backend_config = "cuda"
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
flatbuffer_blob = compile_str(
compiler_module,
target_backends=[backend],

View File

@@ -86,7 +86,7 @@ if __name__ == "__main__":
backend_config = "dylib"
# backend = "cuda"
# backend_config = "cuda"
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
flatbuffer_blob = compile_str(
compiler_module,
target_backends=[backend],

View File

@@ -26,16 +26,17 @@ from apps.stable_diffusion.src.utils.stable_args import (
def create_hash(file_name):
with open(file_name, "rb") as f:
file_hash = hashlib.blake2b()
while chunk := f.read(2**20):
file_hash = hashlib.blake2b(digest_size=64)
while chunk := f.read(2**10):
file_hash.update(chunk)
return file_hash.hexdigest()
def save_torch_model(torch_model_list):
def save_torch_model(torch_model_list, local_tank_cache, import_args):
from tank.model_utils import (
get_hf_model,
get_hf_seq2seq_model,
get_vision_model,
get_hf_img_cls_model,
get_fp16_model,
@@ -58,8 +59,7 @@ def save_torch_model(torch_model_list):
if model_type == "stable_diffusion":
args.use_tuned = False
args.import_mlir = True
args.use_tuned = False
args.local_tank_cache = WORKDIR
args.local_tank_cache = local_tank_cache
precision_values = ["fp16"]
seq_lengths = [64, 77]
@@ -74,24 +74,41 @@ def save_torch_model(torch_model_list):
width=512,
height=512,
use_base_vae=False,
custom_vae="",
debug=True,
sharktank_dir=WORKDIR,
sharktank_dir=local_tank_cache,
generate_vmfb=False,
)
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(torch_model_name)
model, input, _ = get_vision_model(
torch_model_name, import_args
)
elif model_type == "hf":
model, input, _ = get_hf_model(torch_model_name)
model, input, _ = get_hf_model(torch_model_name, import_args)
elif model_type == "hf_seq2seq":
model, input, _ = get_hf_seq2seq_model(
torch_model_name, import_args
)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(torch_model_name)
model, input, _ = get_hf_img_cls_model(
torch_model_name, import_args
)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name)
model, input, _ = get_fp16_model(torch_model_name, import_args)
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"
)
if import_args["batch_size"] != 1:
torch_model_dir = os.path.join(
local_tank_cache,
str(torch_model_name)
+ "_torch"
+ f"_BS{str(import_args['batch_size'])}",
)
else:
torch_model_dir = os.path.join(
local_tank_cache, str(torch_model_name) + "_torch"
)
os.makedirs(torch_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
@@ -115,13 +132,18 @@ def save_torch_model(torch_model_list):
)
def save_tf_model(tf_model_list):
def save_tf_model(tf_model_list, local_tank_cache, import_args):
from tank.model_utils_tf import (
get_causal_image_model,
get_masked_lm_model,
get_causal_lm_model,
get_keras_model,
get_TFhf_model,
get_tfhf_seq2seq_model,
)
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
@@ -145,16 +167,38 @@ def save_tf_model(tf_model_list):
input = None
print(f"Generating artifacts for model {tf_model_name}")
if model_type == "hf":
model, input, _ = get_causal_lm_model(tf_model_name)
if model_type == "img":
model, input, _ = get_causal_image_model(tf_model_name)
if model_type == "keras":
model, input, _ = get_keras_model(tf_model_name)
if model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name)
model, input, _ = get_masked_lm_model(
tf_model_name, import_args
)
elif model_type == "img":
model, input, _ = get_causal_image_model(
tf_model_name, import_args
)
elif model_type == "keras":
model, input, _ = get_keras_model(tf_model_name, import_args)
elif model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name, import_args)
elif model_type == "tfhf_seq2seq":
model, input, _ = get_tfhf_seq2seq_model(
tf_model_name, import_args
)
elif model_type == "hf_causallm":
model, input, _ = get_causal_lm_model(
tf_model_name, import_args
)
tf_model_name = tf_model_name.replace("/", "_")
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
if import_args["batch_size"] != 1:
tf_model_dir = os.path.join(
local_tank_cache,
str(tf_model_name)
+ "_tf"
+ f"_BS{str(import_args['batch_size'])}",
)
else:
tf_model_dir = os.path.join(
local_tank_cache, str(tf_model_name) + "_tf"
)
os.makedirs(tf_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
model,
@@ -166,13 +210,9 @@ def save_tf_model(tf_model_list):
dir=tf_model_dir,
model_name=tf_model_name,
)
mlir_hash = create_hash(
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
)
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
def save_tflite_model(tflite_model_list):
def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
from shark.tflite_utils import TFLitePreprocessor
with open(tflite_model_list) as csvfile:
@@ -184,18 +224,18 @@ def save_tflite_model(tflite_model_list):
print("tflite_model_name", tflite_model_name)
print("tflite_model_link", tflite_model_link)
tflite_model_name_dir = os.path.join(
WORKDIR, str(tflite_model_name) + "_tflite"
local_tank_cache, str(tflite_model_name) + "_tflite"
)
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
# Preprocess to get SharkImporter input args
# Preprocess to get SharkImporter input import_args
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()
# Use SharkImporter to get SharkInference input args
# Use SharkImporter to get SharkInference input import_args
my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
@@ -219,6 +259,71 @@ def save_tflite_model(tflite_model_list):
)
def check_requirements(frontend):
import importlib
has_pkgs = False
if frontend == "torch":
tv_spec = importlib.util.find_spec("torchvision")
has_pkgs = tv_spec is not None
elif frontend in ["tensorflow", "tf"]:
tf_spec = importlib.util.find_spec("tensorflow")
has_pkgs = tf_spec is not None
return has_pkgs
class NoImportException(Exception):
"Raised when requirements are not met for OTF model artifact generation."
pass
def gen_shark_files(modelname, frontend, tank_dir, importer_args):
# If a model's artifacts are requested by shark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
# TODO: Add TFlite support.
import tempfile
import_args = importer_args
if check_requirements(frontend):
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tf_model_list.csv"
)
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
)
# Create a temporary .csv with only the desired entry.
if frontend == "tf":
with open(tf_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_tf_model(custom_model_csv.name, tank_dir, import_args)
elif frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_torch_model(custom_model_csv.name, tank_dir, import_args)
else:
raise NoImportException
# Validates whether the file is present or not.
def is_valid_file(arg):
if not os.path.exists(arg):
@@ -228,7 +333,7 @@ def is_valid_file(arg):
if __name__ == "__main__":
# Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality
# Note, all of these flags are overridden by the import of import_args from stable_args.py, flags are duplicated temporarily to preserve functionality
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--torch_model_csv",
@@ -256,23 +361,26 @@ if __name__ == "__main__":
# )
# parser.add_argument("--upload", type=bool, default=False)
# old_args = parser.parse_args()
# old_import_args = parser.parse_import_args()
import_args = {
"batch_size": "1",
}
print(import_args)
home = str(Path.home())
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
WORKDIR = os.path.join(os.path.dirname(__file__), "..", "gen_shark_tank")
torch_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tf_model_list.csv"
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
tflite_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
)
save_torch_model(
os.path.join(os.path.dirname(__file__), "tank", "torch_sd_list.csv")
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
WORKDIR,
import_args,
)
save_torch_model(torch_model_csv)
save_tf_model(tf_model_csv)
save_tflite_model(tflite_model_csv)
save_torch_model(torch_model_csv, WORKDIR, import_args)
save_tf_model(tf_model_csv, WORKDIR, import_args)
save_tflite_model(tflite_model_csv, WORKDIR, import_args)

View File

@@ -31,4 +31,12 @@ xlm-roberta-base,False,False,-,-,-
facebook/convnext-tiny-224,False,False,-,-,-
efficientnet-v2-s,False,False,22M,"image-classification,cnn","Includes MBConv and Fused-MBConv"
mnasnet1_0,False,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
bert-large-uncased,True,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
t5-base,True,False,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
t5-large,True,False,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
efficientnet_b0,True,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
gpt2,True,False,110M,"nlp;transformer-decoder;auto-regressive","12 layers, 768 hidden units, 12 attention heads"
t5-base,True,False,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
t5-large,True,False,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
1 model_name use_tracing dynamic param_count tags notes
31 facebook/convnext-tiny-224 False False - - -
32 efficientnet-v2-s False False 22M image-classification,cnn Includes MBConv and Fused-MBConv
33 mnasnet1_0 False True - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
34 bert-large-uncased True True 330M nlp;bert-variant;transformer-encoder 24 layers, 1024 hidden units, 16 attention heads
35 t5-base True False 220M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
36 t5-large True False 770M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
37 bert-large-uncased True hf True 330M nlp;bert-variant;transformer-encoder
38 efficientnet_b0 True False 5.3M image-classification;cnn;conv2d;depthwise-conv Smallest EfficientNet variant with 224x224 input
39 efficientnet_b7 True False 66M image-classification;cnn;conv2d;depthwise-conv Largest EfficientNet variant with 600x600 input
40 gpt2 True False 110M nlp;transformer-decoder;auto-regressive 12 layers, 768 hidden units, 12 attention heads
41 t5-base True False 220M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
42 t5-large True False 770M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer

View File

@@ -1,5 +1,4 @@
from shark.shark_inference import SharkInference
from shark.parser import shark_args
import torch
import numpy as np
@@ -7,6 +6,8 @@ import sys
torch.manual_seed(0)
BATCH_SIZE = 1
vision_models = [
"alexnet",
"resnet101",
@@ -17,6 +18,8 @@ vision_models = [
"wide_resnet50_2",
"mobilenet_v3_small",
"mnasnet1_0",
"efficientnet_b0",
"efficientnet_b7",
]
hf_img_cls_models = [
"google/vit-base-patch16-224",
@@ -25,17 +28,23 @@ hf_img_cls_models = [
"microsoft/beit-base-patch16-224-pt22k-ft22k",
"nvidia/mit-b0",
]
hf_seq2seq_models = [
"t5-base",
"t5-large",
]
def get_torch_model(modelname):
def get_torch_model(modelname, import_args):
if modelname in vision_models:
return get_vision_model(modelname)
return get_vision_model(modelname, import_args)
elif modelname in hf_img_cls_models:
return get_hf_img_cls_model(modelname)
return get_hf_img_cls_model(modelname, import_args)
elif modelname in hf_seq2seq_models:
return get_hf_seq2seq_model(modelname, import_args)
elif "fp16" in modelname:
return get_fp16_model(modelname)
return get_fp16_model(modelname, import_args)
else:
return get_hf_model(modelname)
return get_hf_model(modelname, import_args)
##################### Hugging Face Image Classification Models ###################################
@@ -78,13 +87,14 @@ class HuggingFaceImageClassification(torch.nn.Module):
return self.model.forward(inputs)[0]
def get_hf_img_cls_model(name):
def get_hf_img_cls_model(name, import_args):
model = HuggingFaceImageClassification(name)
# you can use preprocess_input_image to get the test_input or just random value.
test_input = preprocess_input_image(name)
# test_input = torch.FloatTensor(1, 3, 224, 224).uniform_(-1, 1)
# print("test_input.shape: ", test_input.shape)
# test_input.shape: torch.Size([1, 3, 224, 224])
test_input = test_input.repeat(int(import_args["batch_size"]), 1, 1, 1)
actual_out = model(test_input)
# print("actual_out.shape ", actual_out.shape)
# actual_out.shape torch.Size([1, 1000])
@@ -114,18 +124,58 @@ class HuggingFaceLanguage(torch.nn.Module):
return self.model.forward(tokens)[0]
def get_hf_model(name):
def get_hf_model(name, import_args):
from transformers import (
BertTokenizer,
)
model = HuggingFaceLanguage(name)
# TODO: Currently the test input is set to (1,128)
test_input = torch.randint(2, (1, 128))
test_input = torch.randint(2, (int(import_args["batch_size"]), 128))
actual_out = model(test_input)
return model, test_input, actual_out
##################### Hugging Face Seq2SeqLM Models ###################################
# We use a maximum sequence length of 512 since this is the default used in the T5 config.
T5_MAX_SEQUENCE_LENGTH = 512
class HFSeq2SeqLanguageModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
from transformers import AutoTokenizer, T5Model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenization_kwargs = {
"pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH,
"padding": True,
"return_tensors": "pt",
}
self.model = T5Model.from_pretrained(model_name, return_dict=True)
def preprocess_input(self, text):
return self.tokenizer(text, **self.tokenization_kwargs)
def forward(self, input_ids, decoder_input_ids):
return self.model.forward(
input_ids, decoder_input_ids=decoder_input_ids
)[0]
def get_hf_seq2seq_model(name, import_args):
m = HFSeq2SeqLanguageModel(name)
encoded_input_ids = m.preprocess_input(
"Studies have been shown that owning a dog is good for you"
).input_ids
decoder_input_ids = m.preprocess_input("Studies show that").input_ids
decoder_input_ids = m.model._shift_right(decoder_input_ids)
test_input = (encoded_input_ids, decoder_input_ids)
actual_out = m.forward(*test_input)
return m, test_input, actual_out
################################################################################
##################### Torch Vision Models ###################################
@@ -141,27 +191,55 @@ class VisionModule(torch.nn.Module):
return self.model.forward(input)
def get_vision_model(torch_model):
def get_vision_model(torch_model, import_args):
import torchvision.models as models
vision_models_dict = {
"alexnet": models.alexnet(weights="DEFAULT"),
"resnet18": models.resnet18(weights="DEFAULT"),
"resnet50": models.resnet50(weights="DEFAULT"),
"resnet50_fp16": models.resnet50(weights="DEFAULT"),
"resnet101": models.resnet101(weights="DEFAULT"),
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
"mobilenet_v3_small": models.mobilenet_v3_small(weights="DEFAULT"),
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
}
if isinstance(torch_model, str):
fp16_model = None
if "fp16" in torch_model:
fp16_model = True
torch_model = vision_models_dict[torch_model]
default_image_size = (224, 224)
modelname = torch_model
if modelname == "alexnet":
torch_model = models.alexnet(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "resnet18":
torch_model = models.resnet18(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "resnet50":
torch_model = models.resnet50(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "resnet50_fp16":
torch_model = models.resnet50(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "resnet50_fp16":
torch_model = models.resnet50(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "resnet101":
torch_model = models.resnet101(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "squeezenet1_0":
torch_model = models.squeezenet1_0(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "wide_resnet50_2":
torch_model = models.wide_resnet50_2(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "mobilenet_v3_small":
torch_model = models.mobilenet_v3_small(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "mnasnet1_0":
torch_model = models.mnasnet1_0(weights="DEFAULT")
input_image_size = default_image_size
if modelname == "efficientnet_b0":
torch_model = models.efficientnet_b0(weights="DEFAULT")
input_image_size = (224, 224)
if modelname == "efficientnet_b7":
torch_model = models.efficientnet_b7(weights="DEFAULT")
input_image_size = (600, 600)
fp16_model = False
if "fp16" in modelname:
fp16_model = True
model = VisionModule(torch_model)
test_input = torch.randn(1, 3, 224, 224)
test_input = torch.randn(
int(import_args["batch_size"]), 3, *input_image_size
)
actual_out = model(test_input)
if fp16_model is not None:
test_input_fp16 = test_input.to(
@@ -202,13 +280,14 @@ class BertHalfPrecisionModel(torch.nn.Module):
return self.model.forward(tokens)[0]
def get_fp16_model(torch_model):
def get_fp16_model(torch_model, import_args):
from transformers import AutoTokenizer
modelname = torch_model.replace("_fp16", "")
model = BertHalfPrecisionModel(modelname)
tokenizer = AutoTokenizer.from_pretrained(modelname)
text = "Replace me by any text you like."
text = [text] * int(import_args["batch_size"])
test_input_fp16 = tokenizer(
text,
truncation=True,

View File

@@ -1,17 +1,19 @@
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
import numpy as np
from transformers import (
AutoModelForSequenceClassification,
BertTokenizer,
TFBertModel,
)
BATCH_SIZE = 1
MAX_SEQUENCE_LENGTH = 128
################################## MHLO/TF models #########################################
# TODO : Generate these lists or fetch model source from tank/tf/tf_model_list.csv
keras_models = ["resnet50", "efficientnet-v2-s"]
keras_models = [
"resnet50",
"efficientnet_b0",
"efficientnet_b7",
"efficientnet-v2-s",
]
maskedlm_models = [
"albert-base-v2",
"bert-base-uncased",
@@ -32,36 +34,61 @@ maskedlm_models = [
"hf-internal-testing/tiny-random-flaubert",
"xlm-roberta",
]
causallm_models = [
"gpt2",
]
tfhf_models = [
"microsoft/MiniLM-L12-H384-uncased",
]
tfhf_seq2seq_models = [
"t5-base",
"t5-large",
]
img_models = [
"google/vit-base-patch16-224",
"facebook/convnext-tiny-224",
]
def get_tf_model(name):
def get_tf_model(name, import_args):
if name in keras_models:
return get_keras_model(name)
return get_keras_model(name, import_args)
elif name in maskedlm_models:
return get_causal_lm_model(name)
return get_masked_lm_model(name, import_args)
elif name in causallm_models:
return get_causal_lm_model(name, import_args)
elif name in tfhf_models:
return get_TFhf_model(name)
return get_TFhf_model(name, import_args)
elif name in img_models:
return get_causal_image_model(name)
return get_causal_image_model(name, import_args)
elif name in tfhf_seq2seq_models:
return get_tfhf_seq2seq_model(name, import_args)
else:
raise Exception(
"TF model not found! Please check that the modelname has been input correctly."
)
##################### Tensorflow Hugging Face LM Models ###################################
##################### Tensorflow Hugging Face Bert Models ###################################
from transformers import (
AutoModelForSequenceClassification,
BertTokenizer,
TFBertModel,
)
BERT_MAX_SEQUENCE_LENGTH = 128
# Create a set of 2-dimensional inputs
tf_bert_input = [
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
tf.TensorSpec(
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
tf.TensorSpec(
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
]
@@ -81,27 +108,37 @@ class TFHuggingFaceLanguage(tf.Module):
return self.m.predict(input_ids, attention_mask, token_type_ids)
def get_TFhf_model(name):
def get_TFhf_model(name, import_args):
model = TFHuggingFaceLanguage(name)
tokenizer = BertTokenizer.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased"
)
text = "Replace me by any text you'd like."
text = [text] * BATCH_SIZE
encoded_input = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
)
for key in encoded_input:
encoded_input[key] = tf.expand_dims(
tf.convert_to_tensor(encoded_input[key]), 0
)
test_input = (
encoded_input["input_ids"],
encoded_input["attention_mask"],
encoded_input["token_type_ids"],
max_length=BERT_MAX_SEQUENCE_LENGTH,
)
test_input = [
tf.reshape(
tf.convert_to_tensor(encoded_input["input_ids"], dtype=tf.int32),
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
),
tf.reshape(
tf.convert_to_tensor(
encoded_input["attention_mask"], dtype=tf.int32
),
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
),
tf.reshape(
tf.convert_to_tensor(
encoded_input["token_type_ids"], dtype=tf.int32
),
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
),
]
actual_out = model.forward(*test_input)
return model, test_input, actual_out
@@ -115,34 +152,40 @@ def compare_tensors_tf(tf_tensor, numpy_tensor):
return np.allclose(tf_to_numpy, numpy_tensor, rtol, atol)
##################### Tensorflow Hugging Face Masked LM Models ###################################
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
import tensorflow as tf
# Create a set of input signature.
input_signature_maskedlm = [
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
]
# For supported models please see here:
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForCasualLM
# Tokenizer for language models
def preprocess_input(
model_name, text="This is just used to compile the model"
model_name, max_length, text="This is just used to compile the model"
):
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = [text] * BATCH_SIZE
inputs = tokenizer(
text,
padding="max_length",
return_tensors="tf",
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
max_length=max_length,
)
return inputs
##################### Tensorflow Hugging Face Masked LM Models ###################################
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
MASKED_LM_MAX_SEQUENCE_LENGTH = 128
# Create a set of input signature.
input_signature_maskedlm = [
tf.TensorSpec(
shape=[BATCH_SIZE, MASKED_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
tf.TensorSpec(
shape=[BATCH_SIZE, MASKED_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
]
# For supported models please see here:
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForMaskedLM
class MaskedLM(tf.Module):
def __init__(self, model_name):
super(MaskedLM, self).__init__()
@@ -156,19 +199,143 @@ class MaskedLM(tf.Module):
return self.m.predict(input_ids, attention_mask)
def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
def get_masked_lm_model(
hf_name, import_args, text="Hello, this is the default text."
):
model = MaskedLM(hf_name)
encoded_input = preprocess_input(hf_name, text)
encoded_input = preprocess_input(
hf_name, MASKED_LM_MAX_SEQUENCE_LENGTH, text
)
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
actual_out = model.forward(*test_input)
return model, test_input, actual_out
##################### Tensorflow Hugging Face Causal LM Models ###################################
from transformers import AutoConfig, TFAutoModelForCausalLM, TFGPT2Model
CAUSAL_LM_MAX_SEQUENCE_LENGTH = 1024
input_signature_causallm = [
tf.TensorSpec(
shape=[BATCH_SIZE, CAUSAL_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
tf.TensorSpec(
shape=[BATCH_SIZE, CAUSAL_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
]
# For supported models please see here:
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForCausalLM
# For more background, see:
# https://huggingface.co/blog/tf-xla-generate
class CausalLM(tf.Module):
def __init__(self, model_name):
super(CausalLM, self).__init__()
# Decoder-only models need left padding.
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, padding_side="left", pad_token="</s>"
)
self.tokenization_kwargs = {
"pad_to_multiple_of": CAUSAL_LM_MAX_SEQUENCE_LENGTH,
"padding": True,
"return_tensors": "tf",
}
self.model = TFGPT2Model.from_pretrained(model_name, return_dict=True)
self.model.predict = lambda x, y: self.model(
input_ids=x, attention_mask=y
)[0]
def preprocess_input(self, text):
return self.tokenizer(text, **self.tokenization_kwargs)
@tf.function(input_signature=input_signature_causallm, jit_compile=True)
def forward(self, input_ids, attention_mask):
return self.model.predict(input_ids, attention_mask)
def get_causal_lm_model(
hf_name, import_args, text="Hello, this is the default text."
):
model = CausalLM(hf_name)
batched_text = [text] * BATCH_SIZE
encoded_input = model.preprocess_input(batched_text)
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
actual_out = model.forward(*test_input)
return model, test_input, actual_out
##################### TensorflowHugging Face Seq2SeqLM Models ###################################
# We use a maximum sequence length of 512 since this is the default used in the T5 config.
T5_MAX_SEQUENCE_LENGTH = 512
input_signature_t5 = [
tf.TensorSpec(
shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH],
dtype=tf.int32,
name="input_ids",
),
tf.TensorSpec(
shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH],
dtype=tf.int32,
name="attention_mask",
),
]
class TFHFSeq2SeqLanguageModel(tf.Module):
def __init__(self, model_name):
super(TFHFSeq2SeqLanguageModel, self).__init__()
from transformers import (
AutoTokenizer,
AutoConfig,
TFAutoModelForSeq2SeqLM,
TFT5Model,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenization_kwargs = {
"pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH,
"padding": True,
"return_tensors": "tf",
}
self.model = TFT5Model.from_pretrained(model_name, return_dict=True)
self.model.predict = lambda x, y: self.model(x, decoder_input_ids=y)[0]
def preprocess_input(self, text):
return self.tokenizer(text, **self.tokenization_kwargs)
@tf.function(input_signature=input_signature_t5, jit_compile=True)
def forward(self, input_ids, decoder_input_ids):
return self.model.predict(input_ids, decoder_input_ids)
def get_tfhf_seq2seq_model(name, import_args):
m = TFHFSeq2SeqLanguageModel(name)
text = "Studies have been shown that owning a dog is good for you"
batched_text = [text] * BATCH_SIZE
encoded_input_ids = m.preprocess_input(batched_text).input_ids
text = "Studies show that"
batched_text = [text] * BATCH_SIZE
decoder_input_ids = m.preprocess_input(batched_text).input_ids
decoder_input_ids = m.model._shift_right(decoder_input_ids)
test_input = (encoded_input_ids, decoder_input_ids)
actual_out = m.forward(*test_input)
return m, test_input, actual_out
##################### TensorFlow Keras Resnet Models #########################################################
# Static shape, including batch size (1).
# Can be dynamic once dynamic shape support is ready.
RESNET_INPUT_SHAPE = [1, 224, 224, 3]
EFFICIENTNET_INPUT_SHAPE = [1, 384, 384, 3]
RESNET_INPUT_SHAPE = [BATCH_SIZE, 224, 224, 3]
EFFICIENTNET_V2_S_INPUT_SHAPE = [BATCH_SIZE, 384, 384, 3]
EFFICIENTNET_B0_INPUT_SHAPE = [BATCH_SIZE, 224, 224, 3]
EFFICIENTNET_B7_INPUT_SHAPE = [BATCH_SIZE, 600, 600, 3]
class ResNetModule(tf.Module):
@@ -195,25 +362,79 @@ class ResNetModule(tf.Module):
return tf.keras.applications.resnet50.preprocess_input(image)
class EfficientNetModule(tf.Module):
class EfficientNetB0Module(tf.Module):
def __init__(self):
super(EfficientNetModule, self).__init__()
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
super(EfficientNetB0Module, self).__init__()
self.m = tf.keras.applications.efficientnet.EfficientNetB0(
weights="imagenet",
include_top=True,
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
input_shape=tuple(EFFICIENTNET_B0_INPUT_SHAPE[1:]),
)
self.m.predict = lambda x: self.m.call(x, training=False)
@tf.function(
input_signature=[tf.TensorSpec(EFFICIENTNET_INPUT_SHAPE, tf.float32)],
input_signature=[
tf.TensorSpec(EFFICIENTNET_B0_INPUT_SHAPE, tf.float32)
],
jit_compile=True,
)
def forward(self, inputs):
return self.m.predict(inputs)
def input_shape(self):
return EFFICIENTNET_INPUT_SHAPE
return EFFICIENTNET_B0_INPUT_SHAPE
def preprocess_input(self, image):
return tf.keras.applications.efficientnet.preprocess_input(image)
class EfficientNetB7Module(tf.Module):
def __init__(self):
super(EfficientNetB7Module, self).__init__()
self.m = tf.keras.applications.efficientnet.EfficientNetB7(
weights="imagenet",
include_top=True,
input_shape=tuple(EFFICIENTNET_B7_INPUT_SHAPE[1:]),
)
self.m.predict = lambda x: self.m.call(x, training=False)
@tf.function(
input_signature=[
tf.TensorSpec(EFFICIENTNET_B7_INPUT_SHAPE, tf.float32)
],
jit_compile=True,
)
def forward(self, inputs):
return self.m.predict(inputs)
def input_shape(self):
return EFFICIENTNET_B7_INPUT_SHAPE
def preprocess_input(self, image):
return tf.keras.applications.efficientnet.preprocess_input(image)
class EfficientNetV2SModule(tf.Module):
def __init__(self):
super(EfficientNetV2SModule, self).__init__()
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
weights="imagenet",
include_top=True,
input_shape=tuple(EFFICIENTNET_V2_S_INPUT_SHAPE[1:]),
)
self.m.predict = lambda x: self.m.call(x, training=False)
@tf.function(
input_signature=[
tf.TensorSpec(EFFICIENTNET_V2_S_INPUT_SHAPE, tf.float32)
],
jit_compile=True,
)
def forward(self, inputs):
return self.m.predict(inputs)
def input_shape(self):
return EFFICIENTNET_V2_S_INPUT_SHAPE
def preprocess_input(self, image):
return tf.keras.applications.efficientnet_v2.preprocess_input(image)
@@ -224,12 +445,17 @@ def load_image(path_to_image, width, height, channels):
image = tf.image.decode_image(image, channels=channels)
image = tf.image.resize(image, (width, height))
image = image[tf.newaxis, :]
image = tf.tile(image, [BATCH_SIZE, 1, 1, 1])
return image
def get_keras_model(modelname):
def get_keras_model(modelname, import_args):
if modelname == "efficientnet-v2-s":
model = EfficientNetModule()
model = EfficientNetV2SModule()
elif modelname == "efficientnet_b0":
model = EfficientNetB0Module()
elif modelname == "efficientnet_b7":
model = EfficientNetB7Module()
else:
model = ResNetModule()
@@ -256,7 +482,7 @@ import requests
# Create a set of input signature.
input_signature_img_cls = [
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
tf.TensorSpec(shape=[BATCH_SIZE, 3, 224, 224], dtype=tf.float32),
]
@@ -304,11 +530,14 @@ def preprocess_input_image(model_name):
)
# inputs: {'pixel_values': <tf.Tensor: shape=(1, 3, 224, 224), dtype=float32, numpy=array([[[[]]]], dtype=float32)>}
inputs = feature_extractor(images=image, return_tensors="tf")
inputs["pixel_values"] = tf.tile(
inputs["pixel_values"], [BATCH_SIZE, 1, 1, 1]
)
return [inputs[str(*inputs)]]
def get_causal_image_model(hf_name):
def get_causal_image_model(hf_name, import_args):
model = AutoModelImageClassfication(hf_name)
test_input = preprocess_input_image(hf_name)
# TFSequenceClassifierOutput(loss=None, logits=<tf.Tensor: shape=(1, 1000), dtype=float32, numpy=

View File

@@ -4,10 +4,8 @@ from shark.iree_utils._common import (
get_supported_device_list,
)
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
from parameterized import parameterized
from shark.shark_downloader import download_model
from shark.shark_inference import SharkInference
from shark.parser import shark_args
from parameterized import parameterized
import iree.compiler as ireec
import pytest
import unittest
@@ -15,8 +13,8 @@ import numpy as np
import csv
import tempfile
import os
import sys
import shutil
import multiprocessing
def load_csv_and_convert(filename, gen=False):
@@ -48,7 +46,9 @@ def load_csv_and_convert(filename, gen=False):
)
# This is a pytest workaround
if gen:
with open("tank/dict_configs.py", "w+") as out:
with open(
os.path.join(os.path.dirname(__file__), "dict_configs.py"), "w+"
) as out:
out.write("ALL = [\n")
for c in model_configs:
out.write(str(c) + ",\n")
@@ -68,7 +68,9 @@ def get_valid_test_params():
dynamic_list = (True, False)
# TODO: This is soooo ugly, but for some reason creating the dict at runtime
# results in strange pytest failures.
load_csv_and_convert("tank/all_models.csv", True)
load_csv_and_convert(
os.path.join(os.path.dirname(__file__), "all_models.csv"), True
)
from tank.dict_configs import ALL
config_list = ALL
@@ -135,9 +137,12 @@ class SharkModuleTester:
self.config = config
def create_and_check_module(self, dynamic, device):
shark_args.update_tank = self.update_tank
shark_args.force_update_tank = self.force_update_tank
shark_args.shark_prefix = self.shark_tank_prefix
shark_args.local_tank_cache = self.local_tank_cache
shark_args.force_update_tank = self.update_tank
shark_args.dispatch_benchmarks = self.benchmark_dispatches
if self.benchmark_dispatches is not None:
_m = self.config["model_name"].split("/")
_m.extend([self.config["framework"], str(dynamic), device])
@@ -161,17 +166,40 @@ class SharkModuleTester:
if "winograd" in self.config["flags"]:
shark_args.use_winograd = True
model, func_name, inputs, golden_out = download_model(
self.config["model_name"],
tank_url=self.tank_url,
frontend=self.config["framework"],
)
import_config = {
"batch_size": self.batch_size,
}
from shark.shark_downloader import download_model
from shark.shark_inference import SharkInference
from tank.generate_sharktank import NoImportException
dl_gen_attempts = 2
for i in range(dl_gen_attempts):
try:
model, func_name, inputs, golden_out = download_model(
self.config["model_name"],
frontend=self.config["framework"],
import_args=import_config,
)
except NoImportException as err:
pytest.xfail(
reason=f"Artifacts for this model/config must be generated locally. Please make sure {self.config['framework']} is installed."
)
except AssertionError as err:
if i < dl_gen_attempts - 1:
continue
else:
pytest.xfail(
"Generating OTF may require exiting the subprocess for files to be available."
)
break
is_bench = True if self.benchmark is not None else False
shark_module = SharkInference(
model,
device=device,
mlir_dialect=self.config["dialect"],
is_benchmark=self.benchmark,
is_benchmark=is_bench,
)
try:
@@ -185,6 +213,10 @@ class SharkModuleTester:
result = shark_module(func_name, inputs)
golden_out, result = self.postprocess_outputs(golden_out, result)
if self.tf32 == "true":
print("Validating with relaxed tolerances.")
atol = 1e-02
rtol = 1e-03
try:
np.testing.assert_allclose(
golden_out,
@@ -197,23 +229,31 @@ class SharkModuleTester:
self.save_reproducers()
if self.ci == True:
self.upload_repro()
if self.benchmark == True:
self.benchmark_module(shark_module, inputs, dynamic, device)
if self.benchmark is not None:
self.benchmark_module(
shark_module, inputs, dynamic, device, mode=self.benchmark
)
print(msg)
pytest.xfail(
reason=f"Numerics Mismatch: Use -s flag to print stderr during pytests."
)
if self.benchmark == True:
self.benchmark_module(shark_module, inputs, dynamic, device)
if self.benchmark is not None:
self.benchmark_module(
shark_module, inputs, dynamic, device, mode=self.benchmark
)
if self.save_repro == True:
self.save_reproducers()
def benchmark_module(self, shark_module, inputs, dynamic, device):
def benchmark_module(
self, shark_module, inputs, dynamic, device, mode="native"
):
model_config = {
"batch_size": self.batch_size,
}
shark_args.enable_tf32 = self.tf32
if shark_args.enable_tf32 == True:
shark_module.compile()
shark_args.enable_tf32 = False
shark_args.onnx_bench = self.onnx_bench
shark_module.shark_runner.benchmark_all_csv(
@@ -222,6 +262,8 @@ class SharkModuleTester:
dynamic,
device,
self.config["framework"],
import_args=model_config,
mode=mode,
)
def save_reproducers(self):
@@ -271,6 +313,9 @@ class SharkModuleTest(unittest.TestCase):
@parameterized.expand(param_list, name_func=shark_test_name_func)
def test_module(self, dynamic, device, config):
self.module_tester = SharkModuleTester(config)
self.module_tester.batch_size = self.pytestconfig.getoption(
"batchsize"
)
self.module_tester.benchmark = self.pytestconfig.getoption("benchmark")
self.module_tester.save_repro = self.pytestconfig.getoption(
"save_repro"
@@ -290,7 +335,12 @@ class SharkModuleTest(unittest.TestCase):
self.module_tester.update_tank = self.pytestconfig.getoption(
"update_tank"
)
self.module_tester.tank_url = self.pytestconfig.getoption("tank_url")
self.module_tester.force_update_tank = self.pytestconfig.getoption(
"force_update_tank"
)
self.module_tester.shark_tank_prefix = self.pytestconfig.getoption(
"tank_prefix"
)
self.module_tester.benchmark_dispatches = self.pytestconfig.getoption(
"benchmark_dispatches"
)
@@ -307,19 +357,26 @@ class SharkModuleTest(unittest.TestCase):
if config["xfail_vkm"] == "True" and device in ["metal", "vulkan"]:
pytest.xfail(reason=config["xfail_reason"])
if os.name == "nt" and "enabled_windows" not in config["xfail_other"]:
if (
self.pytestconfig.getoption("ci") == True
and os.name == "nt"
and "enabled_windows" not in config["xfail_other"]
):
pytest.xfail(reason="this model skipped on windows")
# Special cases that need to be marked.
if "macos" in config["xfail_other"] and device in [
"metal",
"vulkan",
]:
if get_vulkan_triple_flag() is not None:
if "m1-moltenvk-macos" in get_vulkan_triple_flag():
pytest.xfail(
reason="conv-related issue on MacStudio, returns VK_ERROR_DEVICE_LOST."
)
if (
"macos" in config["xfail_other"]
and device
in [
"metal",
"vulkan",
]
and sys.platform == "darwin"
):
pytest.skip(
reason="conv-related issue on MacStudio, returns VK_ERROR_DEVICE_LOST."
)
if (
config["model_name"]
in [
@@ -342,6 +399,10 @@ class SharkModuleTest(unittest.TestCase):
pytest.xfail(
reason="Numerics issues: https://github.com/nod-ai/SHARK/issues/476"
)
if config["framework"] == "tf" and self.module_tester.batch_size != 1:
pytest.xfail(
reason="Configurable batch sizes temp. unavailable for tensorflow models."
)
safe_name = (
f"{config['model_name']}_{config['framework']}_{dynamic}_{device}"
)

View File

@@ -19,3 +19,10 @@ facebook/convnext-tiny-224,img
google/vit-base-patch16-224,img
efficientnet-v2-s,keras
bert-large-uncased,hf
t5-base,tfhf_seq2seq
t5-large,tfhf_seq2seq
efficientnet_b0,keras
efficientnet_b7,keras
gpt2,hf_causallm
t5-base,tfhf_seq2seq
t5-large,tfhf_seq2seq
1 model_name model_type
19 google/vit-base-patch16-224 img
20 efficientnet-v2-s keras
21 bert-large-uncased hf
22 t5-base tfhf_seq2seq
23 t5-large tfhf_seq2seq
24 efficientnet_b0 keras
25 efficientnet_b7 keras
26 gpt2 hf_causallm
27 t5-base tfhf_seq2seq
28 t5-large tfhf_seq2seq

View File

@@ -1,4 +1,6 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
microsoft/MiniLM-L12-H384-uncased,True,hf,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
bert-base-uncased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
@@ -18,4 +20,4 @@ nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encod
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
1 model_name use_tracing model_type dynamic param_count tags notes
2 efficientnet_b0 True vision False 5.3M image-classification;cnn;conv2d;depthwise-conv Smallest EfficientNet variant with 224x224 input
3 efficientnet_b7 True vision False 66M image-classification;cnn;conv2d;depthwise-conv Largest EfficientNet variant with 600x600 input
4 microsoft/MiniLM-L12-H384-uncased True hf True 66M nlp;bert-variant;transformer-encoder Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)
5 bert-base-uncased True hf True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
6 bert-base-cased True hf True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
20 mnasnet1_0 False vision True - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
21 resnet50_fp16 False vision True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
22 bert-base-uncased_fp16 True fp16 False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
23 bert-large-uncased True hf True 330M nlp;bert-variant;transformer-encoder 24 layers, 1024 hidden units, 16 attention heads

View File

@@ -1,4 +1,3 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
stabilityai/stable-diffusion-2-1-base,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
stabilityai/stable-diffusion-2-1,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
prompthero/openjourney,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
1 model_name use_tracing model_type dynamic param_count tags notes
2 stabilityai/stable-diffusion-2-1-base True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A
3 stabilityai/stable-diffusion-2-1 True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A
prompthero/openjourney True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A

3
tank_version.json Normal file
View File

@@ -0,0 +1,3 @@
{
"version": "2023-03-31_02d52bb"
}