Compare commits

...

24 Commits

Author SHA1 Message Date
Daniel Garvey
594c6b8ea2 fix ckpt dir (#1258) 2023-03-28 14:31:01 -07:00
Ean Garvey
96b1560da5 Make batch size configurable via pytest and fix sharktank generation. (#1227)
* Fix sharktank generation and add batch_size pytest option for torch.

* Disable torch dynamo until py3.11 supported

* Compile torchmodel without dynamo if torch.compile fails

* Use release versions of TF/Keras for importer.

* Pin torchvision and remove debug prints.

* Remove duplicates from torch model list.

* Update generate_sharktank.py

* xfail a few models that fail sharktank generation/ numerics
2023-03-28 14:33:39 -05:00
Abhishek Varma
0ef6a0e234 [SD] Fix Stencil scribble crash by updating image resize (#1255)
-- This commit updates Stencil resize feature to cap the size of
   images within [128,768] as supported by the SD pipeline.
-- This solves the issue of scribble crashing on larger image.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-28 10:13:11 -07:00
Gaurav Shukla
641d535f44 [SD] Fix device path issue for cpu (#1256)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-28 10:09:49 -07:00
Daniel Garvey
5bb7846227 single entry point exe for all cli apps (#1158)
usage:
add --app="img2img" (or "inpaint" "outpaint" "txt2img")
2023-03-28 11:15:21 -05:00
yzhang93
8f84258fb8 Fix check for use_tuned conditions (#1252) 2023-03-27 11:21:25 -07:00
Ean Garvey
7619e76bbd Disable and xfail some models that fail validation/compilation. (#1251)
* Rollback T5 models for torch as the inputs give some issues that aren't trivial to resolve
* xfail efficientnet-b0 on torch+cuda -- see CUDA requesting shared memory size larger than allowed size openxla/iree#12771
2023-03-27 12:42:53 -05:00
Daniel Garvey
9267eadbfa disable openjourney gen for nightly (#1249) 2023-03-27 11:55:34 -05:00
Phaneesh Barwaria
431132b8ee Fix img2img mode switch (#1247)
* add updated scheduler value in global config

* clear scheduler global variable with others
2023-03-27 07:01:22 -07:00
cstueckrath
fb35e13e7a fix Python version detection bug (#1246)
* fix Python version detection bug

* Update setup_venv.ps1
2023-03-27 07:00:40 -07:00
yzhang93
17a67897d1 Add SD v2.1 768x768 tuned model (#1244)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-24 10:39:15 -07:00
Gaurav Shukla
da449b73aa [SD] Disable lora training tab for now (#1241)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-24 09:16:24 -07:00
Kyle Herndon
0b0526699a Fix incorrect device argument initialization for LoRA training by extracting the device type and number and formatting it for pytorch (#1237)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
2023-03-24 01:10:50 -07:00
Boian Petkantchin
4fac46f7bb In models testing fix paths to be relative to the script dir not cwd (#1128)
authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-03-22 15:26:52 -05:00
Daniel Garvey
49925950f1 fix false positives (#1193) 2023-03-22 15:25:39 -05:00
Thomas
807947c0c8 Remove deprecated cli option iree-hal-cuda-disable-loop-nounroll-wa (#1235) 2023-03-22 12:05:15 -05:00
Abhishek Varma
593428bda4 [SD] Fix for transformers/__init__.py issue in PyInstaller (#1233)
-- This commit fixes the transformers/__init__.py issue in PyInstaller.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-22 08:43:53 -07:00
Abhishek Varma
cede9b4fec [SD] Fix custom_vae as a required parameter in inpaint (#1232) 2023-03-22 04:30:17 -07:00
Prashant Kumar
c2360303f0 Add the int8 quantized model. 2023-03-22 16:28:13 +05:30
jinchen62
420366c1b8 Move schedulers to global obj (#1225) 2023-03-21 22:40:43 -07:00
Ean Garvey
d31bae488c Set iree-input-type to tm_tensor for SD (#1228) 2023-03-21 19:07:31 -07:00
Kyle Herndon
c23fcf3748 Fix incorrect device argument initialization for LoRA training (#1231)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-21 19:07:18 -07:00
jinchen62
7dbbb1726a Fix SD obj not defined if fail to get models from pretrained (#1222) 2023-03-21 07:55:17 -07:00
Abhishek Varma
8b8cc7fd33 [SD] Update LoRA inference to handle various checkpoints (#1215) 2023-03-21 06:52:20 -07:00
48 changed files with 746 additions and 370 deletions

View File

@@ -120,7 +120,7 @@ jobs:
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
@@ -158,4 +158,6 @@ jobs:
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
python build_tools/stable_diffusion_testing.py --device=vulkan

View File

@@ -114,12 +114,12 @@ source shark.venv/bin/activate
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
```shell
python3.11 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
python3.11 apps/stable_diffusion/scripts/main.py --app=txt2img --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc

View File

@@ -2,6 +2,7 @@ import sys
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
@@ -15,8 +16,6 @@ from apps.stable_diffusion.src import (
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -25,15 +24,15 @@ init_import_mlir = args.import_mlir
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# Both width and height should be > 384 and multiple of 8.
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image):
width, height = image.size
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 384:
n_size = 384
if min_size < 128:
n_size = 128
if width == min_size:
width = n_size
height = n_size / aspect_ratio
@@ -46,6 +45,22 @@ def resize_stencil(image: Image.Image):
n_height = height // 8
n_width *= 8
n_height *= 8
min_size = min(width, height)
if min_size > 768:
n_size = 768
if width == min_size:
height = n_size
width = n_size * aspect_ratio
else:
width = n_size
height = n_size / aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
new_image = image.resize((n_width, n_height))
return new_image, n_width, n_height
@@ -77,6 +92,7 @@ def img2img_inf(
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
@@ -84,8 +100,6 @@ def img2img_inf(
SD_STATE_CANCEL,
)
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
@@ -100,10 +114,6 @@ def img2img_inf(
image = init_image.convert("RGB")
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -118,14 +128,9 @@ def img2img_inf(
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -159,7 +164,7 @@ def img2img_inf(
height,
width,
device,
use_lora=use_lora,
use_lora=args.use_lora,
use_stencil=use_stencil,
)
if (
@@ -183,8 +188,8 @@ def img2img_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
if use_stencil is not None:
args.use_tuned = False
@@ -205,7 +210,7 @@ def img2img_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
)
)
else:
@@ -225,11 +230,11 @@ def img2img_inf(
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.set_sd_scheduler(args.scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
@@ -274,7 +279,7 @@ def img2img_inf(
return generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -380,3 +385,7 @@ if __name__ == "__main__":
extra_info = {"STRENGTH": args.strength}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -13,8 +14,6 @@ from apps.stable_diffusion.src import (
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -48,6 +47,7 @@ def inpaint_inf(
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
@@ -55,8 +55,6 @@ def inpaint_inf(
SD_STATE_CANCEL,
)
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
@@ -66,10 +64,6 @@ def inpaint_inf(
args.mask_path = "not none"
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -84,14 +78,9 @@ def inpaint_inf(
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -108,7 +97,7 @@ def inpaint_inf(
height,
width,
device,
use_lora=use_lora,
use_lora=args.use_lora,
use_stencil=None,
)
if (
@@ -133,14 +122,15 @@ def inpaint_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
@@ -148,14 +138,13 @@ def inpaint_inf(
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
@@ -202,7 +191,7 @@ def inpaint_inf(
return generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -232,6 +221,7 @@ if __name__ == "__main__":
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
@@ -239,7 +229,6 @@ if __name__ == "__main__":
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
@@ -285,3 +274,7 @@ if __name__ == "__main__":
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,19 @@
from apps.stable_diffusion.src import args
from apps.stable_diffusion.scripts import (
img2img,
txt2img,
# inpaint,
# outpaint,
)
if __name__ == "__main__":
if args.app == "txt2img":
txt2img.main()
elif args.app == "img2img":
img2img.main()
# elif args.app == "inpaint":
# inpaint.main()
# elif args.app == "outpaint":
# outpaint.main()
else:
print(f"args.app value is {args.app} but this isn't supported")

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
@@ -13,8 +14,6 @@ from apps.stable_diffusion.src import (
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -51,6 +50,7 @@ def outpaint_inf(
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
@@ -58,8 +58,6 @@ def outpaint_inf(
SD_STATE_CANCEL,
)
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
@@ -68,10 +66,6 @@ def outpaint_inf(
args.img_path = "not none"
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -86,14 +80,9 @@ def outpaint_inf(
else:
args.hf_model_id = custom_model
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -110,7 +99,7 @@ def outpaint_inf(
height,
width,
device,
use_lora=use_lora,
use_lora=args.use_lora,
use_stencil=None,
)
if (
@@ -135,8 +124,8 @@ def outpaint_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
OutpaintPipeline.from_pretrained(
scheduler_obj,
@@ -151,11 +140,11 @@ def outpaint_inf(
args.width,
args.use_base_vae,
args.use_tuned,
use_lora=use_lora,
use_lora=args.use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
@@ -211,7 +200,7 @@ def outpaint_inf(
return generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -310,3 +299,7 @@ if __name__ == "__main__":
}
save_output_img(generated_imgs[0], seed, extra_info)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -159,9 +159,6 @@ class LoraDataset(Dataset):
return example
schedulers = None
########## Setting up the model ##########
def lora_train(
prompt: str,
@@ -187,8 +184,6 @@ def lora_train(
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
print(
"Note LoRA training is not compatible with the latest torch-mlir branch"
)
@@ -227,7 +222,12 @@ def lora_train(
args.max_length = max_length
args.height = height
args.width = width
args.device = device
device_str = device.split("=>", 1)[1].strip().split("://")
if len(device_str) > 1:
device_str = device_str[0] + ":" + device_str[1]
else:
device_str = device_str[0]
args.device = device_str
# Load the Stable Diffusion model
text_encoder = CLIPTextModel.from_pretrained(

View File

@@ -1,4 +1,5 @@
import torch
import transformers
import time
from apps.stable_diffusion.src import (
args,
@@ -11,7 +12,6 @@ from apps.stable_diffusion.src import (
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
@@ -43,6 +43,7 @@ def txt2img_inf(
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
@@ -50,8 +51,6 @@ def txt2img_inf(
SD_STATE_CANCEL,
)
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
@@ -59,10 +58,6 @@ def txt2img_inf(
args.scheduler = scheduler
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -80,14 +75,9 @@ def txt2img_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_lora = ""
if lora_weights == "None" and not lora_hf_id:
use_lora = ""
elif not lora_hf_id:
use_lora = lora_weights
else:
use_lora = lora_hf_id
args.use_lora = use_lora
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
@@ -101,7 +91,7 @@ def txt2img_inf(
height,
width,
device,
use_lora=use_lora,
use_lora=args.use_lora,
use_stencil=None,
)
if (
@@ -127,8 +117,8 @@ def txt2img_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
@@ -145,11 +135,11 @@ def txt2img_inf(
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
@@ -190,7 +180,7 @@ def txt2img_inf(
return generated_imgs, text_output
if __name__ == "__main__":
def main():
if args.clear_all:
clear_all()
@@ -200,7 +190,6 @@ if __name__ == "__main__":
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
use_lora = args.use_lora
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
@@ -216,7 +205,8 @@ if __name__ == "__main__":
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=use_lora,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
)
for current_batch in range(args.batch_count):
@@ -256,3 +246,7 @@ if __name__ == "__main__":
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
@@ -12,8 +13,6 @@ from apps.stable_diffusion.src import (
)
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -41,15 +40,16 @@ def upscaler_inf(
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
@@ -62,10 +62,6 @@ def upscaler_inf(
image = init_image.convert("RGB").resize((height, width))
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
@@ -83,6 +79,10 @@ def upscaler_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
args.height = 128
@@ -97,7 +97,7 @@ def upscaler_inf(
args.height,
args.width,
device,
use_lora=None,
use_lora=args.use_lora,
use_stencil=None,
)
if (
@@ -118,8 +118,8 @@ def upscaler_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
UpscalerPipeline.from_pretrained(
scheduler_obj,
@@ -135,11 +135,14 @@ def upscaler_inf(
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.get_sd_obj().low_res_scheduler = schedulers["DDPM"]
global_obj.set_sd_scheduler(scheduler)
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
"DDPM"
)
start_time = time.time()
global_obj.get_sd_obj().log = ""
@@ -232,6 +235,7 @@ if __name__ == "__main__":
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
ddpm_scheduler=schedulers["DDPM"],
)

View File

@@ -43,7 +43,7 @@ hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
a = Analysis(
['scripts/txt2img.py'],
['scripts/main.py'],
pathex=['.'],
binaries=binaries,
datas=datas,

View File

@@ -18,6 +18,7 @@ from apps.stable_diffusion.src.utils import (
get_path_stem,
get_extended_name,
get_stencil_model_id,
update_lora_weight,
)
@@ -99,7 +100,8 @@ class SharkifyStableDiffusionModel:
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
use_lora: str = ""
use_lora: str = "",
use_quantize: str = None,
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -107,6 +109,7 @@ class SharkifyStableDiffusionModel:
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.use_quantize = use_quantize
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
@@ -200,6 +203,7 @@ class SharkifyStableDiffusionModel:
use_tuned=self.use_tuned,
model_name=self.model_name["vae_encode"],
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_vae_encode
@@ -255,6 +259,7 @@ class SharkifyStableDiffusionModel:
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_vae
@@ -281,13 +286,14 @@ class SharkifyStableDiffusionModel:
use_tuned=self.use_tuned,
model_name=self.model_name["vae"],
extra_args=get_opt_flags("vae", precision="fp32"),
base_model_id=self.base_model_id,
)
return shark_vae
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -295,6 +301,8 @@ class SharkifyStableDiffusionModel:
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
@@ -333,6 +341,7 @@ class SharkifyStableDiffusionModel:
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_controlled_unet
@@ -386,6 +395,7 @@ class SharkifyStableDiffusionModel:
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_cnet
@@ -399,7 +409,7 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
self.unet.load_attn_procs(use_lora)
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
@@ -444,6 +454,7 @@ class SharkifyStableDiffusionModel:
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_unet
@@ -481,18 +492,21 @@ class SharkifyStableDiffusionModel:
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
)
return shark_unet
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
def forward(self, input):
return self.text_encoder(input)[0]
@@ -512,6 +526,7 @@ class SharkifyStableDiffusionModel:
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
)
return shark_clip
@@ -539,6 +554,7 @@ class SharkifyStableDiffusionModel:
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
# configiration.
def compile_all(self, base_model_id, need_vae_encode, need_stencil):
self.base_model_id = base_model_id
self.inputs = get_input_info(
base_models[base_model_id],
self.max_len,
@@ -556,7 +572,12 @@ class SharkifyStableDiffusionModel:
compiled_controlnet = self.get_control_net()
compiled_controlled_unet = self.get_controlled_unet()
else:
compiled_unet = self.get_unet()
# TODO: Plug the experimental "int8" support at right place.
if self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
compiled_unet = get_unet()
else:
compiled_unet = self.get_unet()
if self.custom_vae != "":
print("Plugging in custom Vae")
compiled_vae = self.get_vae()

View File

@@ -20,6 +20,15 @@ hf_model_variant_map = {
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
bucket_key = "gs://shark_tank/prashant_nod"
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
@@ -41,6 +50,12 @@ def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
# TODO: Get the quantize model from model_db.json
if args.use_quantize == "int8":
bk, mk, flags = get_quantize_model()
return get_shark_model(bk, mk, flags)
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"

View File

@@ -321,6 +321,7 @@ class StableDiffusionPipeline:
use_stencil: str = None,
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
is_inpaint = cls.__name__ in [
"InpaintPipeline",
@@ -349,6 +350,7 @@ class StableDiffusionPipeline:
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
use_quantize=use_quantize,
)
if cls.__name__ in [
"Image2ImagePipeline",

View File

@@ -33,4 +33,5 @@ from apps.stable_diffusion.src.utils.utils import (
clear_all,
save_output_img,
get_generation_text_info,
update_lora_weight,
)

View File

@@ -76,18 +76,19 @@ def load_winograd_configs():
return winograd_config_dir
def load_lower_configs():
def load_lower_configs(base_model_id=None):
from apps.stable_diffusion.src.models import get_variant_version
from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
)
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
if not base_model_id:
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
variant, version = get_variant_version(base_model_id)
@@ -114,7 +115,14 @@ def load_lower_configs():
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
and args.width == 768
):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
@@ -212,7 +220,7 @@ def annotate_with_lower_configs(
return bytecode
def sd_model_annotation(mlir_model, model_name):
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
@@ -220,7 +228,7 @@ def sd_model_annotation(mlir_model, model_name):
winograd_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs()
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
winograd_model, lowering_config_dir, model_name, use_winograd
)
@@ -232,7 +240,7 @@ def sd_model_annotation(mlir_model, model_name):
)
else:
use_winograd = False
lowering_config_dir = load_lower_configs()
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)

View File

@@ -22,6 +22,12 @@ p = argparse.ArgumentParser(
### Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
)
p.add_argument(
"-p",
"--prompts",
@@ -340,6 +346,14 @@ p.add_argument(
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################

View File

@@ -9,6 +9,8 @@ from pathlib import Path
import numpy as np
from random import randint
import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
@@ -21,7 +23,7 @@ from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
download_from_original_stable_diffusion_ckpt,
)
@@ -78,7 +80,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
mlir_model, device=args.device, mlir_dialect="tm_tensor"
)
return _compile_module(shark_module, model_name, extra_args)
@@ -95,6 +97,7 @@ def compile_through_fx(
debug=False,
generate_vmfb=True,
extra_args=[],
base_model_id=None,
):
from shark.parser import shark_args
@@ -116,19 +119,21 @@ def compile_through_fx(
if use_tuned:
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
mlir_module = sd_model_annotation(mlir_module, model_name)
mlir_module = sd_model_annotation(
mlir_module, model_name, base_model_id
)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
mlir_dialect="tm_tensor",
)
if generate_vmfb:
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
mlir_dialect="tm_tensor",
)
del mlir_module
gc.collect()
@@ -264,8 +269,9 @@ def set_init_device_flags():
if (
args.precision != "fp16"
or args.height != 512
or args.width != 512
or args.height not in [512, 768]
or (args.height == 512 and args.width != 512)
or (args.height == 768 and args.width != 768)
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
@@ -299,6 +305,20 @@ def set_init_device_flags():
]:
args.use_tuned = False
elif (
args.height == 768
and args.width == 768
and (
base_model_id
not in [
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
]
or "rdna3" not in args.iree_vulkan_target_triple
)
):
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
else:
@@ -368,7 +388,7 @@ def get_available_devices():
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
available_devices.append("device => cpu")
return available_devices
@@ -454,7 +474,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
num_in_channels = 9 if is_inpaint else 4
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
@@ -464,6 +484,115 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
print("Loading complete")
def processLoRA(model, use_lora, splitting_prefix):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
alpha = 0.75
visited = []
# directly update weight in model
process_unet = "te" not in splitting_prefix
for key in state_dict:
if ".alpha" in key or key in visited:
continue
curr_layer = model
if ("text" not in key and process_unet) or (
"text" in key and not process_unet
):
layer_infos = (
key.split(".")[0].split(splitting_prefix)[-1].split("_")
)
else:
continue
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = (
state_dict[pair_keys[0]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
weight_down = (
state_dict[pair_keys[1]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
return model
def update_lora_weight_for_unet(unet, use_lora):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
return unet
main_file_name = get_path_stem(use_lora)
if ".bin" in use_lora:
main_file_name += ".bin"
elif ".safetensors" in use_lora:
main_file_name += ".safetensors"
elif ".pt" in use_lora:
main_file_name += ".pt"
else:
sys.exit("Only .bin and .safetensors format for LoRA is supported")
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_")
def update_lora_weight(model, use_lora, model_name):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora)
try:
return processLoRA(model, use_lora, "lora_te_")
except:
return None
def load_vmfb(vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model

View File

@@ -1,5 +1,6 @@
import os
import sys
import transformers
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
@@ -66,7 +67,7 @@ from apps.stable_diffusion.web.ui import (
)
# init global sd pipeline and config
global_obj.init()
global_obj._init()
def register_button_click(button, selectedid, inputs, outputs):
@@ -94,6 +95,8 @@ with gr.Blocks(
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.Tabs(visible=False) as experimental_tabs:
with gr.TabItem(label="LoRA Training", id=5):
lora_train_web.render()

View File

@@ -178,6 +178,7 @@ footer {
/* Hide "remove buttons" from ui dropdowns */
#custom_model .token-remove.remove-all,
#lora_weights .token-remove.remove-all,
#scheduler .token-remove.remove-all,
#device .token-remove.remove-all,
#stencil_model .token-remove.remove-all {

View File

@@ -77,10 +77,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",

View File

@@ -72,10 +72,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",

View File

@@ -69,10 +69,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",

View File

@@ -73,10 +73,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files(),
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",

View File

@@ -65,6 +65,21 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Input Image", type="pil"
).style(height=300)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -226,6 +241,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
],
outputs=[upscaler_gallery, std_output],
show_progress=args.progress_bar,

View File

@@ -74,25 +74,50 @@ def resource_path(relative_path):
return os.path.join(base_path, relative_path)
def get_custom_model_path():
return Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd(), "models")
def get_custom_model_path(model="models"):
prefix = Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd())
match model:
case "models":
return Path(prefix, "models")
case "vae":
return Path(prefix, "models/vae")
case "lora":
return Path(prefix, "models/lora")
case _:
return ""
def get_custom_model_pathfile(custom_model_name):
return os.path.join(get_custom_model_path(), custom_model_name)
def get_custom_model_pathfile(custom_model_name, model="models"):
return os.path.join(get_custom_model_path(model), custom_model_name)
def get_custom_model_files():
def get_custom_model_files(model="models"):
ckpt_files = []
for extn in custom_model_filetypes:
file_types = custom_model_filetypes
if model == "lora":
file_types = custom_model_filetypes + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_custom_model_path(), extn))
for x in glob.glob(
os.path.join(get_custom_model_path(model), extn)
)
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_custom_vae_or_lora_weights(weights, hf_id, model):
use_weight = ""
if weights == "None" and not hf_id:
use_weight = ""
elif not hf_id:
use_weight = get_custom_model_pathfile(weights, model)
else:
use_weight = hf_id
return use_weight
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:

View File

@@ -8,49 +8,64 @@ Also we could avoid memory leak when switching models by clearing the cache.
"""
def init():
global sd_obj
global config_obj
sd_obj = None
config_obj = None
def _init():
global _sd_obj
global _config_obj
global _schedulers
_sd_obj = None
_config_obj = None
_schedulers = None
def set_sd_obj(value):
global sd_obj
sd_obj = value
global _sd_obj
_sd_obj = value
def set_cfg_obj(value):
global config_obj
config_obj = value
def set_schedulers(value):
global sd_obj
sd_obj.scheduler = value
def get_sd_obj():
return sd_obj
def get_cfg_obj():
return config_obj
def set_sd_scheduler(key):
global _sd_obj
_sd_obj.scheduler = _schedulers[key]
def set_sd_status(value):
global sd_obj
sd_obj.status = value
global _sd_obj
_sd_obj.status = value
def set_cfg_obj(value):
global _config_obj
_config_obj = value
def set_schedulers(value):
global _schedulers
_schedulers = value
def get_sd_obj():
return _sd_obj
def get_sd_status():
global sd_obj
return sd_obj.status
return _sd_obj.status
def get_cfg_obj():
return _config_obj
def get_scheduler(key):
return _schedulers[key]
def clear_cache():
global sd_obj
global config_obj
del sd_obj
del config_obj
global _sd_obj
global _config_obj
global _schedulers
del _sd_obj
del _config_obj
del _schedulers
gc.collect()
_sd_obj = None
_config_obj = None
_schedulers = None

View File

@@ -87,11 +87,22 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]
counter = 0
for import_opt in import_options:
for model_name in hf_model_names:
if model_name in to_skip:
continue
for use_tune in tuned_options:
if (
model_name == "stabilityai/stable-diffusion-2-1"
and use_tune == tuned_options[0]
):
continue
elif (
model_name == "stabilityai/stable-diffusion-2-1-base"
and use_tune == tuned_options[1]
):
continue
command = (
[
executable, # executable is the python from the venv used to run this
@@ -174,9 +185,23 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
else:
print(command)
print("failed to generate image for this configuration")
if "2_1_base" in model_name:
with open(dumpfile_name, "r+") as f:
output = f.readlines()
print("\n".join(output))
if model_name == "CompVis/stable-diffusion-v1-4":
print("failed a known successful model.")
exit(1)
if os.name == "nt":
counter += 1
if counter % 2 == 0:
extra_flags.append(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
else:
if counter != 1:
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)

View File

@@ -70,3 +70,9 @@ def pytest_addoption(parser):
default="./temp_dispatch_benchmarks",
help="Directory in which dispatch benchmarks are saved.",
)
parser.addoption(
"--batchsize",
default=1,
type=int,
help="Batch size for the tested model.",
)

View File

@@ -6,36 +6,16 @@ from distutils.sysconfig import get_python_lib
import fileinput
from pathlib import Path
# Diffusers 0.13.1 fails with transformers __init.py errros in BLIP. So remove it for now until we fork it
pix2pix_init = Path(get_python_lib() + "/diffusers/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(get_python_lib() + "/diffusers/pipelines/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(
get_python_lib() + "/diffusers/pipelines/stable_diffusion/__init__.py"
# Temorary workaround for transformers/__init__.py.
path_to_tranformers_hook = Path(
get_python_lib()
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
)
for line in fileinput.input(pix2pix_init, inplace=True):
if "StableDiffusionPix2PixZeroPipeline" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
if path_to_tranformers_hook.is_file():
pass
else:
with open(path_to_tranformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")

View File

@@ -2,8 +2,8 @@
--pre
numpy>1.22.4
torchvision
pytorch-triton
torchvision==0.16.0.dev20230322
tabulate
tqdm
@@ -15,8 +15,8 @@ iree-tools-tf
# TensorFlow and JAX.
gin-config
tf-nightly
keras>=2.10
tensorflow>2.11
keras
#tf-models-nightly
#tensorflow-text-nightly
transformers

View File

@@ -16,7 +16,7 @@ parameterized
# Add transformers, diffusers and scipy since it most commonly used
transformers
diffusers
diffusers @ git+https://github.com/huggingface/diffusers@main
scipy
ftfy
gradio

View File

@@ -45,7 +45,7 @@ if ($arguments -eq "--force"){
Remove-Item .\shark.venv -Force -Recurse
if (Test-Path .\shark.venv\) {
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
break
exit 1
}
}
}
@@ -78,12 +78,12 @@ if (!($PyVer.length -ne 0)) {$p} # return Python --version String if py.exe is u
if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any list
{
Write-Host "Please install Python 3.11 and try again"
break
exit 34
}
Write-Host "Installing Build Dependencies"
# make sure we really use 3.11 from list, even if it's not the default.
if (!($PyVer.length -ne 0)) {py -3.11 -m venv .\shark.venv\}
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
else {python -m venv .\shark.venv\}
.\shark.venv\Scripts\activate
python -m pip install --upgrade pip

View File

@@ -129,11 +129,11 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
TV_VERSION=${TV_VER:9:18}
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu117."
echo "Successfully Installed torch + cu118."
else
echo "Could not install torch + cu117." >&2
echo "Could not install torch + cu118." >&2
fi
fi

View File

@@ -52,7 +52,7 @@ def get_iree_device_args(device, extra_args=[]):
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg"]:
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
return ["--iree-llvmcpu-target-cpu-features=host"]
elif frontend in ["tensorflow", "tf", "mhlo"]:
return [

View File

@@ -30,11 +30,10 @@ def get_iree_gpu_args():
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
) and (shark_args.enable_tf32 == True):
return [
"--iree-hal-cuda-disable-loop-nounroll-wa",
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
]
else:
return ["--iree-hal-cuda-disable-loop-nounroll-wa"]
return []
# Get the default gpu args given the architecture.

View File

@@ -78,6 +78,7 @@ class SharkBenchmarkRunner(SharkRunner):
self.vmfb_file = None
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
self.import_args = {}
SharkRunner.__init__(
self,
mlir_module,
@@ -112,24 +113,31 @@ class SharkBenchmarkRunner(SharkRunner):
def benchmark_torch(self, modelname):
import torch
import torch._dynamo as dynamo
from tank.model_utils import get_torch_model
if self.device == "cuda":
torch.set_default_tensor_type(torch.cuda.FloatTensor)
if self.enable_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
print(
"Currently disabled TensorFloat32 calculations in pytorch benchmarks."
)
# torch.backends.cuda.matmul.allow_tf32 = True
else:
torch.set_default_tensor_type(torch.FloatTensor)
torch_device = torch.device(
"cuda:0" if self.device == "cuda" else "cpu"
)
HFmodel, input = get_torch_model(modelname)[:2]
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
frontend_model = HFmodel.model
frontend_model.to(torch_device)
input.to(torch_device)
# frontend_model = torch.compile(frontend_model, mode="max-autotune", backend="inductor")
try:
frontend_model = torch.compile(
frontend_model, mode="max-autotune", backend="inductor"
)
except RuntimeError:
frontend_model = HFmodel.model
for i in range(shark_args.num_warmup_iterations):
frontend_model.forward(input)
@@ -178,7 +186,7 @@ class SharkBenchmarkRunner(SharkRunner):
model,
input,
) = get_tf_model(
modelname
modelname, self.import_args
)[:2]
frontend_model = model
@@ -338,11 +346,19 @@ for currently supported models. Exiting benchmark ONNX."
return comp_str
def benchmark_all_csv(
self, inputs: tuple, modelname, dynamic, device_str, frontend
self,
inputs: tuple,
modelname,
dynamic,
device_str,
frontend,
import_args,
):
self.setup_cl(inputs)
self.import_args = import_args
field_names = [
"model",
"batch_size",
"engine",
"dialect",
"device",
@@ -375,6 +391,7 @@ for currently supported models. Exiting benchmark ONNX."
writer = csv.DictWriter(f, fieldnames=field_names)
bench_info = {}
bench_info["model"] = modelname
bench_info["batch_size"] = str(import_args["batch_size"])
bench_info["dialect"] = self.mlir_dialect
bench_info["iterations"] = shark_args.num_iterations
if dynamic == True:

View File

@@ -139,11 +139,21 @@ def download_model(
tank_url="gs://shark_tank/latest",
frontend=None,
tuned=None,
import_args={"batch_size": "1"},
):
model_name = model_name.replace("/", "_")
dyn_str = "_dynamic" if dynamic else ""
os.makedirs(WORKDIR, exist_ok=True)
model_dir_name = model_name + "_" + frontend
if import_args["batch_size"] != 1:
model_dir_name = (
model_name
+ "_"
+ frontend
+ "_BS"
+ str(import_args["batch_size"])
)
else:
model_dir_name = model_name + "_" + frontend
model_dir = os.path.join(WORKDIR, model_dir_name)
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
@@ -201,7 +211,9 @@ def download_model(
from tank.generate_sharktank import gen_shark_files
tank_dir = WORKDIR
gen_shark_files(model_name, frontend, tank_dir)
gen_shark_files(model_name, frontend, tank_dir, import_args)
with open(filename, mode="rb") as f:
mlir_file = f.read()
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
inputs = np.load(os.path.join(model_dir, "inputs.npz"))

View File

@@ -35,18 +35,12 @@ squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","mac
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
t5-large,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243",""
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,False,False,"Torchvision imports issue",""
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"",""
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"",""
gpt2,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
t5-base,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
t5-large,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported",""
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
1 resnet50 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False False macos
35 wide_resnet50_2 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col False False False macos
36 efficientnet-v2-s mhlo tf 1e-02 1e-3 default nhcw-nhwc False False False macos
37 mnasnet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True macos
38 t5-base efficientnet_b0 linalg torch 1e-2 1e-3 default None nhcw-nhwc False True False True False https://github.com/nod-ai/SHARK/issues/1243
39 t5-base efficientnet_b7 mhlo linalg tf torch 1e-2 1e-3 default None nhcw-nhwc False True False False Torchvision imports issue
t5-large linalg torch 1e-2 1e-3 default None False False False
t5-large mhlo tf 1e-2 1e-3 default None False False False
efficientnet_b0 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False
efficientnet_b7 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False
40 efficientnet_b0 mhlo tf 1e-2 1e-3 default None nhcw-nhwc False False False
41 efficientnet_b7 mhlo tf 1e-2 1e-3 default None nhcw-nhwc False False False
efficientnet_b0 mhlo tf 1e-2 1e-3 default None nhcw-nhwc False False
efficientnet_b7 mhlo tf 1e-2 1e-3 default None nhcw-nhwc False False
42 gpt2 mhlo tf 1e-2 1e-3 default None False False False
43 t5-base linalg torch 1e-2 1e-3 default None False True False True False True Inputs for seq2seq models in torch currently unsupported.
44 t5-base mhlo tf 1e-2 1e-3 default None False False False
45 t5-large linalg torch 1e-2 1e-3 default None False True False True False True Inputs for seq2seq models in torch currently unsupported
46 t5-large mhlo tf 1e-2 1e-3 default None False False False

View File

@@ -70,7 +70,7 @@ if __name__ == "__main__":
backend_config = "dylib"
# backend = "cuda"
# backend_config = "cuda"
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
flatbuffer_blob = compile_str(
compiler_module,
target_backends=[backend],

View File

@@ -146,7 +146,6 @@ if __name__ == "__main__":
backend_config = "cuda"
args = [
"--iree-cuda-llvm-target-arch=sm_80",
"--iree-hal-cuda-disable-loop-nounroll-wa",
"--iree-enable-fusion-with-reduction-ops",
]

View File

@@ -91,7 +91,7 @@ if __name__ == "__main__":
backend_config = "dylib"
# backend = "cuda"
# backend_config = "cuda"
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
flatbuffer_blob = compile_str(
compiler_module,
target_backends=[backend],

View File

@@ -86,7 +86,7 @@ if __name__ == "__main__":
backend_config = "dylib"
# backend = "cuda"
# backend_config = "cuda"
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
flatbuffer_blob = compile_str(
compiler_module,
target_backends=[backend],

View File

@@ -33,7 +33,7 @@ def create_hash(file_name):
return file_hash.hexdigest()
def save_torch_model(torch_model_list, local_tank_cache):
def save_torch_model(torch_model_list, local_tank_cache, import_args):
from tank.model_utils import (
get_hf_model,
get_hf_seq2seq_model,
@@ -59,7 +59,6 @@ def save_torch_model(torch_model_list, local_tank_cache):
if model_type == "stable_diffusion":
args.use_tuned = False
args.import_mlir = True
args.use_tuned = False
args.local_tank_cache = local_tank_cache
precision_values = ["fp16"]
@@ -75,6 +74,7 @@ def save_torch_model(torch_model_list, local_tank_cache):
width=512,
height=512,
use_base_vae=False,
custom_vae="",
debug=True,
sharktank_dir=local_tank_cache,
generate_vmfb=False,
@@ -82,19 +82,33 @@ def save_torch_model(torch_model_list, local_tank_cache):
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(torch_model_name)
model, input, _ = get_vision_model(
torch_model_name, import_args
)
elif model_type == "hf":
model, input, _ = get_hf_model(torch_model_name)
model, input, _ = get_hf_model(torch_model_name, import_args)
elif model_type == "hf_seq2seq":
model, input, _ = get_hf_seq2seq_model(torch_model_name)
model, input, _ = get_hf_seq2seq_model(
torch_model_name, import_args
)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(torch_model_name)
model, input, _ = get_hf_img_cls_model(
torch_model_name, import_args
)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name)
model, input, _ = get_fp16_model(torch_model_name, import_args)
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
local_tank_cache, str(torch_model_name) + "_torch"
)
if import_args["batch_size"] != 1:
torch_model_dir = os.path.join(
local_tank_cache,
str(torch_model_name)
+ "_torch"
+ f"_BS{str(import_args['batch_size'])}",
)
else:
torch_model_dir = os.path.join(
local_tank_cache, str(torch_model_name) + "_torch"
)
os.makedirs(torch_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
@@ -118,7 +132,7 @@ def save_torch_model(torch_model_list, local_tank_cache):
)
def save_tf_model(tf_model_list, local_tank_cache):
def save_tf_model(tf_model_list, local_tank_cache, import_args):
from tank.model_utils_tf import (
get_causal_image_model,
get_masked_lm_model,
@@ -150,20 +164,38 @@ def save_tf_model(tf_model_list, local_tank_cache):
input = None
print(f"Generating artifacts for model {tf_model_name}")
if model_type == "hf":
model, input, _ = get_causal_lm_model(tf_model_name)
model, input, _ = get_masked_lm_model(
tf_model_name, import_args
)
elif model_type == "img":
model, input, _ = get_causal_image_model(tf_model_name)
model, input, _ = get_causal_image_model(
tf_model_name, import_args
)
elif model_type == "keras":
model, input, _ = get_keras_model(tf_model_name)
model, input, _ = get_keras_model(tf_model_name, import_args)
elif model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name)
model, input, _ = get_TFhf_model(tf_model_name, import_args)
elif model_type == "tfhf_seq2seq":
model, input, _ = get_tfhf_seq2seq_model(tf_model_name)
model, input, _ = get_tfhf_seq2seq_model(
tf_model_name, import_args
)
elif model_type == "hf_causallm":
model, input, _ = get_causal_lm_model(
tf_model_name, import_args
)
tf_model_name = tf_model_name.replace("/", "_")
tf_model_dir = os.path.join(
local_tank_cache, str(tf_model_name) + "_tf"
)
if import_args["batch_size"] != 1:
tf_model_dir = os.path.join(
local_tank_cache,
str(tf_model_name)
+ "_tf"
+ f"_BS{str(import_args['batch_size'])}",
)
else:
tf_model_dir = os.path.join(
local_tank_cache, str(tf_model_name) + "_tf"
)
os.makedirs(tf_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
model,
@@ -175,13 +207,9 @@ def save_tf_model(tf_model_list, local_tank_cache):
dir=tf_model_dir,
model_name=tf_model_name,
)
mlir_hash = create_hash(
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
)
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
def save_tflite_model(tflite_model_list, local_tank_cache):
def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
from shark.tflite_utils import TFLitePreprocessor
with open(tflite_model_list) as csvfile:
@@ -198,13 +226,13 @@ def save_tflite_model(tflite_model_list, local_tank_cache):
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
# Preprocess to get SharkImporter input args
# Preprocess to get SharkImporter input import_args
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()
# Use SharkImporter to get SharkInference input args
# Use SharkImporter to get SharkInference input import_args
my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
@@ -228,43 +256,69 @@ def save_tflite_model(tflite_model_list, local_tank_cache):
)
def gen_shark_files(modelname, frontend, tank_dir):
def check_requirements(frontend):
import importlib
has_pkgs = False
if frontend == "torch":
tv_spec = importlib.util.find_spec("torchvision")
has_pkgs = tv_spec is not None
elif frontend in ["tensorflow", "tf"]:
tf_spec = importlib.util.find_spec("tensorflow")
has_pkgs = tf_spec is not None
return has_pkgs
class NoImportException(Exception):
"Raised when requirements are not met for OTF model artifact generation."
pass
def gen_shark_files(modelname, frontend, tank_dir, importer_args):
# If a model's artifacts are requested by shark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
# TODO: Add TFlite support.
import tempfile
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
)
# Create a temporary .csv with only the desired entry.
if frontend == "tf":
with open(tf_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_tf_model(custom_model_csv.name, tank_dir)
import_args = importer_args
if check_requirements(frontend):
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tf_model_list.csv"
)
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
)
# Create a temporary .csv with only the desired entry.
if frontend == "tf":
with open(tf_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_tf_model(custom_model_csv.name, tank_dir, import_args)
if frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_torch_model(custom_model_csv.name, tank_dir)
elif frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_torch_model(custom_model_csv.name, tank_dir, import_args)
else:
raise NoImportException
# Validates whether the file is present or not.
@@ -276,7 +330,7 @@ def is_valid_file(arg):
if __name__ == "__main__":
# Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality
# Note, all of these flags are overridden by the import of import_args from stable_args.py, flags are duplicated temporarily to preserve functionality
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--torch_model_csv",
@@ -304,8 +358,11 @@ if __name__ == "__main__":
# )
# parser.add_argument("--upload", type=bool, default=False)
# old_args = parser.parse_args()
# old_import_args = parser.parse_import_args()
import_args = {
"batch_size": "1",
}
print(import_args)
home = str(Path.home())
WORKDIR = os.path.join(os.path.dirname(__file__), "..", "gen_shark_tank")
torch_model_csv = os.path.join(
@@ -319,7 +376,8 @@ if __name__ == "__main__":
save_torch_model(
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
WORKDIR,
import_args,
)
save_torch_model(torch_model_csv, WORKDIR)
save_tf_model(tf_model_csv, WORKDIR)
save_tflite_model(tflite_model_csv, WORKDIR)
save_torch_model(torch_model_csv, WORKDIR, import_args)
save_tf_model(tf_model_csv, WORKDIR, import_args)
save_tflite_model(tflite_model_csv, WORKDIR, import_args)

View File

@@ -1,5 +1,4 @@
from shark.shark_inference import SharkInference
from shark.parser import shark_args
import torch
import numpy as np
@@ -35,17 +34,17 @@ hf_seq2seq_models = [
]
def get_torch_model(modelname):
def get_torch_model(modelname, import_args):
if modelname in vision_models:
return get_vision_model(modelname)
return get_vision_model(modelname, import_args)
elif modelname in hf_img_cls_models:
return get_hf_img_cls_model(modelname)
return get_hf_img_cls_model(modelname, import_args)
elif modelname in hf_seq2seq_models:
return get_hf_seq2seq_model(modelname)
return get_hf_seq2seq_model(modelname, import_args)
elif "fp16" in modelname:
return get_fp16_model(modelname)
return get_fp16_model(modelname, import_args)
else:
return get_hf_model(modelname)
return get_hf_model(modelname, import_args)
##################### Hugging Face Image Classification Models ###################################
@@ -88,14 +87,14 @@ class HuggingFaceImageClassification(torch.nn.Module):
return self.model.forward(inputs)[0]
def get_hf_img_cls_model(name):
def get_hf_img_cls_model(name, import_args):
model = HuggingFaceImageClassification(name)
# you can use preprocess_input_image to get the test_input or just random value.
test_input = preprocess_input_image(name)
# test_input = torch.FloatTensor(1, 3, 224, 224).uniform_(-1, 1)
# print("test_input.shape: ", test_input.shape)
# test_input.shape: torch.Size([1, 3, 224, 224])
test_input = test_input.repeat(BATCH_SIZE, 1, 1, 1)
test_input = test_input.repeat(import_args["batch_size"], 1, 1, 1)
actual_out = model(test_input)
# print("actual_out.shape ", actual_out.shape)
# actual_out.shape torch.Size([1, 1000])
@@ -125,14 +124,13 @@ class HuggingFaceLanguage(torch.nn.Module):
return self.model.forward(tokens)[0]
def get_hf_model(name):
def get_hf_model(name, import_args):
from transformers import (
BertTokenizer,
)
model = HuggingFaceLanguage(name)
# TODO: Currently the test input is set to (1,128)
test_input = torch.randint(2, (BATCH_SIZE, 128))
test_input = torch.randint(2, (import_args["batch_size"], 128))
actual_out = model(test_input)
return model, test_input, actual_out
@@ -165,7 +163,7 @@ class HFSeq2SeqLanguageModel(torch.nn.Module):
)[0]
def get_hf_seq2seq_model(name):
def get_hf_seq2seq_model(name, import_args):
m = HFSeq2SeqLanguageModel(name)
encoded_input_ids = m.preprocess_input(
"Studies have been shown that owning a dog is good for you"
@@ -193,7 +191,7 @@ class VisionModule(torch.nn.Module):
return self.model.forward(input)
def get_vision_model(torch_model):
def get_vision_model(torch_model, import_args):
import torchvision.models as models
default_image_size = (224, 224)
@@ -239,7 +237,7 @@ def get_vision_model(torch_model):
fp16_model = True
torch_model, input_image_size = vision_models_dict[torch_model]
model = VisionModule(torch_model)
test_input = torch.randn(BATCH_SIZE, 3, 224, 224)
test_input = torch.randn(import_args["batch_size"], 3, *input_image_size)
actual_out = model(test_input)
if fp16_model is not None:
test_input_fp16 = test_input.to(
@@ -280,14 +278,14 @@ class BertHalfPrecisionModel(torch.nn.Module):
return self.model.forward(tokens)[0]
def get_fp16_model(torch_model):
def get_fp16_model(torch_model, import_args):
from transformers import AutoTokenizer
modelname = torch_model.replace("_fp16", "")
model = BertHalfPrecisionModel(modelname)
tokenizer = AutoTokenizer.from_pretrained(modelname)
text = "Replace me by any text you like."
text = [text] * BATCH_SIZE
text = [text] * import_args["batch_size"]
test_input_fp16 = tokenizer(
text,
truncation=True,

View File

@@ -1,10 +1,5 @@
import tensorflow as tf
import numpy as np
from transformers import (
AutoModelForSequenceClassification,
BertTokenizer,
TFBertModel,
)
BATCH_SIZE = 1
@@ -52,19 +47,19 @@ img_models = [
]
def get_tf_model(name):
def get_tf_model(name, import_args):
if name in keras_models:
return get_keras_model(name)
return get_keras_model(name, import_args)
elif name in maskedlm_models:
return get_masked_lm_model(name)
return get_masked_lm_model(name, import_args)
elif name in causallm_models:
return get_causal_lm_model(name)
return get_causal_lm_model(name, import_args)
elif name in tfhf_models:
return get_TFhf_model(name)
return get_TFhf_model(name, import_args)
elif name in img_models:
return get_causal_image_model(name)
return get_causal_image_model(name, import_args)
elif name in tfhf_seq2seq_models:
return get_tfhf_seq2seq_model(name)
return get_tfhf_seq2seq_model(name, import_args)
else:
raise Exception(
"TF model not found! Please check that the modelname has been input correctly."
@@ -72,6 +67,12 @@ def get_tf_model(name):
##################### Tensorflow Hugging Face Bert Models ###################################
from transformers import (
AutoModelForSequenceClassification,
BertTokenizer,
TFBertModel,
)
BERT_MAX_SEQUENCE_LENGTH = 128
# Create a set of 2-dimensional inputs
@@ -104,7 +105,7 @@ class TFHuggingFaceLanguage(tf.Module):
return self.m.predict(input_ids, attention_mask, token_type_ids)
def get_TFhf_model(name):
def get_TFhf_model(name, import_args):
model = TFHuggingFaceLanguage(name)
tokenizer = BertTokenizer.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased"
@@ -166,7 +167,6 @@ def preprocess_input(
##################### Tensorflow Hugging Face Masked LM Models ###################################
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
import tensorflow as tf
MASKED_LM_MAX_SEQUENCE_LENGTH = 128
@@ -196,7 +196,9 @@ class MaskedLM(tf.Module):
return self.m.predict(input_ids, attention_mask)
def get_masked_lm_model(hf_name, text="Hello, this is the default text."):
def get_masked_lm_model(
hf_name, import_args, text="Hello, this is the default text."
):
model = MaskedLM(hf_name)
encoded_input = preprocess_input(
hf_name, MASKED_LM_MAX_SEQUENCE_LENGTH, text
@@ -251,7 +253,9 @@ class CausalLM(tf.Module):
return self.model.predict(input_ids, attention_mask)
def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
def get_causal_lm_model(
hf_name, import_args, text="Hello, this is the default text."
):
model = CausalLM(hf_name)
batched_text = [text] * BATCH_SIZE
encoded_input = model.preprocess_input(batched_text)
@@ -306,7 +310,7 @@ class TFHFSeq2SeqLanguageModel(tf.Module):
return self.model.predict(input_ids, decoder_input_ids)
def get_tfhf_seq2seq_model(name):
def get_tfhf_seq2seq_model(name, import_args):
m = TFHFSeq2SeqLanguageModel(name)
text = "Studies have been shown that owning a dog is good for you"
batched_text = [text] * BATCH_SIZE
@@ -442,7 +446,7 @@ def load_image(path_to_image, width, height, channels):
return image
def get_keras_model(modelname):
def get_keras_model(modelname, import_args):
if modelname == "efficientnet-v2-s":
model = EfficientNetV2SModule()
elif modelname == "efficientnet_b0":
@@ -530,7 +534,7 @@ def preprocess_input_image(model_name):
return [inputs[str(*inputs)]]
def get_causal_image_model(hf_name):
def get_causal_image_model(hf_name, import_args):
model = AutoModelImageClassfication(hf_name)
test_input = preprocess_input_image(hf_name)
# TFSequenceClassifierOutput(loss=None, logits=<tf.Tensor: shape=(1, 1000), dtype=float32, numpy=

View File

@@ -8,6 +8,7 @@ from parameterized import parameterized
from shark.shark_downloader import download_model
from shark.shark_inference import SharkInference
from shark.parser import shark_args
from tank.generate_sharktank import NoImportException
import iree.compiler as ireec
import pytest
import unittest
@@ -48,7 +49,9 @@ def load_csv_and_convert(filename, gen=False):
)
# This is a pytest workaround
if gen:
with open("tank/dict_configs.py", "w+") as out:
with open(
os.path.join(os.path.dirname(__file__), "dict_configs.py"), "w+"
) as out:
out.write("ALL = [\n")
for c in model_configs:
out.write(str(c) + ",\n")
@@ -68,7 +71,9 @@ def get_valid_test_params():
dynamic_list = (True, False)
# TODO: This is soooo ugly, but for some reason creating the dict at runtime
# results in strange pytest failures.
load_csv_and_convert("tank/all_models.csv", True)
load_csv_and_convert(
os.path.join(os.path.dirname(__file__), "all_models.csv"), True
)
from tank.dict_configs import ALL
config_list = ALL
@@ -135,6 +140,9 @@ class SharkModuleTester:
self.config = config
def create_and_check_module(self, dynamic, device):
import_config = {
"batch_size": self.batch_size,
}
shark_args.local_tank_cache = self.local_tank_cache
shark_args.force_update_tank = self.update_tank
shark_args.dispatch_benchmarks = self.benchmark_dispatches
@@ -161,11 +169,17 @@ class SharkModuleTester:
if "winograd" in self.config["flags"]:
shark_args.use_winograd = True
model, func_name, inputs, golden_out = download_model(
self.config["model_name"],
tank_url=self.tank_url,
frontend=self.config["framework"],
)
try:
model, func_name, inputs, golden_out = download_model(
self.config["model_name"],
tank_url=self.tank_url,
frontend=self.config["framework"],
import_args=import_config,
)
except NoImportException:
pytest.xfail(
reason=f"Artifacts for this model/config must be generated locally. Please make sure {self.config['framework']} is installed."
)
shark_module = SharkInference(
model,
@@ -210,10 +224,12 @@ class SharkModuleTester:
self.save_reproducers()
def benchmark_module(self, shark_module, inputs, dynamic, device):
model_config = {
"batch_size": self.batch_size,
}
shark_args.enable_tf32 = self.tf32
if shark_args.enable_tf32 == True:
shark_module.compile()
shark_args.enable_tf32 = False
shark_args.onnx_bench = self.onnx_bench
shark_module.shark_runner.benchmark_all_csv(
@@ -222,6 +238,7 @@ class SharkModuleTester:
dynamic,
device,
self.config["framework"],
import_args=model_config,
)
def save_reproducers(self):
@@ -271,6 +288,9 @@ class SharkModuleTest(unittest.TestCase):
@parameterized.expand(param_list, name_func=shark_test_name_func)
def test_module(self, dynamic, device, config):
self.module_tester = SharkModuleTester(config)
self.module_tester.batch_size = self.pytestconfig.getoption(
"batchsize"
)
self.module_tester.benchmark = self.pytestconfig.getoption("benchmark")
self.module_tester.save_repro = self.pytestconfig.getoption(
"save_repro"
@@ -342,6 +362,10 @@ class SharkModuleTest(unittest.TestCase):
pytest.xfail(
reason="Numerics issues: https://github.com/nod-ai/SHARK/issues/476"
)
if config["framework"] == "tf" and self.module_tester.batch_size != 1:
pytest.xfail(
reason="Configurable batch sizes temp. unavailable for tensorflow models."
)
safe_name = (
f"{config['model_name']}_{config['framework']}_{dynamic}_{device}"
)

View File

@@ -1,4 +1,6 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
microsoft/MiniLM-L12-H384-uncased,True,hf,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
bert-base-uncased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
@@ -19,9 +21,3 @@ mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
t5-base,True,hf_seq2seq,True,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
t5-large,True,hf_seq2seq,True,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
t5-base,True,hf_seq2seq,True,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
t5-large,True,hf_seq2seq,True,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
1 model_name use_tracing model_type dynamic param_count tags notes
2 efficientnet_b0 True vision False 5.3M image-classification;cnn;conv2d;depthwise-conv Smallest EfficientNet variant with 224x224 input
3 efficientnet_b7 True vision False 66M image-classification;cnn;conv2d;depthwise-conv Largest EfficientNet variant with 600x600 input
4 microsoft/MiniLM-L12-H384-uncased True hf True 66M nlp;bert-variant;transformer-encoder Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)
5 bert-base-uncased True hf True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
6 bert-base-cased True hf True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
21 resnet50_fp16 False vision True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
22 bert-base-uncased_fp16 True fp16 False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
23 bert-large-uncased True hf True 330M nlp;bert-variant;transformer-encoder 24 layers, 1024 hidden units, 16 attention heads
t5-base True hf_seq2seq True 220M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
t5-large True hf_seq2seq True 770M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
efficientnet_b0 True vision False 5.3M image-classification;cnn;conv2d;depthwise-conv Smallest EfficientNet variant with 224x224 input
efficientnet_b7 True vision False 66M image-classification;cnn;conv2d;depthwise-conv Largest EfficientNet variant with 600x600 input
t5-base True hf_seq2seq True 220M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
t5-large True hf_seq2seq True 770M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer

View File

@@ -1,4 +1,3 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
stabilityai/stable-diffusion-2-1-base,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
stabilityai/stable-diffusion-2-1,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
prompthero/openjourney,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
1 model_name use_tracing model_type dynamic param_count tags notes
2 stabilityai/stable-diffusion-2-1-base True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A
3 stabilityai/stable-diffusion-2-1 True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A
prompthero/openjourney True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A