diff --git a/shark/examples/shark_eager/dynamo_demo.ipynb b/shark/examples/shark_eager/dynamo_demo.ipynb index 08a2b0b0..526ff95b 100644 --- a/shark/examples/shark_eager/dynamo_demo.ipynb +++ b/shark/examples/shark_eager/dynamo_demo.ipynb @@ -36,7 +36,9 @@ " from torchdynamo.optimizations.backends import create_backend\n", " from torchdynamo.optimizations.subgraph import SubGraph\n", "except ModuleNotFoundError:\n", - " print(\"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\")\n", + " print(\n", + " \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n", + " )\n", " exit()\n", "\n", "# torch-mlir imports for compiling\n", @@ -97,7 +99,9 @@ "\n", " for node in fx_g.graph.nodes:\n", " if node.op == \"output\":\n", - " assert len(node.args) == 1, \"Output node must have a single argument\"\n", + " assert (\n", + " len(node.args) == 1\n", + " ), \"Output node must have a single argument\"\n", " node_arg = node.args[0]\n", " if isinstance(node_arg, tuple) and len(node_arg) == 1:\n", " node.args = (node_arg[0],)\n", @@ -116,8 +120,12 @@ " if len(args) == 1 and isinstance(args[0], list):\n", " args = args[0]\n", "\n", - " linalg_module = compile(ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS)\n", - " callable, _ = get_iree_compiled_module(linalg_module, \"cuda\", func_name=\"forward\")\n", + " linalg_module = compile(\n", + " ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n", + " )\n", + " callable, _ = get_iree_compiled_module(\n", + " linalg_module, \"cuda\", func_name=\"forward\"\n", + " )\n", "\n", " def forward(*inputs):\n", " return callable(*inputs)\n", @@ -212,6 +220,7 @@ " assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n", " return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n", "\n", + "\n", "@torchdynamo.optimize(\"torch_mlir\")\n", "def toy_example2(*args):\n", " a, b = args\n", diff --git a/shark/examples/shark_inference/stable_diffusion/main.py b/shark/examples/shark_inference/stable_diffusion/main.py index c156b1d9..f27014f0 100644 --- a/shark/examples/shark_inference/stable_diffusion/main.py +++ b/shark/examples/shark_inference/stable_diffusion/main.py @@ -51,7 +51,7 @@ if __name__ == "__main__": neg_prompt = args.negative_prompts height = 512 # default height of Stable Diffusion width = 512 # default width of Stable Diffusion - if args.version == "v2.1" and args.variant == "stablediffusion": + if args.version == "v2_1": height = 768 width = 768 @@ -91,7 +91,7 @@ if __name__ == "__main__": subfolder="scheduler", ) cpu_scheduling = True - if args.version == "v2.1": + if args.version == "v2_1": tokenizer = CLIPTokenizer.from_pretrained( "stabilityai/stable-diffusion-2-1", subfolder="tokenizer" ) @@ -101,7 +101,7 @@ if __name__ == "__main__": subfolder="scheduler", ) - if args.version == "v2.1base": + if args.version == "v2_1base": tokenizer = CLIPTokenizer.from_pretrained( "stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer" ) diff --git a/shark/examples/shark_inference/stable_diffusion/model_wrappers.py b/shark/examples/shark_inference/stable_diffusion/model_wrappers.py index c32f3f44..082d64fe 100644 --- a/shark/examples/shark_inference/stable_diffusion/model_wrappers.py +++ b/shark/examples/shark_inference/stable_diffusion/model_wrappers.py @@ -5,9 +5,9 @@ from stable_args import args import torch model_config = { - "v2.1": "stabilityai/stable-diffusion-2-1", - "v2.1base": "stabilityai/stable-diffusion-2-1-base", - "v1.4": "CompVis/stable-diffusion-v1-4", + "v2_1": "stabilityai/stable-diffusion-2-1", + "v2_1base": "stabilityai/stable-diffusion-2-1-base", + "v1_4": "CompVis/stable-diffusion-v1-4", } # clip has 2 variants of max length 77 or 64. @@ -24,7 +24,7 @@ model_variant = { } model_input = { - "v2.1": { + "v2_1": { "clip": (torch.randint(1, 2, (2, model_clip_max_length)),), "vae": (torch.randn(1, 4, 96, 96),), "unet": ( @@ -34,7 +34,7 @@ model_input = { torch.tensor(1).to(torch.float32), # guidance_scale ), }, - "v2.1base": { + "v2_1base": { "clip": (torch.randint(1, 2, (2, model_clip_max_length)),), "vae": (torch.randn(1, 4, 64, 64),), "unet": ( @@ -44,7 +44,7 @@ model_input = { torch.tensor(1).to(torch.float32), # guidance_scale ), }, - "v1.4": { + "v1_4": { "clip": (torch.randint(1, 2, (2, model_clip_max_length)),), "vae": (torch.randn(1, 4, 64, 64),), "unet": ( @@ -70,7 +70,7 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]): "openai/clip-vit-large-patch14" ) if args.variant == "stablediffusion": - if args.version != "v1.4": + if args.version != "v1_4": text_encoder = CLIPTextModel.from_pretrained( model_config[args.version], subfolder="text_encoder" ) @@ -102,6 +102,54 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]): return shark_clip +def get_base_vae_mlir(model_name="vae", extra_args=[]): + class BaseVaeModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.vae = AutoencoderKL.from_pretrained( + model_config[args.version] + if args.variant == "stablediffusion" + else model_variant[args.variant], + subfolder="vae", + revision=model_revision[args.variant], + ) + + def forward(self, input): + x = self.vae.decode(input, return_dict=False)[0] + return (x / 2 + 0.5).clamp(0, 1) + + vae = BaseVaeModel() + if args.variant == "stablediffusion": + if args.precision == "fp16": + vae = vae.half().cuda() + inputs = tuple( + [ + inputs.half().cuda() + for inputs in model_input[args.version]["vae"] + ] + ) + else: + inputs = model_input[args.version]["vae"] + elif args.variant in ["anythingv3", "analogdiffusion"]: + if args.precision == "fp16": + vae = vae.half().cuda() + inputs = tuple( + [inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]] + ) + else: + inputs = model_input["v1_4"]["vae"] + else: + raise (f"{args.variant} not yet added") + + shark_vae = compile_through_fx( + vae, + inputs, + model_name=model_name, + extra_args=extra_args, + ) + return shark_vae + + def get_vae_mlir(model_name="vae", extra_args=[]): class VaeModel(torch.nn.Module): def __init__(self): @@ -137,10 +185,10 @@ def get_vae_mlir(model_name="vae", extra_args=[]): if args.precision == "fp16": vae = vae.half().cuda() inputs = tuple( - [inputs.half().cuda() for inputs in model_input["v1.4"]["vae"]] + [inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]] ) else: - inputs = model_input["v1.4"]["vae"] + inputs = model_input["v1_4"]["vae"] else: raise (f"{args.variant} not yet added") @@ -197,11 +245,11 @@ def get_unet_mlir(model_name="unet", extra_args=[]): inputs = tuple( [ inputs.half().cuda() if len(inputs.shape) != 0 else inputs - for inputs in model_input["v1.4"]["unet"] + for inputs in model_input["v1_4"]["unet"] ] ) else: - inputs = model_input["v1.4"]["unet"] + inputs = model_input["v1_4"]["unet"] else: raise (f"{args.variant} is not yet added") shark_unet = compile_through_fx( diff --git a/shark/examples/shark_inference/stable_diffusion/opt_params.py b/shark/examples/shark_inference/stable_diffusion/opt_params.py index d564d3ce..f461e19d 100644 --- a/shark/examples/shark_inference/stable_diffusion/opt_params.py +++ b/shark/examples/shark_inference/stable_diffusion/opt_params.py @@ -1,9 +1,11 @@ import sys from model_wrappers import ( + get_base_vae_mlir, get_vae_mlir, get_unet_mlir, get_clip_mlir, ) +from resources import models_db from stable_args import args from utils import get_shark_model @@ -11,222 +13,110 @@ BATCH_SIZE = len(args.prompts) if BATCH_SIZE != 1: sys.exit("Only batch size 1 is supported.") -def get_unet(): + +def get_params(model_key): iree_flags = [] if len(args.iree_vulkan_target_triple) > 0: iree_flags.append( f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" ) + # Disable bindings fusion to work with moltenVK. if sys.platform == "darwin": iree_flags.append("-iree-stream-fuse-binding=false") - if args.variant == "stablediffusion": - # Tuned model is present for `fp16` precision. - if args.precision == "fp16": - if args.use_tuned: - bucket = "gs://shark_tank/vivian" - if args.version == "v1.4": - model_name = "unet_1dec_fp16_tuned" - if args.version == "v2.1base": - if args.max_length == 64: - model_name = "unet_19dec_v2p1base_fp16_64_tuned" - else: - model_name = "unet2base_8dec_fp16_tuned_v2" - return get_shark_model(bucket, model_name, iree_flags) - else: - bucket = "gs://shark_tank/stable_diffusion" - model_name = "unet_8dec_fp16" - if args.version == "v2.1base": - if args.max_length == 64: - model_name = "unet_19dec_v2p1base_fp16_64" - else: - model_name = "unet2base_8dec_fp16" - if args.version == "v2.1": - model_name = "unet2_14dec_fp16" - iree_flags += [ - "--iree-flow-enable-padding-linalg-ops", - "--iree-flow-linalg-ops-padding-size=32", - "--iree-flow-enable-conv-img2col-transform", - ] - if args.import_mlir: - return get_unet_mlir(model_name, iree_flags) - return get_shark_model(bucket, model_name, iree_flags) + try: + model_name = models_db[model_key] + except KeyError: + raise Exception(f"{model_key} is not present in the models database") - # Tuned model is not present for `fp32` case. - if args.precision == "fp32": - bucket = "gs://shark_tank/stable_diffusion" - model_name = "unet_1dec_fp32" + return model_name, iree_flags + + +def get_unet(): + # Tuned model is present only for `fp16` precision. + is_tuned = "/tuned" if args.use_tuned else "/untuned" + variant_version = args.variant + model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}{is_tuned}" + model_name, iree_flags = get_params(model_key) + if args.use_tuned: + bucket = "gs://shark_tank/vivian" + return get_shark_model(bucket, model_name, iree_flags) + else: + bucket = "gs://shark_tank/stable_diffusion" + if args.variant == "anythingv3": + bucket = "gs://shark_tank/sd_anythingv3" + elif args.variant == "analogdiffusion": + bucket = "gs://shark_tank/sd_analog_diffusion" + if args.precision == "fp16": + iree_flags += [ + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=32", + "--iree-flow-enable-conv-img2col-transform", + ] + elif args.precision == "fp32": iree_flags += [ "--iree-flow-enable-conv-nchw-to-nhwc-transform", "--iree-flow-enable-padding-linalg-ops", "--iree-flow-linalg-ops-padding-size=16", ] - if args.import_mlir: - return get_unet_mlir(model_name, iree_flags) - return get_shark_model(bucket, model_name, iree_flags) - - if args.precision == "int8": - bucket = "gs://shark_tank/prashant_nod" - model_name = "unet_int8" - iree_flags += [ - "--iree-flow-enable-padding-linalg-ops", - "--iree-flow-linalg-ops-padding-size=32", - ] - sys.exit("int8 model is currently in maintenance.") - # # TODO: Pass iree_flags to the exported model. - # if args.import_mlir: - # sys.exit( - # "--import_mlir is not supported for the int8 model, try --no-import_mlir flag." - # ) - # return get_shark_model(bucket, model_name, iree_flags) - - else: - iree_flags += [ - "--iree-flow-enable-padding-linalg-ops", - "--iree-flow-linalg-ops-padding-size=32", - "--iree-flow-enable-conv-img2col-transform", - ] - if args.variant == "anythingv3": - bucket = "gs://shark_tank/sd_anythingv3" - model_name = f"av3_unet_19dec_{args.precision}" - elif args.variant == "analogdiffusion": - bucket = "gs://shark_tank/sd_analog_diffusion" - model_name = f"ad_unet_19dec_{args.precision}" - else: - sys.exit(f"{args.variant} variant of SD is currently unsupported") - if args.import_mlir: return get_unet_mlir(model_name, iree_flags) return get_shark_model(bucket, model_name, iree_flags) def get_vae(): - iree_flags = [] - if len(args.iree_vulkan_target_triple) > 0: - iree_flags.append( - f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" - ) - # Disable bindings fusion to work with moltenVK. - if sys.platform == "darwin": - iree_flags.append("-iree-stream-fuse-binding=false") - - if args.variant == "stablediffusion": - if args.precision in ["fp16", "int8"]: - if args.use_tuned: - bucket = "gs://shark_tank/vivian" - if args.version == "v2.1base": - model_name = "vae2base_19dec_fp16_tuned" - iree_flags += [ - "--iree-flow-enable-padding-linalg-ops", - "--iree-flow-linalg-ops-padding-size=32", - "--iree-flow-enable-conv-img2col-transform", - "--iree-flow-enable-conv-winograd-transform", - ] - return get_shark_model(bucket, model_name, iree_flags) - else: - bucket = "gs://shark_tank/stable_diffusion" - model_name = "vae_19dec_fp16" - if args.version == "v2.1base": - model_name = "vae2base_19dec_fp16" - if args.version == "v2.1": - model_name = "vae2_19dec_fp16" - iree_flags += [ - "--iree-flow-enable-padding-linalg-ops", - "--iree-flow-linalg-ops-padding-size=32", - "--iree-flow-enable-conv-img2col-transform", - ] - if args.import_mlir: - return get_vae_mlir(model_name, iree_flags) - return get_shark_model(bucket, model_name, iree_flags) - - if args.precision == "fp32": - bucket = "gs://shark_tank/stable_diffusion" - model_name = "vae_1dec_fp32" + # Tuned model is present only for `fp16` precision. + is_tuned = "/tuned" if args.use_tuned else "/untuned" + is_base = "/base" if args.use_base_vae else "" + model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77{is_tuned}{is_base}" + model_name, iree_flags = get_params(model_key) + if args.use_tuned: + bucket = "gs://shark_tank/vivian" + iree_flags += [ + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=32", + "--iree-flow-enable-conv-img2col-transform", + "--iree-flow-enable-conv-winograd-transform", + ] + return get_shark_model(bucket, model_name, iree_flags) + else: + bucket = "gs://shark_tank/stable_diffusion" + if args.variant == "anythingv3": + bucket = "gs://shark_tank/sd_anythingv3" + elif args.variant == "analogdiffusion": + bucket = "gs://shark_tank/sd_analog_diffusion" + if args.precision == "fp16": + iree_flags += [ + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=32", + "--iree-flow-enable-conv-img2col-transform", + ] + elif args.precision == "fp32": iree_flags += [ "--iree-flow-enable-conv-nchw-to-nhwc-transform", "--iree-flow-enable-padding-linalg-ops", "--iree-flow-linalg-ops-padding-size=16", ] - if args.import_mlir: - return get_vae_mlir(model_name, iree_flags) - return get_shark_model(bucket, model_name, iree_flags) - - else: - iree_flags += [ - "--iree-flow-enable-padding-linalg-ops", - ] - if args.precision == "fp16": - iree_flags += [ - "--iree-flow-linalg-ops-padding-size=16", - "--iree-flow-enable-conv-img2col-transform", - ] - elif args.precision == "fp32": - iree_flags += [ - "--iree-flow-linalg-ops-padding-size=32", - "--iree-flow-enable-conv-nchw-to-nhwc-transform", - ] - else: - sys.exit("int8 precision is currently in not supported.") - - if args.variant == "anythingv3": - bucket = "gs://shark_tank/sd_anythingv3" - model_name = f"av3_vae_19dec_{args.precision}" - - elif args.variant == "analogdiffusion": - bucket = "gs://shark_tank/sd_analog_diffusion" - model_name = f"ad_vae_19dec_{args.precision}" - - else: - sys.exit(f"{args.variant} variant of SD is currently unsupported") - if args.import_mlir: - return get_unet_mlir(model_name, iree_flags) + if args.use_base_vae: + return get_base_vae_mlir(model_name, iree_flags) + return get_vae_mlir(model_name, iree_flags) return get_shark_model(bucket, model_name, iree_flags) def get_clip(): - iree_flags = [] - if len(args.iree_vulkan_target_triple) > 0: - iree_flags.append( - f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" - ) - # Disable bindings fusion to work with moltenVK. - if sys.platform == "darwin": - iree_flags.append("-iree-stream-fuse-binding=false") - - if args.variant == "stablediffusion": - bucket = "gs://shark_tank/stable_diffusion" - model_name = "clip_18dec_fp32" - if args.version == "v2.1base": - if args.max_length == 64: - model_name = "clip_19dec_v2p1base_fp32_64" - else: - model_name = "clip2base_18dec_fp32" - if args.version == "v2.1": - model_name = "clip2_18dec_fp32" - iree_flags += [ - "--iree-flow-linalg-ops-padding-size=16", - "--iree-flow-enable-padding-linalg-ops", - ] - if args.import_mlir: - return get_clip_mlir(model_name, iree_flags) - return get_shark_model(bucket, model_name, iree_flags) - + model_key = f"{args.variant}/{args.version}/clip/fp32/length_{args.max_length}/untuned" + model_name, iree_flags = get_params(model_key) + bucket = "gs://shark_tank/stable_diffusion" if args.variant == "anythingv3": bucket = "gs://shark_tank/sd_anythingv3" - model_name = "av3_clip_19dec_fp32" elif args.variant == "analogdiffusion": bucket = "gs://shark_tank/sd_analog_diffusion" - model_name = "ad_clip_19dec_fp32" - iree_flags += [ - "--iree-flow-enable-padding-linalg-ops", - "--iree-flow-linalg-ops-padding-size=16", - "--iree-flow-enable-conv-img2col-transform", - ] - else: - sys.exit(f"{args.variant} variant of SD is currently unsupported") - + iree_flags += [ + "--iree-flow-linalg-ops-padding-size=16", + "--iree-flow-enable-padding-linalg-ops", + ] if args.import_mlir: - return get_unet_mlir(model_name, iree_flags) + return get_clip_mlir(model_name, iree_flags) return get_shark_model(bucket, model_name, iree_flags) diff --git a/shark/examples/shark_inference/stable_diffusion/resources.py b/shark/examples/shark_inference/stable_diffusion/resources.py new file mode 100644 index 00000000..3f0ee67f --- /dev/null +++ b/shark/examples/shark_inference/stable_diffusion/resources.py @@ -0,0 +1,25 @@ +import os +import json +import sys + + +def resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr( + sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) + ) + return os.path.join(base_path, relative_path) + + +prompt_examples = [] +prompts_loc = resource_path("resources/prompts.json") +if os.path.exists(prompts_loc): + with open(prompts_loc, encoding="utf-8") as fopen: + prompt_examples = json.load(fopen) + + +models_db = dict() +models_loc = resource_path("resources/model_db.json") +if os.path.exists(models_loc): + with open(models_loc, encoding="utf-8") as fopen: + models_db = json.load(fopen) diff --git a/shark/examples/shark_inference/stable_diffusion/resources/model_db.json b/shark/examples/shark_inference/stable_diffusion/resources/model_db.json new file mode 100644 index 00000000..b54a32b3 --- /dev/null +++ b/shark/examples/shark_inference/stable_diffusion/resources/model_db.json @@ -0,0 +1,33 @@ +{ + "stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16", + "stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_1dec_fp16_tuned", + "stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32", + "stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16", + "stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16", + "stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32", + "stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32", + "stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet2base_8dec_fp16", + "stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2", + "stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64", + "stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned", + "stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16", + "stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned", + "stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16", + "stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned", + "stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip2base_18dec_fp32", + "stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_19dec_v2p1base_fp32_64", + "stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet2_14dec_fp16", + "stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae2_19dec_fp16", + "stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16", + "stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip2_18dec_fp32", + "anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16", + "anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32", + "anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16", + "anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32", + "anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32", + "analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16", + "analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32", + "analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16", + "analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32", + "analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32" +} diff --git a/shark/examples/shark_inference/stable_diffusion/resources/prompts.json b/shark/examples/shark_inference/stable_diffusion/resources/prompts.json new file mode 100644 index 00000000..4c8370db --- /dev/null +++ b/shark/examples/shark_inference/stable_diffusion/resources/prompts.json @@ -0,0 +1,8 @@ +[["A high tech solarpunk utopia in the Amazon rainforest"], +["A pikachu fine dining with a view to the Eiffel Tower"], +["A mecha robot in a favela in expressionist style"], +["an insect robot preparing a delicious meal"], +["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"], +["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"], +["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"], +["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]] diff --git a/shark/examples/shark_inference/stable_diffusion/stable_args.py b/shark/examples/shark_inference/stable_diffusion/stable_args.py index 04de38d0..79a51a4d 100644 --- a/shark/examples/shark_inference/stable_diffusion/stable_args.py +++ b/shark/examples/shark_inference/stable_diffusion/stable_args.py @@ -61,7 +61,7 @@ p.add_argument( p.add_argument( "--version", type=str, - default="v2.1base", + default="v2_1base", help="Specify version of stable diffusion model", ) @@ -97,11 +97,19 @@ p.add_argument( help="Download and use the tuned version of the model if available", ) +p.add_argument( + "--use_base_vae", + default=False, + action=argparse.BooleanOptionalAction, + help="Do conversion from the VAE output to pixel space on cpu.", +) + p.add_argument( "--variant", default="stablediffusion", help="We now support multiple vairants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported ) + ############################################################################## ### IREE - Vulkan supported flags ############################################################################## diff --git a/web/models/stable_diffusion/main.py b/web/models/stable_diffusion/main.py index 7cd95d1d..19f8e30a 100644 --- a/web/models/stable_diffusion/main.py +++ b/web/models/stable_diffusion/main.py @@ -158,7 +158,7 @@ def stable_diff_inf( text_output = f"prompt={args.prompts}" text_output += f"\nnegative prompt={args.negative_prompts}" text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, scheduler={scheduler_key}, seed={args.seed}, size={height}x{width}, version={args.version}" - text_output += f"\nAverage step time: {avg_ms:.2f}ms/it" - text_output += f"\nTotal image generation time: {total_time:.2f}sec" + text_output += f"\nAverage step time: {avg_ms:.4f}ms/it" + text_output += f"\nTotal image generation time: {total_time:.4f}sec" return pil_images[0], text_output diff --git a/web/models/stable_diffusion/model_wrappers.py b/web/models/stable_diffusion/model_wrappers.py index 17974fac..1d69e8b8 100644 --- a/web/models/stable_diffusion/model_wrappers.py +++ b/web/models/stable_diffusion/model_wrappers.py @@ -12,6 +12,16 @@ model_config = { # clip has 2 variants of max length 77 or 64. model_clip_max_length = 64 if args.max_length == 64 else 77 +if args.variant != "stablediffusion": + model_clip_max_length = 77 + +model_variant = { + "stablediffusion": "SD", + "anythingv3": "Linaqruf/anything-v3.0", + "dreamlike": "dreamlike-art/dreamlike-diffusion-1.0", + "openjourney": "prompthero/openjourney", + "analogdiffusion": "wavymulder/Analog-Diffusion", +} model_input = { "v2_1": { @@ -47,7 +57,11 @@ model_input = { } # revision param for from_pretrained defaults to "main" => fp32 -model_revision = "fp16" if args.precision == "fp16" else "main" +model_revision = { + "stablediffusion": "fp16" if args.precision == "fp16" else "main", + "anythingv3": "diffusers", + "analogdiffusion": "main", +} def get_clip_mlir(model_name="clip_text", extra_args=[]): @@ -55,10 +69,20 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]): text_encoder = CLIPTextModel.from_pretrained( "openai/clip-vit-large-patch14" ) - if args.version != "v1_4": + if args.variant == "stablediffusion": + if args.version != "v1_4": + text_encoder = CLIPTextModel.from_pretrained( + model_config[args.version], subfolder="text_encoder" + ) + + elif args.variant in ["anythingv3", "analogdiffusion"]: text_encoder = CLIPTextModel.from_pretrained( - model_config[args.version], subfolder="text_encoder" + model_variant[args.variant], + subfolder="text_encoder", + revision=model_revision[args.variant], ) + else: + raise (f"{args.variant} not yet added") class CLIPText(torch.nn.Module): def __init__(self): @@ -83,9 +107,11 @@ def get_base_vae_mlir(model_name="vae", extra_args=[]): def __init__(self): super().__init__() self.vae = AutoencoderKL.from_pretrained( - model_config[args.version], + model_config[args.version] + if args.variant == "stablediffusion" + else model_variant[args.variant], subfolder="vae", - revision=model_revision, + revision=model_revision[args.variant], ) def forward(self, input): @@ -93,16 +119,27 @@ def get_base_vae_mlir(model_name="vae", extra_args=[]): return (x / 2 + 0.5).clamp(0, 1) vae = BaseVaeModel() - if args.precision == "fp16": - vae = vae.half().cuda() - inputs = tuple( - [ - inputs.half().cuda() - for inputs in model_input[args.version]["vae"] - ] - ) + if args.variant == "stablediffusion": + if args.precision == "fp16": + vae = vae.half().cuda() + inputs = tuple( + [ + inputs.half().cuda() + for inputs in model_input[args.version]["vae"] + ] + ) + else: + inputs = model_input[args.version]["vae"] + elif args.variant in ["anythingv3", "analogdiffusion"]: + if args.precision == "fp16": + vae = vae.half().cuda() + inputs = tuple( + [inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]] + ) + else: + inputs = model_input["v1_4"]["vae"] else: - inputs = model_input[args.version]["vae"] + raise (f"{args.variant} not yet added") shark_vae = compile_through_fx( vae, @@ -118,9 +155,11 @@ def get_vae_mlir(model_name="vae", extra_args=[]): def __init__(self): super().__init__() self.vae = AutoencoderKL.from_pretrained( - model_config[args.version], + model_config[args.version] + if args.variant == "stablediffusion" + else model_variant[args.variant], subfolder="vae", - revision=model_revision, + revision=model_revision[args.variant], ) def forward(self, input): @@ -131,16 +170,27 @@ def get_vae_mlir(model_name="vae", extra_args=[]): return x.round() vae = VaeModel() - if args.precision == "fp16": - vae = vae.half().cuda() - inputs = tuple( - [ - inputs.half().cuda() - for inputs in model_input[args.version]["vae"] - ] - ) + if args.variant == "stablediffusion": + if args.precision == "fp16": + vae = vae.half().cuda() + inputs = tuple( + [ + inputs.half().cuda() + for inputs in model_input[args.version]["vae"] + ] + ) + else: + inputs = model_input[args.version]["vae"] + elif args.variant in ["anythingv3", "analogdiffusion"]: + if args.precision == "fp16": + vae = vae.half().cuda() + inputs = tuple( + [inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]] + ) + else: + inputs = model_input["v1_4"]["vae"] else: - inputs = model_input[args.version]["vae"] + raise (f"{args.variant} not yet added") shark_vae = compile_through_fx( vae, @@ -156,9 +206,11 @@ def get_unet_mlir(model_name="unet", extra_args=[]): def __init__(self): super().__init__() self.unet = UNet2DConditionModel.from_pretrained( - model_config[args.version], + model_config[args.version] + if args.variant == "stablediffusion" + else model_variant[args.variant], subfolder="unet", - revision=model_revision, + revision=model_revision[args.variant], ) self.in_channels = self.unet.in_channels self.train(False) @@ -176,16 +228,30 @@ def get_unet_mlir(model_name="unet", extra_args=[]): return noise_pred unet = UnetModel() - if args.precision == "fp16": - unet = unet.half().cuda() - inputs = tuple( - [ - inputs.half().cuda() if len(inputs.shape) != 0 else inputs - for inputs in model_input[args.version]["unet"] - ] - ) + if args.variant == "stablediffusion": + if args.precision == "fp16": + unet = unet.half().cuda() + inputs = tuple( + [ + inputs.half().cuda() if len(inputs.shape) != 0 else inputs + for inputs in model_input[args.version]["unet"] + ] + ) + else: + inputs = model_input[args.version]["unet"] + elif args.variant in ["anythingv3", "analogdiffusion"]: + if args.precision == "fp16": + unet = unet.half().cuda() + inputs = tuple( + [ + inputs.half().cuda() if len(inputs.shape) != 0 else inputs + for inputs in model_input["v1_4"]["unet"] + ] + ) + else: + inputs = model_input["v1_4"]["unet"] else: - inputs = model_input[args.version]["unet"] + raise (f"{args.variant} is not yet added") shark_unet = compile_through_fx( unet, inputs, diff --git a/web/models/stable_diffusion/stable_args.py b/web/models/stable_diffusion/stable_args.py index 328f359f..97114f5b 100644 --- a/web/models/stable_diffusion/stable_args.py +++ b/web/models/stable_diffusion/stable_args.py @@ -97,12 +97,6 @@ p.add_argument( help="Download and use the tuned version of the model if available", ) -p.add_argument( - "--variant", - default="stablediffusion", - help="We now support multiple variants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported -) - p.add_argument( "--use_base_vae", default=False, @@ -110,6 +104,12 @@ p.add_argument( help="Do conversion from the VAE output to pixel space on cpu.", ) +p.add_argument( + "--variant", + default="stablediffusion", + help="We now support multiple vairants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported +) + ############################################################################## ### IREE - Vulkan supported flags ############################################################################## @@ -145,6 +145,13 @@ p.add_argument( ### Misc. Debug and Optimization flags ############################################################################## +p.add_argument( + "--use_compiled_scheduler", + default=True, + action=argparse.BooleanOptionalAction, + help="use the default scheduler precompiled into the model if available", +) + p.add_argument( "--local_tank_cache", default="", @@ -179,11 +186,18 @@ p.add_argument( p.add_argument( "--hide_steps", - default=True, + default=False, action=argparse.BooleanOptionalAction, help="flag for hiding the details of iteration/sec for each step.", ) +p.add_argument( + "--warmup_count", + type=int, + default=0, + help="flag setting warmup count for clip and vae [>= 0].", +) + ############################################################################## ### Web UI flags ############################################################################## diff --git a/web/models/stable_diffusion/utils.py b/web/models/stable_diffusion/utils.py index cb636d80..e0f3285b 100644 --- a/web/models/stable_diffusion/utils.py +++ b/web/models/stable_diffusion/utils.py @@ -4,7 +4,10 @@ import torch from shark.shark_inference import SharkInference from models.stable_diffusion.stable_args import args from shark.shark_importer import import_with_fx -from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags +from shark.iree_utils.vulkan_utils import ( + set_iree_vulkan_runtime_flags, + get_vulkan_target_triple, +) def _compile_module(shark_module, model_name, extra_args=[]): @@ -86,3 +89,102 @@ def set_iree_runtime_flags(): set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) return + + +def set_init_device_flags(): + def get_all_devices(driver_name): + """ + Inputs: driver_name + Returns a list of all the available devices for a given driver sorted by + the iree path names of the device as in --list_devices option in iree. + Set `full_dict` flag to True to get a dict + with `path`, `name` and `device_id` for all devices + """ + from iree.runtime import get_driver + + driver = get_driver(driver_name) + device_list_src = driver.query_available_devices() + device_list_src.sort(key=lambda d: d["path"]) + return device_list_src + + def get_device_mapping(driver, key_combination=3): + """This method ensures consistent device ordering when choosing + specific devices for execution + Args: + driver (str): execution driver (vulkan, cuda, rocm, etc) + key_combination (int, optional): choice for mapping value for device name. + 1 : path + 2 : name + 3 : (name, path) + Defaults to 3. + Returns: + dict: map to possible device names user can input mapped to desired combination of name/path. + """ + from shark.iree_utils._common import iree_device_map + + driver = iree_device_map(driver) + device_list = get_all_devices(driver) + device_map = dict() + + def get_output_value(dev_dict): + if key_combination == 1: + return f"{driver}://{dev_dict['path']}" + if key_combination == 2: + return dev_dict["name"] + if key_combination == 3: + return (dev_dict["name"], f"{driver}://{dev_dict['path']}") + + # mapping driver name to default device (driver://0) + device_map[f"{driver}"] = get_output_value(device_list[0]) + for i, device in enumerate(device_list): + # mapping with index + device_map[f"{driver}://{i}"] = get_output_value(device) + # mapping with full path + device_map[f"{driver}://{device['path']}"] = get_output_value( + device + ) + return device_map + + def map_device_to_name_path(device, key_combination=3): + """Gives the appropriate device data (supported name/path) for user selected execution device + Args: + device (str): user + key_combination (int, optional): choice for mapping value for device name. + 1 : path + 2 : name + 3 : (name, path) + Defaults to 3. + Raises: + ValueError: + Returns: + str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value + """ + driver = device.split("://")[0] + device_map = get_device_mapping(driver, key_combination) + try: + device_mapping = device_map[device] + except KeyError: + raise ValueError(f"Device '{device}' is not a valid device.") + return device_mapping + + if "vulkan" in args.device: + name, args.device = map_device_to_name_path(args.device) + triple = get_vulkan_target_triple(name) + print(f"Found device {name}. Using target triple {triple}") + # set triple flag to avoid multiple calls to get_vulkan_triple_flag + if args.iree_vulkan_target_triple == "" and triple is not None: + args.iree_vulkan_target_triple = triple + + # use tuned models only in the case of rdna3 cards. + if not args.iree_vulkan_target_triple: + if triple is not None and "rdna3" not in triple: + args.use_tuned = False + elif "rdna3" not in args.iree_vulkan_target_triple: + args.use_tuned = False + + if args.use_tuned: + print("Using tuned models for rdna3 card") + else: + if args.use_tuned: + print("Tuned models not currently supported for device") + args.use_tuned = False