Compare commits

..

33 Commits

Author SHA1 Message Date
yzhang93
8f84258fb8 Fix check for use_tuned conditions (#1252) 2023-03-27 11:21:25 -07:00
Ean Garvey
7619e76bbd Disable and xfail some models that fail validation/compilation. (#1251)
* Rollback T5 models for torch as the inputs give some issues that aren't trivial to resolve
* xfail efficientnet-b0 on torch+cuda -- see CUDA requesting shared memory size larger than allowed size openxla/iree#12771
2023-03-27 12:42:53 -05:00
Daniel Garvey
9267eadbfa disable openjourney gen for nightly (#1249) 2023-03-27 11:55:34 -05:00
Phaneesh Barwaria
431132b8ee Fix img2img mode switch (#1247)
* add updated scheduler value in global config

* clear scheduler global variable with others
2023-03-27 07:01:22 -07:00
cstueckrath
fb35e13e7a fix Python version detection bug (#1246)
* fix Python version detection bug

* Update setup_venv.ps1
2023-03-27 07:00:40 -07:00
yzhang93
17a67897d1 Add SD v2.1 768x768 tuned model (#1244)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-24 10:39:15 -07:00
Gaurav Shukla
da449b73aa [SD] Disable lora training tab for now (#1241)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-24 09:16:24 -07:00
Kyle Herndon
0b0526699a Fix incorrect device argument initialization for LoRA training by extracting the device type and number and formatting it for pytorch (#1237)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
2023-03-24 01:10:50 -07:00
Boian Petkantchin
4fac46f7bb In models testing fix paths to be relative to the script dir not cwd (#1128)
authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-03-22 15:26:52 -05:00
Daniel Garvey
49925950f1 fix false positives (#1193) 2023-03-22 15:25:39 -05:00
Thomas
807947c0c8 Remove deprecated cli option iree-hal-cuda-disable-loop-nounroll-wa (#1235) 2023-03-22 12:05:15 -05:00
Abhishek Varma
593428bda4 [SD] Fix for transformers/__init__.py issue in PyInstaller (#1233)
-- This commit fixes the transformers/__init__.py issue in PyInstaller.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-22 08:43:53 -07:00
Abhishek Varma
cede9b4fec [SD] Fix custom_vae as a required parameter in inpaint (#1232) 2023-03-22 04:30:17 -07:00
Prashant Kumar
c2360303f0 Add the int8 quantized model. 2023-03-22 16:28:13 +05:30
jinchen62
420366c1b8 Move schedulers to global obj (#1225) 2023-03-21 22:40:43 -07:00
Ean Garvey
d31bae488c Set iree-input-type to tm_tensor for SD (#1228) 2023-03-21 19:07:31 -07:00
Kyle Herndon
c23fcf3748 Fix incorrect device argument initialization for LoRA training (#1231)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-21 19:07:18 -07:00
jinchen62
7dbbb1726a Fix SD obj not defined if fail to get models from pretrained (#1222) 2023-03-21 07:55:17 -07:00
Abhishek Varma
8b8cc7fd33 [SD] Update LoRA inference to handle various checkpoints (#1215) 2023-03-21 06:52:20 -07:00
Ean Garvey
e3c96a2b9d Move sentencepiece to importer requirements. (#1218) 2023-03-21 00:39:57 -05:00
Ean Garvey
5e3f50647d Set --vulkan_large_heap_block_size default to 2gb. (#1220) 2023-03-20 21:07:09 -07:00
gpetters94
7899e1803a Add fix for attention slicing fp16 (#1217) 2023-03-20 19:11:29 -07:00
mariecwhite
d105246b9c Fix t5 models 2023-03-21 10:39:59 +11:00
mariecwhite
90c958bca2 Add T5-base and T5-large Torch and TF Models (#1116) 2023-03-20 17:32:50 -05:00
mariecwhite
f99903e023 Add EfficientNet B0 and B7 Torch and TF models 2023-03-21 09:22:05 +11:00
mariecwhite
c6f44ef1b3 Add EfficientNet B0 and B7 Torch and TF models 2023-03-21 09:14:45 +11:00
mariecwhite
8dcd4d5aeb Make batch size configurable 2023-03-20 18:03:17 -04:00
Phoenix Meadowlark
d319f4684e Add peak memory reporting for IREE, TF and PyTorch (#1216) 2023-03-20 15:40:49 -05:00
Ean Garvey
54d7b6d83e Generate model artifacts in pytests if they don't exist in the cloud. (#1121)
* Add gen_shark_files fn to shark_downloader for OTF artifact generation

* add generate_sharktank as a tank/ python module.

* Fix some paths in tank generation.
2023-03-20 12:13:19 -05:00
m68k-fr
4a622532e5 [Web] Stop images (#1212) 2023-03-19 14:37:30 -07:00
cstueckrath
650b2ada58 add pytorch_lightning to requirements (#1211)
* add pytorch_lightning to requirements

this will additionally add lightning-utilities and torchmetrics

* Update shark_sd.spec

* Update shark_sd_cli.spec
2023-03-19 12:29:54 -07:00
m68k-fr
f87f8949f3 [Web] CSS fix for gradio V3.22.1 (#1210) 2023-03-19 06:13:59 -07:00
m68k-fr
7dc9bf8148 [Web] Move "stop Batch" button to "Advanced Options" toggle (#1209) 2023-03-18 20:54:42 -07:00
52 changed files with 1310 additions and 592 deletions

View File

@@ -2,6 +2,7 @@ import sys
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
@@ -12,10 +13,9 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -76,11 +76,13 @@ def img2img_inf(
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
@@ -96,10 +98,6 @@ def img2img_inf(
image = init_image.convert("RGB")
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -114,14 +112,9 @@ def img2img_inf(
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -155,7 +148,7 @@ def img2img_inf(
height,
width,
device,
use_lora=use_lora,
use_lora=args.use_lora,
use_stencil=use_stencil,
)
if (
@@ -164,6 +157,7 @@ def img2img_inf(
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
@@ -178,8 +172,8 @@ def img2img_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
if use_stencil is not None:
args.use_tuned = False
@@ -200,7 +194,7 @@ def img2img_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
)
)
else:
@@ -220,11 +214,11 @@ def img2img_inf(
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.set_sd_scheduler(args.scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
@@ -232,6 +226,7 @@ def img2img_inf(
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"STRENGTH": strength}
text_output = ""
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
@@ -252,23 +247,20 @@ def img2img_inf(
cpu_scheduling,
use_stencil=use_stencil,
)
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, strength={args.strength}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -10,10 +11,9 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -47,11 +47,13 @@ def inpaint_inf(
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
@@ -62,10 +64,6 @@ def inpaint_inf(
args.mask_path = "not none"
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -80,14 +78,9 @@ def inpaint_inf(
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -104,7 +97,7 @@ def inpaint_inf(
height,
width,
device,
use_lora=use_lora,
use_lora=args.use_lora,
use_stencil=None,
)
if (
@@ -114,6 +107,7 @@ def inpaint_inf(
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
@@ -128,14 +122,15 @@ def inpaint_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
@@ -143,14 +138,13 @@ def inpaint_inf(
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
@@ -159,6 +153,7 @@ def inpaint_inf(
img_seed = utils.sanitize_seed(seed)
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
@@ -180,23 +175,20 @@ def inpaint_inf(
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":
@@ -229,6 +221,7 @@ if __name__ == "__main__":
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
@@ -236,7 +229,6 @@ if __name__ == "__main__":
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
@@ -10,10 +11,9 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -50,11 +50,13 @@ def outpaint_inf(
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
@@ -64,10 +66,6 @@ def outpaint_inf(
args.img_path = "not none"
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -82,14 +80,9 @@ def outpaint_inf(
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -106,7 +99,7 @@ def outpaint_inf(
height,
width,
device,
use_lora=use_lora,
use_lora=args.use_lora,
use_stencil=None,
)
if (
@@ -116,6 +109,7 @@ def outpaint_inf(
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
@@ -130,8 +124,8 @@ def outpaint_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
OutpaintPipeline.from_pretrained(
scheduler_obj,
@@ -146,11 +140,11 @@ def outpaint_inf(
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=use_lora,
use_lora=args.use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
@@ -163,6 +157,7 @@ def outpaint_inf(
top = True if "up" in directions else False
bottom = True if "down" in directions else False
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
@@ -189,23 +184,20 @@ def outpaint_inf(
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":

View File

@@ -159,9 +159,6 @@ class LoraDataset(Dataset):
return example
schedulers = None
########## Setting up the model ##########
def lora_train(
prompt: str,
@@ -187,8 +184,6 @@ def lora_train(
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
print(
"Note LoRA training is not compatible with the latest torch-mlir branch"
)
@@ -227,7 +222,12 @@ def lora_train(
args.max_length = max_length
args.height = height
args.width = width
args.device = device
device_str = device.split("=>", 1)[1].strip().split("://")
if len(device_str) > 1:
device_str = device_str[0] + ":" + device_str[1]
else:
device_str = device_str[0]
args.device = device_str
# Load the Stable Diffusion model
text_encoder = CLIPTextModel.from_pretrained(

View File

@@ -1,4 +1,5 @@
import torch
import transformers
import time
from apps.stable_diffusion.src import (
args,
@@ -9,10 +10,9 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -43,11 +43,13 @@ def txt2img_inf(
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
@@ -56,10 +58,6 @@ def txt2img_inf(
args.scheduler = scheduler
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -77,14 +75,9 @@ def txt2img_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
@@ -98,7 +91,7 @@ def txt2img_inf(
height,
width,
device,
use_lora=use_lora,
use_lora=args.use_lora,
use_stencil=None,
)
if (
@@ -108,6 +101,7 @@ def txt2img_inf(
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
@@ -123,8 +117,8 @@ def txt2img_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
@@ -141,17 +135,18 @@ def txt2img_inf(
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
@@ -169,25 +164,20 @@ def txt2img_inf(
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += (
f"\nsteps={steps}, guidance_scale={guidance_scale}, seed={seeds}"
)
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
# text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":
@@ -200,7 +190,6 @@ if __name__ == "__main__":
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
use_lora = args.use_lora
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
@@ -216,7 +205,8 @@ if __name__ == "__main__":
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
)
for current_batch in range(args.batch_count):

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
@@ -12,8 +13,6 @@ from apps.stable_diffusion.src import (
)
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -41,15 +40,16 @@ def upscaler_inf(
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
@@ -62,10 +62,6 @@ def upscaler_inf(
image = init_image.convert("RGB").resize((height, width))
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -83,6 +79,10 @@ def upscaler_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
args.height = 128
@@ -97,7 +97,7 @@ def upscaler_inf(
args.height,
args.width,
device,
use_lora=None,
use_lora=args.use_lora,
use_stencil=None,
)
if (
@@ -118,8 +118,8 @@ def upscaler_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
UpscalerPipeline.from_pretrained(
scheduler_obj,
@@ -135,11 +135,14 @@ def upscaler_inf(
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.get_sd_obj().low_res_scheduler = schedulers["DDPM"]
global_obj.set_sd_scheduler(scheduler)
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
"DDPM"
)
start_time = time.time()
global_obj.get_sd_obj().log = ""
@@ -232,6 +235,7 @@ if __name__ == "__main__":
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ddpm_scheduler=schedulers["DDPM"],
)

View File

@@ -21,6 +21,7 @@ datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv-python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')

View File

@@ -22,6 +22,7 @@ datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('iree')

View File

@@ -18,6 +18,7 @@ from apps.stable_diffusion.src.utils import (
get_path_stem,
get_extended_name,
get_stencil_model_id,
update_lora_weight,
)
@@ -99,7 +100,8 @@ class SharkifyStableDiffusionModel:
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
use_lora: str = ""
use_lora: str = "",
use_quantize: str = None,
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -107,6 +109,7 @@ class SharkifyStableDiffusionModel:
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.use_quantize = use_quantize
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
@@ -200,6 +203,7 @@ class SharkifyStableDiffusionModel:
use_tuned=self.use_tuned,
model_name=self.model_name["vae_encode"],
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_vae_encode
@@ -255,6 +259,7 @@ class SharkifyStableDiffusionModel:
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_vae
@@ -281,13 +286,14 @@ class SharkifyStableDiffusionModel:
use_tuned=self.use_tuned,
model_name=self.model_name["vae"],
extra_args=get_opt_flags("vae", precision="fp32"),
base_model_id=self.base_model_id,
)
return shark_vae
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -295,6 +301,8 @@ class SharkifyStableDiffusionModel:
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
@@ -333,6 +341,7 @@ class SharkifyStableDiffusionModel:
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_controlled_unet
@@ -386,6 +395,7 @@ class SharkifyStableDiffusionModel:
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_cnet
@@ -399,7 +409,7 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
self.unet.load_attn_procs(use_lora)
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
@@ -444,6 +454,7 @@ class SharkifyStableDiffusionModel:
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_unet
@@ -481,18 +492,21 @@ class SharkifyStableDiffusionModel:
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_unet
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
def forward(self, input):
return self.text_encoder(input)[0]
@@ -512,6 +526,7 @@ class SharkifyStableDiffusionModel:
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
)
return shark_clip
@@ -539,6 +554,7 @@ class SharkifyStableDiffusionModel:
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
# configiration.
def compile_all(self, base_model_id, need_vae_encode, need_stencil):
self.base_model_id = base_model_id
self.inputs = get_input_info(
base_models[base_model_id],
self.max_len,
@@ -556,7 +572,12 @@ class SharkifyStableDiffusionModel:
compiled_controlnet = self.get_control_net()
compiled_controlled_unet = self.get_controlled_unet()
else:
compiled_unet = self.get_unet()
# TODO: Plug the experimental "int8" support at right place.
if self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
compiled_unet = get_unet()
else:
compiled_unet = self.get_unet()
if self.custom_vae != "":
print("Plugging in custom Vae")
compiled_vae = self.get_vae()

View File

@@ -20,6 +20,15 @@ hf_model_variant_map = {
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
bucket_key = "gs://shark_tank/prashant_nod"
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
@@ -41,6 +50,12 @@ def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
# TODO: Get the quantize model from model_db.json
if args.use_quantize == "int8":
bk, mk, flags = get_quantize_model()
return get_shark_model(bk, mk, flags)
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"

View File

@@ -31,6 +31,9 @@ from apps.stable_diffusion.src.utils import (
end_profiling,
)
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
class StableDiffusionPipeline:
def __init__(
@@ -58,6 +61,7 @@ class StableDiffusionPipeline:
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -226,6 +230,7 @@ class StableDiffusionPipeline:
masked_image_latents=None,
return_all_latents=False,
):
self.status = SD_STATE_IDLE
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
@@ -275,6 +280,9 @@ class StableDiffusionPipeline:
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -313,6 +321,7 @@ class StableDiffusionPipeline:
use_stencil: str = None,
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
is_inpaint = cls.__name__ in [
"InpaintPipeline",
@@ -341,6 +350,7 @@ class StableDiffusionPipeline:
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
use_quantize=use_quantize,
)
if cls.__name__ in [
"Image2ImagePipeline",

View File

@@ -32,4 +32,6 @@ from apps.stable_diffusion.src.utils.utils import (
get_extended_name,
clear_all,
save_output_img,
get_generation_text_info,
update_lora_weight,
)

View File

@@ -76,18 +76,19 @@ def load_winograd_configs():
return winograd_config_dir
def load_lower_configs():
def load_lower_configs(base_model_id=None):
from apps.stable_diffusion.src.models import get_variant_version
from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
)
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
if not base_model_id:
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
variant, version = get_variant_version(base_model_id)
@@ -114,7 +115,14 @@ def load_lower_configs():
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
and args.width == 768
):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
@@ -212,7 +220,7 @@ def annotate_with_lower_configs(
return bytecode
def sd_model_annotation(mlir_model, model_name):
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
@@ -220,7 +228,7 @@ def sd_model_annotation(mlir_model, model_name):
winograd_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs()
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
winograd_model, lowering_config_dir, model_name, use_winograd
)
@@ -232,7 +240,7 @@ def sd_model_annotation(mlir_model, model_name):
)
else:
use_winograd = False
lowering_config_dir = load_lower_configs()
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)

View File

@@ -340,6 +340,14 @@ p.add_argument(
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
@@ -360,7 +368,7 @@ p.add_argument(
p.add_argument(
"--vulkan_large_heap_block_size",
default="4147483648",
default="2073741824",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)

View File

@@ -9,6 +9,8 @@ from pathlib import Path
import numpy as np
from random import randint
import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
@@ -21,7 +23,7 @@ from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
download_from_original_stable_diffusion_ckpt,
)
@@ -78,7 +80,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
mlir_model, device=args.device, mlir_dialect="tm_tensor"
)
return _compile_module(shark_module, model_name, extra_args)
@@ -95,6 +97,7 @@ def compile_through_fx(
debug=False,
generate_vmfb=True,
extra_args=[],
base_model_id=None,
):
from shark.parser import shark_args
@@ -116,19 +119,21 @@ def compile_through_fx(
if use_tuned:
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
mlir_module = sd_model_annotation(mlir_module, model_name)
mlir_module = sd_model_annotation(
mlir_module, model_name, base_model_id
)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
mlir_dialect="tm_tensor",
)
if generate_vmfb:
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
mlir_dialect="tm_tensor",
)
del mlir_module
gc.collect()
@@ -264,8 +269,9 @@ def set_init_device_flags():
if (
args.precision != "fp16"
or args.height != 512
or args.width != 512
or args.height not in [512, 768]
or (args.height == 512 and args.width != 512)
or (args.height == 768 and args.width != 768)
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
@@ -299,6 +305,20 @@ def set_init_device_flags():
]:
args.use_tuned = False
elif (
args.height == 768
and args.width == 768
and (
base_model_id
not in [
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
]
or "rdna3" not in args.iree_vulkan_target_triple
)
):
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
else:
@@ -454,7 +474,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
num_in_channels = 9 if is_inpaint else 4
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
@@ -464,6 +484,115 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
print("Loading complete")
def processLoRA(model, use_lora, splitting_prefix):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
alpha = 0.75
visited = []
# directly update weight in model
process_unet = "te" not in splitting_prefix
for key in state_dict:
if ".alpha" in key or key in visited:
continue
curr_layer = model
if ("text" not in key and process_unet) or (
"text" in key and not process_unet
):
layer_infos = (
key.split(".")[0].split(splitting_prefix)[-1].split("_")
)
else:
continue
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = (
state_dict[pair_keys[0]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
weight_down = (
state_dict[pair_keys[1]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
return model
def update_lora_weight_for_unet(unet, use_lora):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
return unet
main_file_name = get_path_stem(use_lora)
if ".bin" in use_lora:
main_file_name += ".bin"
elif ".safetensors" in use_lora:
main_file_name += ".safetensors"
elif ".pt" in use_lora:
main_file_name += ".pt"
else:
sys.exit("Only .bin and .safetensors format for LoRA is supported")
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_")
def update_lora_weight(model, use_lora, model_name):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora)
try:
return processLoRA(model, use_lora, "lora_te_")
except:
return None
def load_vmfb(vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
@@ -629,3 +758,14 @@ def save_output_img(output_img, img_seed, extra_info={}):
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
return text_output

View File

@@ -1,5 +1,6 @@
import os
import sys
import transformers
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
@@ -66,7 +67,7 @@ from apps.stable_diffusion.web.ui import (
)
# init global sd pipeline and config
global_obj.init()
global_obj._init()
def register_button_click(button, selectedid, inputs, outputs):
@@ -94,6 +95,8 @@ with gr.Blocks(
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.Tabs(visible=False) as experimental_tabs:
with gr.TabItem(label="LoRA Training", id=5):
lora_train_web.render()

View File

@@ -7,93 +7,100 @@ Procedure to upgrade the dark theme:
*/
:root {
--color-accent-soft: var(--neutral-700);
--color-background-primary: var(--neutral-950);
--color-background-secondary: var(--neutral-900);
--color-border-accent: var(--neutral-600);
--color-border-primary: var(--neutral-700);
--text-color-code-background: var(--neutral-800);
--text-color-link-active: var(--secondary-500);
--text-color-link: var(--secondary-500);
--text-color-link-hover: var(--secondary-400);
--text-color-link-visited: var(--secondary-600);
--text-color-subdued: var(--neutral-400);
--body-background-color: var(--color-background-primary);
--body-background-fill: var(--background-fill-primary);
--body-text-color: var(--neutral-100);
--color-accent-soft: var(--neutral-700);
--background-fill-primary: var(--neutral-950);
--background-fill-secondary: var(--neutral-900);
--border-color-accent: var(--neutral-600);
--border-color-primary: var(--neutral-700);
--link-text-color-active: var(--secondary-500);
--link-text-color: var(--secondary-500);
--link-text-color-hover: var(--secondary-400);
--link-text-color-visited: var(--secondary-600);
--body-text-color-subdued: var(--neutral-400);
--shadow-spread: 1px;
--block-background: var(--neutral-800);
--block-border-color: var(--color-border-primary);
--block-border-width: 1px;
--block-info-color: var(--text-color-subdued);
--block-label-background: var(--color-background-secondary);
--block-label-border-color: var(--color-border-primary);
--block-label-border-width: 1px;
--block-label-color: var(--neutral-200);
--block-shadow: none;
--block-title-background: none;
--block-title-border-color: none;
--block-title-border-width: 0px;
--block-title-color: var(--neutral-200);
--panel-background: var(--color-background-secondary);
--panel-border-color: var(--color-border-primary);
--checkbox-background: var(--neutral-800);
--checkbox-background-focus: var(--checkbox-background);
--checkbox-background-hover: var(--checkbox-background);
--checkbox-background-selected: var(--secondary-600);
--block-background-fill: var(--neutral-800);
--block-border-color: var(--border-color-primary);
--block_border_width: None;
--block-info-text-color: var(--body-text-color-subdued);
--block-label-background-fill: var(--background-fill-secondary);
--block-label-border-color: var(--border-color-primary);
--block_label_border_width: None;
--block-label-text-color: var(--neutral-200);
--block_shadow: None;
--block_title_background_fill: None;
--block_title_border_color: None;
--block_title_border_width: None;
--block-title-text-color: var(--neutral-200);
--panel-background-fill: var(--background-fill-secondary);
--panel-border-color: var(--border-color-primary);
--panel_border_width: None;
--checkbox-background-color: var(--neutral-800);
--checkbox-background-color-focus: var(--checkbox-background-color);
--checkbox-background-color-hover: var(--checkbox-background-color);
--checkbox-background-color-selected: var(--secondary-600);
--checkbox-border-color: var(--neutral-700);
--checkbox-border-color-focus: var(--secondary-500);
--checkbox-border-color-hover: var(--neutral-600);
--checkbox-border-color-selected: var(--secondary-600);
--checkbox-label-background: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-selected: var(--checkbox-label-background);
--checkbox-label-border-color: var(--color-border-primary);
--checkbox-label-border-color-hover: var(--color-border-primary);
--checkbox-text-color: var(--body-text-color);
--checkbox-text-color-selected: var(--checkbox-text-color);
--error-background: var(--color-background-primary);
--error-border-color: var(--color-border-primary);
--error-border-width: var(--error-border-width);
--error-color: #ef4444;
--input-background: var(--neutral-800);
--input-background-focus: var(--secondary-600);
--input-background-hover: var(--input-background);
--input-border-color: var(--color-border-primary);
--checkbox-border-width: var(--input-border-width);
--checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
--checkbox-label-border-color: var(--border-color-primary);
--checkbox-label-border-color-hover: var(--checkbox-label-border-color);
--checkbox-label-border-width: var(--input-border-width);
--checkbox-label-text-color: var(--body-text-color);
--checkbox-label-text-color-selected: var(--checkbox-label-text-color);
--error-background-fill: var(--background-fill-primary);
--error-border-color: var(--border-color-primary);
--error_border_width: None;
--error-text-color: #ef4444;
--input-background-fill: var(--neutral-800);
--input-background-fill-focus: var(--secondary-600);
--input-background-fill-hover: var(--input-background-fill);
--input-border-color: var(--border-color-primary);
--input-border-color-focus: var(--neutral-700);
--input-border-color-hover: var(--color-border-primary);
--input-border-color-hover: var(--input-border-color);
--input_border_width: None;
--input-placeholder-color: var(--neutral-500);
--input-shadow: var(--input-shadow);
--input_shadow: None;
--input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset);
--loader-color: var(--color-accent);
--stat-color-background: linear-gradient(to right, var(--primary-400), var(--primary-600));
--loader_color: None;
--slider_color: None;
--stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600));
--table-border-color: var(--neutral-700);
--table-even-background: var(--neutral-950);
--table-odd-background: var(--neutral-900);
--table-even-background-fill: var(--neutral-950);
--table-odd-background-fill: var(--neutral-900);
--table-row-focus: var(--color-accent-soft);
--button-cancel-background: linear-gradient(to bottom right, #dc2626, #b91c1c);
--button-cancel-background-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
--button-border-width: var(--input-border-width);
--button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
--button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
--button-cancel-border-color: #dc2626;
--button-cancel-border-color-hover: var(--button-cancel-border-color);
--button-cancel-text-color: white;
--button-cancel-text-color-hover: var(--button-cancel-text-color);
--button-primary-background: linear-gradient(to bottom right, var(--primary-600), var(--primary-700));
--button-primary-background-hover: linear-gradient(to bottom right, var(--primary-600), var(--primary-600));
--button-primary-border-color: var(--primary-600);
--button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600));
--button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500));
--button-primary-border-color: var(--primary-500);
--button-primary-border-color-hover: var(--button-primary-border-color);
--button-primary-text-color: white;
--button-primary-text-color-hover: var(--button-primary-text-color);
--button-secondary-background: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
--button-secondary-background-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
--button-secondary-border-color: var(--neutral-600);
--button-secondary-border-color-hover: var(--button-secondary-border-color);
--button-secondary-text-color: white;
--button-secondary-text-color-hover: var(--button-secondary-text-color);
--block-border-width: 1px;
--block-label-border-width: 1px;
--form-gap-width: 1px;
--error-border-width: 1px;
--input-border-width: 1px;
}
/* SHARK theme */
body {
background-color: var(--color-background-primary);
}
/* display in full width for desktop devices */
@media (min-width: 1536px)
@@ -131,7 +138,7 @@ body {
}
#prompt_box textarea, #negative_prompt_box textarea {
background-color: var(--color-background-primary) !important;
background-color: var(--background-fill-primary) !important;
}
#prompt_examples {
@@ -143,7 +150,6 @@ body {
}
#ui_body {
background-color: var(--color-background-secondary) !important;
padding: var(--size-2) !important;
border-radius: 0.5em !important;
}
@@ -172,6 +178,7 @@ footer {
/* Hide "remove buttons" from ui dropdowns */
#custom_model .token-remove.remove-all,
#lora_weights .token-remove.remove-all,
#scheduler .token-remove.remove-all,
#device .token-remove.remove-all,
#stencil_model .token-remove.remove-all {

View File

@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
scheduler_list,
predefined_models,
cancel_sd,
)
@@ -76,10 +77,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -144,21 +145,23 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Denoising Strength",
)
with gr.Row():
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
@@ -168,6 +171,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -189,8 +193,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -253,6 +255,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
scheduler_list,
predefined_paint_models,
cancel_sd,
)
@@ -71,10 +72,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -146,21 +147,23 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Row():
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
@@ -170,6 +173,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -191,8 +195,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -255,6 +257,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -111,22 +111,25 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
label="CFG Scale",
)
with gr.Row():
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -148,8 +151,6 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
)
with gr.Column(scale=6):
train_lora = gr.Button("Train LoRA")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
@@ -201,4 +202,4 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
prompt_submit = prompt.submit(**kwargs)
train_click = train_lora.click(**kwargs)
clear_queue.click(fn=None, cancels=[prompt_submit, train_click])
stop_batch.click(fn=None, cancels=[prompt_submit, train_click])

View File

@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
scheduler_list,
predefined_paint_models,
cancel_sd,
)
@@ -68,10 +69,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -165,21 +166,23 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
1, 100, value=20, step=1, label="Steps"
)
with gr.Row():
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
@@ -189,6 +192,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -210,8 +214,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -275,6 +277,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
scheduler_list_txt2img,
predefined_models,
cancel_sd,
)
with gr.Blocks(title="Text-to-Image") as txt2img_web:
@@ -72,10 +73,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -140,22 +141,25 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="CFG Scale",
)
with gr.Row():
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -177,8 +181,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
@@ -247,8 +249,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
from apps.stable_diffusion.web.utils.png_metadata import (

View File

@@ -65,6 +65,21 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Input Image", type="pil"
).style(height=300)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -129,21 +144,23 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Noise Level",
)
with gr.Row():
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
@@ -153,6 +170,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -174,8 +192,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -225,6 +241,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
],
outputs=[upscaler_gallery, std_output],
show_progress=args.progress_bar,
@@ -233,6 +251,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
)

View File

@@ -5,6 +5,10 @@ import glob
from pathlib import Path
from apps.stable_diffusion.src import args
from dataclasses import dataclass
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
@dataclass
@@ -70,24 +74,56 @@ def resource_path(relative_path):
return os.path.join(base_path, relative_path)
def get_custom_model_path():
return Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd(), "models")
def get_custom_model_path(model="models"):
match model:
case "models":
return Path(Path.cwd(), "models")
case "vae":
return Path(Path.cwd(), "models/vae")
case "lora":
return Path(Path.cwd(), "models/lora")
case _:
return ""
def get_custom_model_pathfile(custom_model_name):
return os.path.join(get_custom_model_path(), custom_model_name)
def get_custom_model_pathfile(custom_model_name, model="models"):
return os.path.join(get_custom_model_path(model), custom_model_name)
def get_custom_model_files():
def get_custom_model_files(model="models"):
ckpt_files = []
for extn in custom_model_filetypes:
file_types = custom_model_filetypes
if model == "lora":
file_types = custom_model_filetypes + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_custom_model_path(), extn))
for x in glob.glob(
os.path.join(get_custom_model_path(model), extn)
)
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_custom_vae_or_lora_weights(weights, hf_id, model):
use_weight = ""
if weights == "None" and not hf_id:
use_weight = ""
elif not hf_id:
use_weight = get_custom_model_pathfile(weights, model)
else:
use_weight = hf_id
return use_weight
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:
global_obj.set_sd_status(SD_STATE_CANCEL)
except Exception:
pass
nodlogo_loc = resource_path("logos/nod-logo.png")
available_devices = get_available_devices()

View File

@@ -8,39 +8,64 @@ Also we could avoid memory leak when switching models by clearing the cache.
"""
def init():
global sd_obj
global config_obj
sd_obj = None
config_obj = None
def _init():
global _sd_obj
global _config_obj
global _schedulers
_sd_obj = None
_config_obj = None
_schedulers = None
def set_sd_obj(value):
global sd_obj
sd_obj = value
global _sd_obj
_sd_obj = value
def set_sd_scheduler(key):
global _sd_obj
_sd_obj.scheduler = _schedulers[key]
def set_sd_status(value):
global _sd_obj
_sd_obj.status = value
def set_cfg_obj(value):
global config_obj
config_obj = value
global _config_obj
_config_obj = value
def set_schedulers(value):
global sd_obj
sd_obj.scheduler = value
global _schedulers
_schedulers = value
def get_sd_obj():
return sd_obj
return _sd_obj
def get_sd_status():
return _sd_obj.status
def get_cfg_obj():
return config_obj
return _config_obj
def get_scheduler(key):
return _schedulers[key]
def clear_cache():
global sd_obj
global config_obj
del sd_obj
del config_obj
global _sd_obj
global _config_obj
global _schedulers
del _sd_obj
del _config_obj
del _schedulers
gc.collect()
_sd_obj = None
_config_obj = None
_schedulers = None

View File

@@ -2,4 +2,4 @@
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python generate_sharktank.py
python tank/generate_sharktank.py

View File

@@ -87,11 +87,22 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]
counter = 0
for import_opt in import_options:
for model_name in hf_model_names:
if model_name in to_skip:
continue
for use_tune in tuned_options:
if (
model_name == "stabilityai/stable-diffusion-2-1"
and use_tune == tuned_options[0]
):
continue
elif (
model_name == "stabilityai/stable-diffusion-2-1-base"
and use_tune == tuned_options[1]
):
continue
command = (
[
executable, # executable is the python from the venv used to run this
@@ -174,9 +185,23 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
else:
print(command)
print("failed to generate image for this configuration")
if "2_1_base" in model_name:
with open(dumpfile_name, "r+") as f:
output = f.readlines()
print("\n".join(output))
if model_name == "CompVis/stable-diffusion-v1-4":
print("failed a known successful model.")
exit(1)
if os.name == "nt":
counter += 1
if counter % 2 == 0:
extra_flags.append(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
else:
if counter != 1:
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)

View File

@@ -6,36 +6,16 @@ from distutils.sysconfig import get_python_lib
import fileinput
from pathlib import Path
# Diffusers 0.13.1 fails with transformers __init.py errros in BLIP. So remove it for now until we fork it
pix2pix_init = Path(get_python_lib() + "/diffusers/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(get_python_lib() + "/diffusers/pipelines/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(
get_python_lib() + "/diffusers/pipelines/stable_diffusion/__init__.py"
# Temorary workaround for transformers/__init__.py.
path_to_tranformers_hook = Path(
get_python_lib()
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
)
for line in fileinput.input(pix2pix_init, inplace=True):
if "StableDiffusionPix2PixZeroPipeline" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
if path_to_tranformers_hook.is_file():
pass
else:
with open(path_to_tranformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")

View File

@@ -33,6 +33,7 @@ lit
pyyaml
python-dateutil
sacremoses
sentencepiece
# web dependecies.
gradio

View File

@@ -16,7 +16,7 @@ parameterized
# Add transformers, diffusers and scipy since it most commonly used
transformers
diffusers
diffusers @ git+https://github.com/huggingface/diffusers@main
scipy
ftfy
gradio
@@ -25,6 +25,7 @@ omegaconf
safetensors
opencv-python
scikit-image
pytorch_lightning # for runwayml models
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile

View File

@@ -45,7 +45,7 @@ if ($arguments -eq "--force"){
Remove-Item .\shark.venv -Force -Recurse
if (Test-Path .\shark.venv\) {
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
break
exit 1
}
}
}
@@ -78,12 +78,12 @@ if (!($PyVer.length -ne 0)) {$p} # return Python --version String if py.exe is u
if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any list
{
Write-Host "Please install Python 3.11 and try again"
break
exit 34
}
Write-Host "Installing Build Dependencies"
# make sure we really use 3.11 from list, even if it's not the default.
if (!($PyVer.length -ne 0)) {py -3.11 -m venv .\shark.venv\}
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
else {python -m venv .\shark.venv\}
.\shark.venv\Scripts\activate
python -m pip install --upgrade pip

View File

@@ -35,8 +35,9 @@ def run_cmd(cmd, debug=False):
stderr=subprocess.PIPE,
check=True,
)
result_str = result.stdout.decode()
return result_str
stdout = result.stdout.decode()
stderr = result.stderr.decode()
return stdout, stderr
except subprocess.CalledProcessError as e:
print(e.output)
sys.exit(f"Exiting program due to error running {cmd}")

View File

@@ -90,6 +90,7 @@ def build_benchmark_args(
benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}")
# if time_extractor:
# benchmark_cl.append(time_extractor)
benchmark_cl.append(f"--print_statistics=true")
return benchmark_cl
@@ -129,7 +130,8 @@ def build_benchmark_args_non_tensor_input(
def run_benchmark_module(benchmark_cl):
"""
Run benchmark command, extract result and return iteration/seconds.
Run benchmark command, extract result and return iteration/seconds, host
peak memory, and device peak memory.
# TODO: Add an example of the benchmark command.
Input: benchmark command.
@@ -138,15 +140,22 @@ def run_benchmark_module(benchmark_cl):
assert os.path.exists(
benchmark_path
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
bench_result = run_cmd(" ".join(benchmark_cl))
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
try:
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
match = regex_split.search(bench_stdout)
time_ms = float(match.group(1))
unit = match.group(3)
except AttributeError:
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
match = regex_split.search(bench_stdout)
time_ms = float(match.group(1))
unit = match.group(2)
return 1.0 / (time * 0.001)
iter_per_second = 1.0 / (time_ms * 0.001)
# Extract peak memory.
host_regex = re.compile(r".*HOST_LOCAL:\s*([0-9]+)B peak")
host_peak_b = int(host_regex.search(bench_stderr).group(1))
device_regex = re.compile(r".*DEVICE_LOCAL:\s*([0-9]+)B peak")
device_peak_b = int(device_regex.search(bench_stderr).group(1))
return iter_per_second, host_peak_b, device_peak_b

View File

@@ -52,7 +52,7 @@ def get_iree_device_args(device, extra_args=[]):
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg"]:
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
return ["--iree-llvmcpu-target-cpu-features=host"]
elif frontend in ["tensorflow", "tf", "mhlo"]:
return [
@@ -188,21 +188,23 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
benchmark_bash.write(" ".join(benchmark_cl))
benchmark_bash.close()
benchmark_data = run_benchmark_module(benchmark_cl)
iter_per_second, _, _ = run_benchmark_module(
benchmark_cl
)
benchmark_file = open(
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
)
benchmark_file.write(f"DISPATCH: {d_}\n")
benchmark_file.write(str(benchmark_data) + "\n")
benchmark_file.write(str(iter_per_second) + "\n")
benchmark_file.write(
"SHARK BENCHMARK RESULT: "
+ str(1 / (benchmark_data * 0.001))
+ str(1 / (iter_per_second * 0.001))
+ "\n"
)
benchmark_file.close()
benchmark_runtimes[d_] = 1 / (benchmark_data * 0.001)
benchmark_runtimes[d_] = 1 / (iter_per_second * 0.001)
elif ".mlir" in f_ and "benchmark" not in f_:
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")

View File

@@ -30,11 +30,10 @@ def get_iree_gpu_args():
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
) and (shark_args.enable_tf32 == True):
return [
"--iree-hal-cuda-disable-loop-nounroll-wa",
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
]
else:
return ["--iree-hal-cuda-disable-loop-nounroll-wa"]
return []
# Get the default gpu args given the architecture.

View File

@@ -22,7 +22,8 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_vulkan_device_name():
vulkaninfo_dump = run_cmd("vulkaninfo").split(linesep)
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")

View File

@@ -21,9 +21,17 @@ from shark.iree_utils.benchmark_utils import (
from shark.parser import shark_args
from datetime import datetime
import time
from typing import Optional
import csv
import os
TF_CPU_DEVICE = "/CPU:0"
TF_GPU_DEVICE = "/GPU:0"
def _bytes_to_mb_str(bytes_: Optional[int]) -> str:
return "" if bytes_ is None else f"{bytes_ / 1e6:.6f}"
class OnnxFusionOptions(object):
def __init__(self):
@@ -126,18 +134,26 @@ class SharkBenchmarkRunner(SharkRunner):
for i in range(shark_args.num_warmup_iterations):
frontend_model.forward(input)
if self.device == "cuda":
torch.cuda.reset_peak_memory_stats()
begin = time.time()
for i in range(shark_args.num_iterations):
out = frontend_model.forward(input)
if i == shark_args.num_iterations - 1:
end = time.time()
break
end = time.time()
if self.device == "cuda":
stats = torch.cuda.memory_stats()
device_peak_b = stats["allocated_bytes.all.peak"]
else:
device_peak_b = None
print(
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
return [
f"{shark_args.num_iterations/(end-begin)}",
f"{((end-begin)/shark_args.num_iterations)*1000}",
"", # host_peak_b (CPU usage) is not reported by PyTorch.
_bytes_to_mb_str(device_peak_b),
]
def benchmark_tf(self, modelname):
@@ -155,8 +171,8 @@ class SharkBenchmarkRunner(SharkRunner):
from tank.model_utils_tf import get_tf_model
# tf_device = "/GPU:0" if self.device == "cuda" else "/CPU:0"
tf_device = "/CPU:0"
# tf_device = TF_GPU_DEVICE if self.device == "cuda" else TF_CPU_DEVICE
tf_device = TF_CPU_DEVICE
with tf.device(tf_device):
(
model,
@@ -169,24 +185,41 @@ class SharkBenchmarkRunner(SharkRunner):
for i in range(shark_args.num_warmup_iterations):
frontend_model.forward(*input)
if tf_device == TF_GPU_DEVICE:
tf.config.experimental.reset_memory_stats(tf_device)
begin = time.time()
for i in range(shark_args.num_iterations):
out = frontend_model.forward(*input)
if i == shark_args.num_iterations - 1:
end = time.time()
break
end = time.time()
if tf_device == TF_GPU_DEVICE:
memory_info = tf.config.experimental.get_memory_info(tf_device)
device_peak_b = memory_info["peak"]
else:
# tf.config.experimental does not currently support measuring
# CPU memory usage.
device_peak_b = None
print(
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
return [
f"{shark_args.num_iterations/(end-begin)}",
f"{((end-begin)/shark_args.num_iterations)*1000}",
"", # host_peak_b (CPU usage) is not reported by TensorFlow.
_bytes_to_mb_str(device_peak_b),
]
def benchmark_c(self):
result = run_benchmark_module(self.benchmark_cl)
print(f"Shark-IREE-C benchmark:{result} iter/second")
return [f"{result}", f"{1000/result}"]
iter_per_second, host_peak_b, device_peak_b = run_benchmark_module(
self.benchmark_cl
)
print(f"Shark-IREE-C benchmark:{iter_per_second} iter/second")
return [
f"{iter_per_second}",
f"{1000/iter_per_second}",
_bytes_to_mb_str(host_peak_b),
_bytes_to_mb_str(device_peak_b),
]
def benchmark_python(self, inputs):
input_list = [x for x in inputs]
@@ -196,8 +229,7 @@ class SharkBenchmarkRunner(SharkRunner):
begin = time.time()
for i in range(shark_args.num_iterations):
out = self.run("forward", input_list)
if i == shark_args.num_iterations - 1:
end = time.time()
end = time.time()
print(
f"Shark-IREE Python benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
@@ -324,7 +356,12 @@ for currently supported models. Exiting benchmark ONNX."
"tags",
"notes",
"datetime",
"host_memory_mb",
"device_memory_mb",
"measured_host_memory_mb",
"measured_device_memory_mb",
]
# "frontend" must be the first element.
engines = ["frontend", "shark_python", "shark_iree_c"]
if shark_args.onnx_bench == True:
engines.append("onnxruntime")
@@ -336,75 +373,76 @@ for currently supported models. Exiting benchmark ONNX."
with open("bench_results.csv", mode="a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=field_names)
bench_result = {}
bench_result["model"] = modelname
bench_info = {}
bench_info["model"] = modelname
bench_info["dialect"] = self.mlir_dialect
bench_info["iterations"] = shark_args.num_iterations
if dynamic == True:
bench_result["shape_type"] = "dynamic"
bench_info["shape_type"] = "dynamic"
else:
bench_result["shape_type"] = "static"
bench_result["device"] = device_str
bench_info["shape_type"] = "static"
bench_info["device"] = device_str
if "fp16" in modelname:
bench_result["data_type"] = "float16"
bench_info["data_type"] = "float16"
else:
bench_result["data_type"] = inputs[0].dtype
bench_info["data_type"] = inputs[0].dtype
for e in engines:
(
bench_result["param_count"],
bench_result["tags"],
bench_result["notes"],
) = ["", "", ""]
engine_result = {}
if e == "frontend":
bench_result["engine"] = frontend
engine_result["engine"] = frontend
if check_requirements(frontend):
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
engine_result["host_memory_mb"],
engine_result["device_memory_mb"],
) = self.benchmark_frontend(modelname)
self.frontend_result = bench_result["ms/iter"]
bench_result["vs. PyTorch/TF"] = "baseline"
self.frontend_result = engine_result["ms/iter"]
engine_result["vs. PyTorch/TF"] = "baseline"
(
bench_result["param_count"],
bench_result["tags"],
bench_result["notes"],
engine_result["param_count"],
engine_result["tags"],
engine_result["notes"],
) = self.get_metadata(modelname)
else:
self.frontend_result = None
continue
elif e == "shark_python":
bench_result["engine"] = "shark_python"
engine_result["engine"] = "shark_python"
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
) = self.benchmark_python(inputs)
bench_result[
engine_result[
"vs. PyTorch/TF"
] = self.compare_bench_results(
self.frontend_result, bench_result["ms/iter"]
self.frontend_result, engine_result["ms/iter"]
)
elif e == "shark_iree_c":
bench_result["engine"] = "shark_iree_c"
engine_result["engine"] = "shark_iree_c"
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
engine_result["host_memory_mb"],
engine_result["device_memory_mb"],
) = self.benchmark_c()
bench_result[
engine_result[
"vs. PyTorch/TF"
] = self.compare_bench_results(
self.frontend_result, bench_result["ms/iter"]
self.frontend_result, engine_result["ms/iter"]
)
elif e == "onnxruntime":
bench_result["engine"] = "onnxruntime"
engine_result["engine"] = "onnxruntime"
(
bench_result["iter/sec"],
bench_result["ms/iter"],
engine_result["iter/sec"],
engine_result["ms/iter"],
) = self.benchmark_onnx(modelname, inputs)
bench_result["dialect"] = self.mlir_dialect
bench_result["iterations"] = shark_args.num_iterations
bench_result["datetime"] = str(datetime.now())
writer.writerow(bench_result)
engine_result["datetime"] = str(datetime.now())
writer.writerow(bench_info | engine_result)

View File

@@ -194,8 +194,14 @@ def download_model(
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
filename = os.path.join(model_dir, model_name + suffix)
with open(filename, mode="rb") as f:
mlir_file = f.read()
try:
with open(filename, mode="rb") as f:
mlir_file = f.read()
except FileNotFoundError:
from tank.generate_sharktank import gen_shark_files
tank_dir = WORKDIR
gen_shark_files(model_name, frontend, tank_dir)
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
inputs = np.load(os.path.join(model_dir, "inputs.npz"))

View File

@@ -297,6 +297,7 @@ def transform_fx(fx_g):
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.zeros,
]:
node.kwargs = kwargs_dict
# Inputs and outputs of aten.var.mean should be upcasted to fp32.

View File

@@ -35,3 +35,12 @@ squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","mac
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"",""
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"",""
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"",""
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
gpt2,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
1 resnet50 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False False macos
35 wide_resnet50_2 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col False False False macos
36 efficientnet-v2-s mhlo tf 1e-02 1e-3 default nhcw-nhwc False False False macos
37 mnasnet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True macos
38 t5-base linalg torch 1e-2 1e-3 default None True True True
39 t5-base mhlo tf 1e-2 1e-3 default None False False False
40 t5-large linalg torch 1e-2 1e-3 default None True True True
41 t5-large mhlo tf 1e-2 1e-3 default None False False False
42 efficientnet_b0 linalg torch 1e-2 1e-3 default nhcw-nhwc False True False
43 efficientnet_b7 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False
44 efficientnet_b0 mhlo tf 1e-2 1e-3 default None nhcw-nhwc False False False
45 efficientnet_b7 mhlo tf 1e-2 1e-3 default None nhcw-nhwc False False False
46 gpt2 mhlo tf 1e-2 1e-3 default None False False False

View File

@@ -70,7 +70,7 @@ if __name__ == "__main__":
backend_config = "dylib"
# backend = "cuda"
# backend_config = "cuda"
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
flatbuffer_blob = compile_str(
compiler_module,
target_backends=[backend],

View File

@@ -146,7 +146,6 @@ if __name__ == "__main__":
backend_config = "cuda"
args = [
"--iree-cuda-llvm-target-arch=sm_80",
"--iree-hal-cuda-disable-loop-nounroll-wa",
"--iree-enable-fusion-with-reduction-ops",
]

View File

@@ -91,7 +91,7 @@ if __name__ == "__main__":
backend_config = "dylib"
# backend = "cuda"
# backend_config = "cuda"
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
flatbuffer_blob = compile_str(
compiler_module,
target_backends=[backend],

View File

@@ -86,7 +86,7 @@ if __name__ == "__main__":
backend_config = "dylib"
# backend = "cuda"
# backend_config = "cuda"
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
flatbuffer_blob = compile_str(
compiler_module,
target_backends=[backend],

View File

@@ -33,9 +33,10 @@ def create_hash(file_name):
return file_hash.hexdigest()
def save_torch_model(torch_model_list):
def save_torch_model(torch_model_list, local_tank_cache):
from tank.model_utils import (
get_hf_model,
get_hf_seq2seq_model,
get_vision_model,
get_hf_img_cls_model,
get_fp16_model,
@@ -59,7 +60,7 @@ def save_torch_model(torch_model_list):
args.use_tuned = False
args.import_mlir = True
args.use_tuned = False
args.local_tank_cache = WORKDIR
args.local_tank_cache = local_tank_cache
precision_values = ["fp16"]
seq_lengths = [64, 77]
@@ -75,7 +76,7 @@ def save_torch_model(torch_model_list):
height=512,
use_base_vae=False,
debug=True,
sharktank_dir=WORKDIR,
sharktank_dir=local_tank_cache,
generate_vmfb=False,
)
model()
@@ -84,13 +85,15 @@ def save_torch_model(torch_model_list):
model, input, _ = get_vision_model(torch_model_name)
elif model_type == "hf":
model, input, _ = get_hf_model(torch_model_name)
elif model_type == "hf_seq2seq":
model, input, _ = get_hf_seq2seq_model(torch_model_name)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(torch_model_name)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name)
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"
local_tank_cache, str(torch_model_name) + "_torch"
)
os.makedirs(torch_model_dir, exist_ok=True)
@@ -115,12 +118,14 @@ def save_torch_model(torch_model_list):
)
def save_tf_model(tf_model_list):
def save_tf_model(tf_model_list, local_tank_cache):
from tank.model_utils_tf import (
get_causal_image_model,
get_masked_lm_model,
get_causal_lm_model,
get_keras_model,
get_TFhf_model,
get_tfhf_seq2seq_model,
)
import tensorflow as tf
@@ -146,15 +151,19 @@ def save_tf_model(tf_model_list):
print(f"Generating artifacts for model {tf_model_name}")
if model_type == "hf":
model, input, _ = get_causal_lm_model(tf_model_name)
if model_type == "img":
elif model_type == "img":
model, input, _ = get_causal_image_model(tf_model_name)
if model_type == "keras":
elif model_type == "keras":
model, input, _ = get_keras_model(tf_model_name)
if model_type == "TFhf":
elif model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name)
elif model_type == "tfhf_seq2seq":
model, input, _ = get_tfhf_seq2seq_model(tf_model_name)
tf_model_name = tf_model_name.replace("/", "_")
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
tf_model_dir = os.path.join(
local_tank_cache, str(tf_model_name) + "_tf"
)
os.makedirs(tf_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
model,
@@ -172,7 +181,7 @@ def save_tf_model(tf_model_list):
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
def save_tflite_model(tflite_model_list):
def save_tflite_model(tflite_model_list, local_tank_cache):
from shark.tflite_utils import TFLitePreprocessor
with open(tflite_model_list) as csvfile:
@@ -184,7 +193,7 @@ def save_tflite_model(tflite_model_list):
print("tflite_model_name", tflite_model_name)
print("tflite_model_link", tflite_model_link)
tflite_model_name_dir = os.path.join(
WORKDIR, str(tflite_model_name) + "_tflite"
local_tank_cache, str(tflite_model_name) + "_tflite"
)
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
@@ -219,6 +228,45 @@ def save_tflite_model(tflite_model_list):
)
def gen_shark_files(modelname, frontend, tank_dir):
# If a model's artifacts are requested by shark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
# TODO: Add TFlite support.
import tempfile
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
)
# Create a temporary .csv with only the desired entry.
if frontend == "tf":
with open(tf_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_tf_model(custom_model_csv.name, tank_dir)
if frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_torch_model(custom_model_csv.name, tank_dir)
# Validates whether the file is present or not.
def is_valid_file(arg):
if not os.path.exists(arg):
@@ -259,20 +307,19 @@ if __name__ == "__main__":
# old_args = parser.parse_args()
home = str(Path.home())
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
WORKDIR = os.path.join(os.path.dirname(__file__), "..", "gen_shark_tank")
torch_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tf_model_list.csv"
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
tflite_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
)
save_torch_model(
os.path.join(os.path.dirname(__file__), "tank", "torch_sd_list.csv")
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
WORKDIR,
)
save_torch_model(torch_model_csv)
save_tf_model(tf_model_csv)
save_tflite_model(tflite_model_csv)
save_torch_model(torch_model_csv, WORKDIR)
save_tf_model(tf_model_csv, WORKDIR)
save_tflite_model(tflite_model_csv, WORKDIR)

View File

@@ -31,4 +31,12 @@ xlm-roberta-base,False,False,-,-,-
facebook/convnext-tiny-224,False,False,-,-,-
efficientnet-v2-s,False,False,22M,"image-classification,cnn","Includes MBConv and Fused-MBConv"
mnasnet1_0,False,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
bert-large-uncased,True,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
t5-base,True,False,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
t5-large,True,False,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
efficientnet_b0,True,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
gpt2,True,False,110M,"nlp;transformer-decoder;auto-regressive","12 layers, 768 hidden units, 12 attention heads"
t5-base,True,False,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
t5-large,True,False,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
1 model_name use_tracing dynamic param_count tags notes
31 facebook/convnext-tiny-224 False False - - -
32 efficientnet-v2-s False False 22M image-classification,cnn Includes MBConv and Fused-MBConv
33 mnasnet1_0 False True - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
34 bert-large-uncased True True 330M nlp;bert-variant;transformer-encoder 24 layers, 1024 hidden units, 16 attention heads
35 t5-base True False 220M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
36 t5-large True False 770M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
37 bert-large-uncased True hf True 330M nlp;bert-variant;transformer-encoder
38 efficientnet_b0 True False 5.3M image-classification;cnn;conv2d;depthwise-conv Smallest EfficientNet variant with 224x224 input
39 efficientnet_b7 True False 66M image-classification;cnn;conv2d;depthwise-conv Largest EfficientNet variant with 600x600 input
40 gpt2 True False 110M nlp;transformer-decoder;auto-regressive 12 layers, 768 hidden units, 12 attention heads
41 t5-base True False 220M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
42 t5-large True False 770M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer

View File

@@ -7,6 +7,8 @@ import sys
torch.manual_seed(0)
BATCH_SIZE = 1
vision_models = [
"alexnet",
"resnet101",
@@ -17,6 +19,8 @@ vision_models = [
"wide_resnet50_2",
"mobilenet_v3_small",
"mnasnet1_0",
"efficientnet_b0",
"efficientnet_b7",
]
hf_img_cls_models = [
"google/vit-base-patch16-224",
@@ -25,6 +29,10 @@ hf_img_cls_models = [
"microsoft/beit-base-patch16-224-pt22k-ft22k",
"nvidia/mit-b0",
]
hf_seq2seq_models = [
"t5-base",
"t5-large",
]
def get_torch_model(modelname):
@@ -32,6 +40,8 @@ def get_torch_model(modelname):
return get_vision_model(modelname)
elif modelname in hf_img_cls_models:
return get_hf_img_cls_model(modelname)
elif modelname in hf_seq2seq_models:
return get_hf_seq2seq_model(modelname)
elif "fp16" in modelname:
return get_fp16_model(modelname)
else:
@@ -85,6 +95,7 @@ def get_hf_img_cls_model(name):
# test_input = torch.FloatTensor(1, 3, 224, 224).uniform_(-1, 1)
# print("test_input.shape: ", test_input.shape)
# test_input.shape: torch.Size([1, 3, 224, 224])
test_input = test_input.repeat(BATCH_SIZE, 1, 1, 1)
actual_out = model(test_input)
# print("actual_out.shape ", actual_out.shape)
# actual_out.shape torch.Size([1, 1000])
@@ -121,11 +132,52 @@ def get_hf_model(name):
model = HuggingFaceLanguage(name)
# TODO: Currently the test input is set to (1,128)
test_input = torch.randint(2, (1, 128))
test_input = torch.randint(2, (BATCH_SIZE, 128))
actual_out = model(test_input)
return model, test_input, actual_out
##################### Hugging Face Seq2SeqLM Models ###################################
# We use a maximum sequence length of 512 since this is the default used in the T5 config.
T5_MAX_SEQUENCE_LENGTH = 512
class HFSeq2SeqLanguageModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
from transformers import AutoTokenizer, T5Model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenization_kwargs = {
"pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH,
"padding": True,
"return_tensors": "pt",
}
self.model = T5Model.from_pretrained(model_name, return_dict=True)
def preprocess_input(self, text):
return self.tokenizer(text, **self.tokenization_kwargs)
def forward(self, input_ids, decoder_input_ids):
return self.model.forward(
input_ids, decoder_input_ids=decoder_input_ids
)[0]
def get_hf_seq2seq_model(name):
m = HFSeq2SeqLanguageModel(name)
encoded_input_ids = m.preprocess_input(
"Studies have been shown that owning a dog is good for you"
).input_ids
decoder_input_ids = m.preprocess_input("Studies show that").input_ids
decoder_input_ids = m.model._shift_right(decoder_input_ids)
test_input = (encoded_input_ids, decoder_input_ids)
actual_out = m.forward(*test_input)
return m, test_input, actual_out
################################################################################
##################### Torch Vision Models ###################################
@@ -144,24 +196,50 @@ class VisionModule(torch.nn.Module):
def get_vision_model(torch_model):
import torchvision.models as models
default_image_size = (224, 224)
vision_models_dict = {
"alexnet": models.alexnet(weights="DEFAULT"),
"resnet18": models.resnet18(weights="DEFAULT"),
"resnet50": models.resnet50(weights="DEFAULT"),
"resnet50_fp16": models.resnet50(weights="DEFAULT"),
"resnet101": models.resnet101(weights="DEFAULT"),
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
"mobilenet_v3_small": models.mobilenet_v3_small(weights="DEFAULT"),
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
"alexnet": (models.alexnet(weights="DEFAULT"), default_image_size),
"resnet18": (models.resnet18(weights="DEFAULT"), default_image_size),
"resnet50": (models.resnet50(weights="DEFAULT"), default_image_size),
"resnet50_fp16": (
models.resnet50(weights="DEFAULT"),
default_image_size,
),
"resnet101": (models.resnet101(weights="DEFAULT"), default_image_size),
"squeezenet1_0": (
models.squeezenet1_0(weights="DEFAULT"),
default_image_size,
),
"wide_resnet50_2": (
models.wide_resnet50_2(weights="DEFAULT"),
default_image_size,
),
"mobilenet_v3_small": (
models.mobilenet_v3_small(weights="DEFAULT"),
default_image_size,
),
"mnasnet1_0": (
models.mnasnet1_0(weights="DEFAULT"),
default_image_size,
),
# EfficientNet input image size varies on the size of the model.
"efficientnet_b0": (
models.efficientnet_b0(weights="DEFAULT"),
(224, 224),
),
"efficientnet_b7": (
models.efficientnet_b7(weights="DEFAULT"),
(600, 600),
),
}
if isinstance(torch_model, str):
fp16_model = None
if "fp16" in torch_model:
fp16_model = True
torch_model = vision_models_dict[torch_model]
torch_model, input_image_size = vision_models_dict[torch_model]
model = VisionModule(torch_model)
test_input = torch.randn(1, 3, 224, 224)
test_input = torch.randn(BATCH_SIZE, 3, 224, 224)
actual_out = model(test_input)
if fp16_model is not None:
test_input_fp16 = test_input.to(
@@ -209,6 +287,7 @@ def get_fp16_model(torch_model):
model = BertHalfPrecisionModel(modelname)
tokenizer = AutoTokenizer.from_pretrained(modelname)
text = "Replace me by any text you like."
text = [text] * BATCH_SIZE
test_input_fp16 = tokenizer(
text,
truncation=True,

View File

@@ -7,11 +7,15 @@ from transformers import (
)
BATCH_SIZE = 1
MAX_SEQUENCE_LENGTH = 128
################################## MHLO/TF models #########################################
# TODO : Generate these lists or fetch model source from tank/tf/tf_model_list.csv
keras_models = ["resnet50", "efficientnet-v2-s"]
keras_models = [
"resnet50",
"efficientnet_b0",
"efficientnet_b7",
"efficientnet-v2-s",
]
maskedlm_models = [
"albert-base-v2",
"bert-base-uncased",
@@ -32,9 +36,16 @@ maskedlm_models = [
"hf-internal-testing/tiny-random-flaubert",
"xlm-roberta",
]
causallm_models = [
"gpt2",
]
tfhf_models = [
"microsoft/MiniLM-L12-H384-uncased",
]
tfhf_seq2seq_models = [
"t5-base",
"t5-large",
]
img_models = [
"google/vit-base-patch16-224",
"facebook/convnext-tiny-224",
@@ -45,23 +56,35 @@ def get_tf_model(name):
if name in keras_models:
return get_keras_model(name)
elif name in maskedlm_models:
return get_masked_lm_model(name)
elif name in causallm_models:
return get_causal_lm_model(name)
elif name in tfhf_models:
return get_TFhf_model(name)
elif name in img_models:
return get_causal_image_model(name)
elif name in tfhf_seq2seq_models:
return get_tfhf_seq2seq_model(name)
else:
raise Exception(
"TF model not found! Please check that the modelname has been input correctly."
)
##################### Tensorflow Hugging Face LM Models ###################################
##################### Tensorflow Hugging Face Bert Models ###################################
BERT_MAX_SEQUENCE_LENGTH = 128
# Create a set of 2-dimensional inputs
tf_bert_input = [
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
tf.TensorSpec(
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
tf.TensorSpec(
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
]
@@ -87,21 +110,31 @@ def get_TFhf_model(name):
"microsoft/MiniLM-L12-H384-uncased"
)
text = "Replace me by any text you'd like."
text = [text] * BATCH_SIZE
encoded_input = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
)
for key in encoded_input:
encoded_input[key] = tf.expand_dims(
tf.convert_to_tensor(encoded_input[key]), 0
)
test_input = (
encoded_input["input_ids"],
encoded_input["attention_mask"],
encoded_input["token_type_ids"],
max_length=BERT_MAX_SEQUENCE_LENGTH,
)
test_input = [
tf.reshape(
tf.convert_to_tensor(encoded_input["input_ids"], dtype=tf.int32),
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
),
tf.reshape(
tf.convert_to_tensor(
encoded_input["attention_mask"], dtype=tf.int32
),
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
),
tf.reshape(
tf.convert_to_tensor(
encoded_input["token_type_ids"], dtype=tf.int32
),
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
),
]
actual_out = model.forward(*test_input)
return model, test_input, actual_out
@@ -115,34 +148,41 @@ def compare_tensors_tf(tf_tensor, numpy_tensor):
return np.allclose(tf_to_numpy, numpy_tensor, rtol, atol)
##################### Tensorflow Hugging Face Masked LM Models ###################################
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
import tensorflow as tf
# Create a set of input signature.
input_signature_maskedlm = [
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
]
# For supported models please see here:
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForCasualLM
# Tokenizer for language models
def preprocess_input(
model_name, text="This is just used to compile the model"
model_name, max_length, text="This is just used to compile the model"
):
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = [text] * BATCH_SIZE
inputs = tokenizer(
text,
padding="max_length",
return_tensors="tf",
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
max_length=max_length,
)
return inputs
##################### Tensorflow Hugging Face Masked LM Models ###################################
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
import tensorflow as tf
MASKED_LM_MAX_SEQUENCE_LENGTH = 128
# Create a set of input signature.
input_signature_maskedlm = [
tf.TensorSpec(
shape=[BATCH_SIZE, MASKED_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
tf.TensorSpec(
shape=[BATCH_SIZE, MASKED_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
]
# For supported models please see here:
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForMaskedLM
class MaskedLM(tf.Module):
def __init__(self, model_name):
super(MaskedLM, self).__init__()
@@ -156,19 +196,139 @@ class MaskedLM(tf.Module):
return self.m.predict(input_ids, attention_mask)
def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
def get_masked_lm_model(hf_name, text="Hello, this is the default text."):
model = MaskedLM(hf_name)
encoded_input = preprocess_input(hf_name, text)
encoded_input = preprocess_input(
hf_name, MASKED_LM_MAX_SEQUENCE_LENGTH, text
)
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
actual_out = model.forward(*test_input)
return model, test_input, actual_out
##################### Tensorflow Hugging Face Causal LM Models ###################################
from transformers import AutoConfig, TFAutoModelForCausalLM, TFGPT2Model
CAUSAL_LM_MAX_SEQUENCE_LENGTH = 1024
input_signature_causallm = [
tf.TensorSpec(
shape=[BATCH_SIZE, CAUSAL_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
tf.TensorSpec(
shape=[BATCH_SIZE, CAUSAL_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
),
]
# For supported models please see here:
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForCausalLM
# For more background, see:
# https://huggingface.co/blog/tf-xla-generate
class CausalLM(tf.Module):
def __init__(self, model_name):
super(CausalLM, self).__init__()
# Decoder-only models need left padding.
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, padding_side="left", pad_token="</s>"
)
self.tokenization_kwargs = {
"pad_to_multiple_of": CAUSAL_LM_MAX_SEQUENCE_LENGTH,
"padding": True,
"return_tensors": "tf",
}
self.model = TFGPT2Model.from_pretrained(model_name, return_dict=True)
self.model.predict = lambda x, y: self.model(
input_ids=x, attention_mask=y
)[0]
def preprocess_input(self, text):
return self.tokenizer(text, **self.tokenization_kwargs)
@tf.function(input_signature=input_signature_causallm, jit_compile=True)
def forward(self, input_ids, attention_mask):
return self.model.predict(input_ids, attention_mask)
def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
model = CausalLM(hf_name)
batched_text = [text] * BATCH_SIZE
encoded_input = model.preprocess_input(batched_text)
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
actual_out = model.forward(*test_input)
return model, test_input, actual_out
##################### TensorflowHugging Face Seq2SeqLM Models ###################################
# We use a maximum sequence length of 512 since this is the default used in the T5 config.
T5_MAX_SEQUENCE_LENGTH = 512
input_signature_t5 = [
tf.TensorSpec(
shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH],
dtype=tf.int32,
name="input_ids",
),
tf.TensorSpec(
shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH],
dtype=tf.int32,
name="attention_mask",
),
]
class TFHFSeq2SeqLanguageModel(tf.Module):
def __init__(self, model_name):
super(TFHFSeq2SeqLanguageModel, self).__init__()
from transformers import (
AutoTokenizer,
AutoConfig,
TFAutoModelForSeq2SeqLM,
TFT5Model,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenization_kwargs = {
"pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH,
"padding": True,
"return_tensors": "tf",
}
self.model = TFT5Model.from_pretrained(model_name, return_dict=True)
self.model.predict = lambda x, y: self.model(x, decoder_input_ids=y)[0]
def preprocess_input(self, text):
return self.tokenizer(text, **self.tokenization_kwargs)
@tf.function(input_signature=input_signature_t5, jit_compile=True)
def forward(self, input_ids, decoder_input_ids):
return self.model.predict(input_ids, decoder_input_ids)
def get_tfhf_seq2seq_model(name):
m = TFHFSeq2SeqLanguageModel(name)
text = "Studies have been shown that owning a dog is good for you"
batched_text = [text] * BATCH_SIZE
encoded_input_ids = m.preprocess_input(batched_text).input_ids
text = "Studies show that"
batched_text = [text] * BATCH_SIZE
decoder_input_ids = m.preprocess_input(batched_text).input_ids
decoder_input_ids = m.model._shift_right(decoder_input_ids)
test_input = (encoded_input_ids, decoder_input_ids)
actual_out = m.forward(*test_input)
return m, test_input, actual_out
##################### TensorFlow Keras Resnet Models #########################################################
# Static shape, including batch size (1).
# Can be dynamic once dynamic shape support is ready.
RESNET_INPUT_SHAPE = [1, 224, 224, 3]
EFFICIENTNET_INPUT_SHAPE = [1, 384, 384, 3]
RESNET_INPUT_SHAPE = [BATCH_SIZE, 224, 224, 3]
EFFICIENTNET_V2_S_INPUT_SHAPE = [BATCH_SIZE, 384, 384, 3]
EFFICIENTNET_B0_INPUT_SHAPE = [BATCH_SIZE, 224, 224, 3]
EFFICIENTNET_B7_INPUT_SHAPE = [BATCH_SIZE, 600, 600, 3]
class ResNetModule(tf.Module):
@@ -195,25 +355,79 @@ class ResNetModule(tf.Module):
return tf.keras.applications.resnet50.preprocess_input(image)
class EfficientNetModule(tf.Module):
class EfficientNetB0Module(tf.Module):
def __init__(self):
super(EfficientNetModule, self).__init__()
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
super(EfficientNetB0Module, self).__init__()
self.m = tf.keras.applications.efficientnet.EfficientNetB0(
weights="imagenet",
include_top=True,
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
input_shape=tuple(EFFICIENTNET_B0_INPUT_SHAPE[1:]),
)
self.m.predict = lambda x: self.m.call(x, training=False)
@tf.function(
input_signature=[tf.TensorSpec(EFFICIENTNET_INPUT_SHAPE, tf.float32)],
input_signature=[
tf.TensorSpec(EFFICIENTNET_B0_INPUT_SHAPE, tf.float32)
],
jit_compile=True,
)
def forward(self, inputs):
return self.m.predict(inputs)
def input_shape(self):
return EFFICIENTNET_INPUT_SHAPE
return EFFICIENTNET_B0_INPUT_SHAPE
def preprocess_input(self, image):
return tf.keras.applications.efficientnet.preprocess_input(image)
class EfficientNetB7Module(tf.Module):
def __init__(self):
super(EfficientNetB7Module, self).__init__()
self.m = tf.keras.applications.efficientnet.EfficientNetB7(
weights="imagenet",
include_top=True,
input_shape=tuple(EFFICIENTNET_B7_INPUT_SHAPE[1:]),
)
self.m.predict = lambda x: self.m.call(x, training=False)
@tf.function(
input_signature=[
tf.TensorSpec(EFFICIENTNET_B7_INPUT_SHAPE, tf.float32)
],
jit_compile=True,
)
def forward(self, inputs):
return self.m.predict(inputs)
def input_shape(self):
return EFFICIENTNET_B7_INPUT_SHAPE
def preprocess_input(self, image):
return tf.keras.applications.efficientnet.preprocess_input(image)
class EfficientNetV2SModule(tf.Module):
def __init__(self):
super(EfficientNetV2SModule, self).__init__()
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
weights="imagenet",
include_top=True,
input_shape=tuple(EFFICIENTNET_V2_S_INPUT_SHAPE[1:]),
)
self.m.predict = lambda x: self.m.call(x, training=False)
@tf.function(
input_signature=[
tf.TensorSpec(EFFICIENTNET_V2_S_INPUT_SHAPE, tf.float32)
],
jit_compile=True,
)
def forward(self, inputs):
return self.m.predict(inputs)
def input_shape(self):
return EFFICIENTNET_V2_S_INPUT_SHAPE
def preprocess_input(self, image):
return tf.keras.applications.efficientnet_v2.preprocess_input(image)
@@ -224,12 +438,17 @@ def load_image(path_to_image, width, height, channels):
image = tf.image.decode_image(image, channels=channels)
image = tf.image.resize(image, (width, height))
image = image[tf.newaxis, :]
image = tf.tile(image, [BATCH_SIZE, 1, 1, 1])
return image
def get_keras_model(modelname):
if modelname == "efficientnet-v2-s":
model = EfficientNetModule()
model = EfficientNetV2SModule()
elif modelname == "efficientnet_b0":
model = EfficientNetB0Module()
elif modelname == "efficientnet_b7":
model = EfficientNetB7Module()
else:
model = ResNetModule()
@@ -256,7 +475,7 @@ import requests
# Create a set of input signature.
input_signature_img_cls = [
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
tf.TensorSpec(shape=[BATCH_SIZE, 3, 224, 224], dtype=tf.float32),
]
@@ -304,6 +523,9 @@ def preprocess_input_image(model_name):
)
# inputs: {'pixel_values': <tf.Tensor: shape=(1, 3, 224, 224), dtype=float32, numpy=array([[[[]]]], dtype=float32)>}
inputs = feature_extractor(images=image, return_tensors="tf")
inputs["pixel_values"] = tf.tile(
inputs["pixel_values"], [BATCH_SIZE, 1, 1, 1]
)
return [inputs[str(*inputs)]]

View File

@@ -48,7 +48,9 @@ def load_csv_and_convert(filename, gen=False):
)
# This is a pytest workaround
if gen:
with open("tank/dict_configs.py", "w+") as out:
with open(
os.path.join(os.path.dirname(__file__), "dict_configs.py"), "w+"
) as out:
out.write("ALL = [\n")
for c in model_configs:
out.write(str(c) + ",\n")
@@ -68,7 +70,9 @@ def get_valid_test_params():
dynamic_list = (True, False)
# TODO: This is soooo ugly, but for some reason creating the dict at runtime
# results in strange pytest failures.
load_csv_and_convert("tank/all_models.csv", True)
load_csv_and_convert(
os.path.join(os.path.dirname(__file__), "all_models.csv"), True
)
from tank.dict_configs import ALL
config_list = ALL

View File

@@ -19,3 +19,10 @@ facebook/convnext-tiny-224,img
google/vit-base-patch16-224,img
efficientnet-v2-s,keras
bert-large-uncased,hf
t5-base,tfhf_seq2seq
t5-large,tfhf_seq2seq
efficientnet_b0,keras
efficientnet_b7,keras
gpt2,hf_causallm
t5-base,tfhf_seq2seq
t5-large,tfhf_seq2seq
1 model_name model_type
19 google/vit-base-patch16-224 img
20 efficientnet-v2-s keras
21 bert-large-uncased hf
22 t5-base tfhf_seq2seq
23 t5-large tfhf_seq2seq
24 efficientnet_b0 keras
25 efficientnet_b7 keras
26 gpt2 hf_causallm
27 t5-base tfhf_seq2seq
28 t5-large tfhf_seq2seq

View File

@@ -18,4 +18,6 @@ nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encod
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
1 model_name use_tracing model_type dynamic param_count tags notes
18 mnasnet1_0 False vision True - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
19 resnet50_fp16 False vision True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
20 bert-base-uncased_fp16 True fp16 False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
21 bert-large-uncased True hf True 330M nlp;bert-variant;transformer-encoder 24 layers, 1024 hidden units, 16 attention heads
22 efficientnet_b0 True vision False 5.3M image-classification;cnn;conv2d;depthwise-conv Smallest EfficientNet variant with 224x224 input
23 efficientnet_b7 True vision False 66M image-classification;cnn;conv2d;depthwise-conv Largest EfficientNet variant with 600x600 input

View File

@@ -1,4 +1,3 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
stabilityai/stable-diffusion-2-1-base,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
stabilityai/stable-diffusion-2-1,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
prompthero/openjourney,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
1 model_name use_tracing model_type dynamic param_count tags notes
2 stabilityai/stable-diffusion-2-1-base True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A
3 stabilityai/stable-diffusion-2-1 True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A
prompthero/openjourney True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A