[SD] Update SD CLI to use model_db.json

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-12-21 22:44:05 +05:30
parent 1595254eab
commit dfd6ba67b3
12 changed files with 450 additions and 247 deletions

View File

@@ -36,7 +36,9 @@
" from torchdynamo.optimizations.backends import create_backend\n",
" from torchdynamo.optimizations.subgraph import SubGraph\n",
"except ModuleNotFoundError:\n",
" print(\"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\")\n",
" print(\n",
" \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n",
" )\n",
" exit()\n",
"\n",
"# torch-mlir imports for compiling\n",
@@ -97,7 +99,9 @@
"\n",
" for node in fx_g.graph.nodes:\n",
" if node.op == \"output\":\n",
" assert len(node.args) == 1, \"Output node must have a single argument\"\n",
" assert (\n",
" len(node.args) == 1\n",
" ), \"Output node must have a single argument\"\n",
" node_arg = node.args[0]\n",
" if isinstance(node_arg, tuple) and len(node_arg) == 1:\n",
" node.args = (node_arg[0],)\n",
@@ -116,8 +120,12 @@
" if len(args) == 1 and isinstance(args[0], list):\n",
" args = args[0]\n",
"\n",
" linalg_module = compile(ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS)\n",
" callable, _ = get_iree_compiled_module(linalg_module, \"cuda\", func_name=\"forward\")\n",
" linalg_module = compile(\n",
" ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n",
" )\n",
" callable, _ = get_iree_compiled_module(\n",
" linalg_module, \"cuda\", func_name=\"forward\"\n",
" )\n",
"\n",
" def forward(*inputs):\n",
" return callable(*inputs)\n",
@@ -212,6 +220,7 @@
" assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n",
" return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n",
"\n",
"\n",
"@torchdynamo.optimize(\"torch_mlir\")\n",
"def toy_example2(*args):\n",
" a, b = args\n",

View File

@@ -51,7 +51,7 @@ if __name__ == "__main__":
neg_prompt = args.negative_prompts
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
if args.version == "v2.1" and args.variant == "stablediffusion":
if args.version == "v2_1":
height = 768
width = 768
@@ -91,7 +91,7 @@ if __name__ == "__main__":
subfolder="scheduler",
)
cpu_scheduling = True
if args.version == "v2.1":
if args.version == "v2_1":
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2-1", subfolder="tokenizer"
)
@@ -101,7 +101,7 @@ if __name__ == "__main__":
subfolder="scheduler",
)
if args.version == "v2.1base":
if args.version == "v2_1base":
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
)

View File

@@ -5,9 +5,9 @@ from stable_args import args
import torch
model_config = {
"v2.1": "stabilityai/stable-diffusion-2-1",
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
"v1.4": "CompVis/stable-diffusion-v1-4",
"v2_1": "stabilityai/stable-diffusion-2-1",
"v2_1base": "stabilityai/stable-diffusion-2-1-base",
"v1_4": "CompVis/stable-diffusion-v1-4",
}
# clip has 2 variants of max length 77 or 64.
@@ -24,7 +24,7 @@ model_variant = {
}
model_input = {
"v2.1": {
"v2_1": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
"vae": (torch.randn(1, 4, 96, 96),),
"unet": (
@@ -34,7 +34,7 @@ model_input = {
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v2.1base": {
"v2_1base": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
"vae": (torch.randn(1, 4, 64, 64),),
"unet": (
@@ -44,7 +44,7 @@ model_input = {
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v1.4": {
"v1_4": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
"vae": (torch.randn(1, 4, 64, 64),),
"unet": (
@@ -70,7 +70,7 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
"openai/clip-vit-large-patch14"
)
if args.variant == "stablediffusion":
if args.version != "v1.4":
if args.version != "v1_4":
text_encoder = CLIPTextModel.from_pretrained(
model_config[args.version], subfolder="text_encoder"
)
@@ -102,6 +102,54 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
return shark_clip
def get_base_vae_mlir(model_name="vae", extra_args=[]):
class BaseVaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="vae",
revision=model_revision[args.variant],
)
def forward(self, input):
x = self.vae.decode(input, return_dict=False)[0]
return (x / 2 + 0.5).clamp(0, 1)
vae = BaseVaeModel()
if args.variant == "stablediffusion":
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in ["anythingv3", "analogdiffusion"]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
)
else:
inputs = model_input["v1_4"]["vae"]
else:
raise (f"{args.variant} not yet added")
shark_vae = compile_through_fx(
vae,
inputs,
model_name=model_name,
extra_args=extra_args,
)
return shark_vae
def get_vae_mlir(model_name="vae", extra_args=[]):
class VaeModel(torch.nn.Module):
def __init__(self):
@@ -137,10 +185,10 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input["v1.4"]["vae"]]
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
)
else:
inputs = model_input["v1.4"]["vae"]
inputs = model_input["v1_4"]["vae"]
else:
raise (f"{args.variant} not yet added")
@@ -197,11 +245,11 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input["v1.4"]["unet"]
for inputs in model_input["v1_4"]["unet"]
]
)
else:
inputs = model_input["v1.4"]["unet"]
inputs = model_input["v1_4"]["unet"]
else:
raise (f"{args.variant} is not yet added")
shark_unet = compile_through_fx(

View File

@@ -1,9 +1,11 @@
import sys
from model_wrappers import (
get_base_vae_mlir,
get_vae_mlir,
get_unet_mlir,
get_clip_mlir,
)
from resources import models_db
from stable_args import args
from utils import get_shark_model
@@ -11,222 +13,110 @@ BATCH_SIZE = len(args.prompts)
if BATCH_SIZE != 1:
sys.exit("Only batch size 1 is supported.")
def get_unet():
def get_params(model_key):
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.variant == "stablediffusion":
# Tuned model is present for `fp16` precision.
if args.precision == "fp16":
if args.use_tuned:
bucket = "gs://shark_tank/vivian"
if args.version == "v1.4":
model_name = "unet_1dec_fp16_tuned"
if args.version == "v2.1base":
if args.max_length == 64:
model_name = "unet_19dec_v2p1base_fp16_64_tuned"
else:
model_name = "unet2base_8dec_fp16_tuned_v2"
return get_shark_model(bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_8dec_fp16"
if args.version == "v2.1base":
if args.max_length == 64:
model_name = "unet_19dec_v2p1base_fp16_64"
else:
model_name = "unet2base_8dec_fp16"
if args.version == "v2.1":
model_name = "unet2_14dec_fp16"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
try:
model_name = models_db[model_key]
except KeyError:
raise Exception(f"{model_key} is not present in the models database")
# Tuned model is not present for `fp32` case.
if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_1dec_fp32"
return model_name, iree_flags
def get_unet():
# Tuned model is present only for `fp16` precision.
is_tuned = "/tuned" if args.use_tuned else "/untuned"
variant_version = args.variant
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}{is_tuned}"
model_name, iree_flags = get_params(model_key)
if args.use_tuned:
bucket = "gs://shark_tank/vivian"
return get_shark_model(bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
if args.variant == "anythingv3":
bucket = "gs://shark_tank/sd_anythingv3"
elif args.variant == "analogdiffusion":
bucket = "gs://shark_tank/sd_analog_diffusion"
if args.precision == "fp16":
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
elif args.precision == "fp32":
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "int8":
bucket = "gs://shark_tank/prashant_nod"
model_name = "unet_int8"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
]
sys.exit("int8 model is currently in maintenance.")
# # TODO: Pass iree_flags to the exported model.
# if args.import_mlir:
# sys.exit(
# "--import_mlir is not supported for the int8 model, try --no-import_mlir flag."
# )
# return get_shark_model(bucket, model_name, iree_flags)
else:
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.variant == "anythingv3":
bucket = "gs://shark_tank/sd_anythingv3"
model_name = f"av3_unet_19dec_{args.precision}"
elif args.variant == "analogdiffusion":
bucket = "gs://shark_tank/sd_analog_diffusion"
model_name = f"ad_unet_19dec_{args.precision}"
else:
sys.exit(f"{args.variant} variant of SD is currently unsupported")
if args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.variant == "stablediffusion":
if args.precision in ["fp16", "int8"]:
if args.use_tuned:
bucket = "gs://shark_tank/vivian"
if args.version == "v2.1base":
model_name = "vae2base_19dec_fp16_tuned"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
"--iree-flow-enable-conv-winograd-transform",
]
return get_shark_model(bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_19dec_fp16"
if args.version == "v2.1base":
model_name = "vae2base_19dec_fp16"
if args.version == "v2.1":
model_name = "vae2_19dec_fp16"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_1dec_fp32"
# Tuned model is present only for `fp16` precision.
is_tuned = "/tuned" if args.use_tuned else "/untuned"
is_base = "/base" if args.use_base_vae else ""
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77{is_tuned}{is_base}"
model_name, iree_flags = get_params(model_key)
if args.use_tuned:
bucket = "gs://shark_tank/vivian"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
"--iree-flow-enable-conv-winograd-transform",
]
return get_shark_model(bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
if args.variant == "anythingv3":
bucket = "gs://shark_tank/sd_anythingv3"
elif args.variant == "analogdiffusion":
bucket = "gs://shark_tank/sd_analog_diffusion"
if args.precision == "fp16":
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
elif args.precision == "fp32":
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
else:
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
]
if args.precision == "fp16":
iree_flags += [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-conv-img2col-transform",
]
elif args.precision == "fp32":
iree_flags += [
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
]
else:
sys.exit("int8 precision is currently in not supported.")
if args.variant == "anythingv3":
bucket = "gs://shark_tank/sd_anythingv3"
model_name = f"av3_vae_19dec_{args.precision}"
elif args.variant == "analogdiffusion":
bucket = "gs://shark_tank/sd_analog_diffusion"
model_name = f"ad_vae_19dec_{args.precision}"
else:
sys.exit(f"{args.variant} variant of SD is currently unsupported")
if args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
if args.use_base_vae:
return get_base_vae_mlir(model_name, iree_flags)
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.variant == "stablediffusion":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "clip_18dec_fp32"
if args.version == "v2.1base":
if args.max_length == 64:
model_name = "clip_19dec_v2p1base_fp32_64"
else:
model_name = "clip2base_18dec_fp32"
if args.version == "v2.1":
model_name = "clip2_18dec_fp32"
iree_flags += [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops",
]
if args.import_mlir:
return get_clip_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
model_key = f"{args.variant}/{args.version}/clip/fp32/length_{args.max_length}/untuned"
model_name, iree_flags = get_params(model_key)
bucket = "gs://shark_tank/stable_diffusion"
if args.variant == "anythingv3":
bucket = "gs://shark_tank/sd_anythingv3"
model_name = "av3_clip_19dec_fp32"
elif args.variant == "analogdiffusion":
bucket = "gs://shark_tank/sd_analog_diffusion"
model_name = "ad_clip_19dec_fp32"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-conv-img2col-transform",
]
else:
sys.exit(f"{args.variant} variant of SD is currently unsupported")
iree_flags += [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops",
]
if args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
return get_clip_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)

View File

@@ -0,0 +1,25 @@
import os
import json
import sys
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
prompt_examples = []
prompts_loc = resource_path("resources/prompts.json")
if os.path.exists(prompts_loc):
with open(prompts_loc, encoding="utf-8") as fopen:
prompt_examples = json.load(fopen)
models_db = dict()
models_loc = resource_path("resources/model_db.json")
if os.path.exists(models_loc):
with open(models_loc, encoding="utf-8") as fopen:
models_db = json.load(fopen)

View File

@@ -0,0 +1,33 @@
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_1dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet2base_8dec_fp16",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip2base_18dec_fp32",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_19dec_v2p1base_fp32_64",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet2_14dec_fp16",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae2_19dec_fp16",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip2_18dec_fp32",
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32"
}

View File

@@ -0,0 +1,8 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]

View File

@@ -61,7 +61,7 @@ p.add_argument(
p.add_argument(
"--version",
type=str,
default="v2.1base",
default="v2_1base",
help="Specify version of stable diffusion model",
)
@@ -97,11 +97,19 @@ p.add_argument(
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--variant",
default="stablediffusion",
help="We now support multiple vairants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################

View File

@@ -158,7 +158,7 @@ def stable_diff_inf(
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, scheduler={scheduler_key}, seed={args.seed}, size={height}x{width}, version={args.version}"
text_output += f"\nAverage step time: {avg_ms:.2f}ms/it"
text_output += f"\nTotal image generation time: {total_time:.2f}sec"
text_output += f"\nAverage step time: {avg_ms:.4f}ms/it"
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
return pil_images[0], text_output

View File

@@ -12,6 +12,16 @@ model_config = {
# clip has 2 variants of max length 77 or 64.
model_clip_max_length = 64 if args.max_length == 64 else 77
if args.variant != "stablediffusion":
model_clip_max_length = 77
model_variant = {
"stablediffusion": "SD",
"anythingv3": "Linaqruf/anything-v3.0",
"dreamlike": "dreamlike-art/dreamlike-diffusion-1.0",
"openjourney": "prompthero/openjourney",
"analogdiffusion": "wavymulder/Analog-Diffusion",
}
model_input = {
"v2_1": {
@@ -47,7 +57,11 @@ model_input = {
}
# revision param for from_pretrained defaults to "main" => fp32
model_revision = "fp16" if args.precision == "fp16" else "main"
model_revision = {
"stablediffusion": "fp16" if args.precision == "fp16" else "main",
"anythingv3": "diffusers",
"analogdiffusion": "main",
}
def get_clip_mlir(model_name="clip_text", extra_args=[]):
@@ -55,10 +69,20 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
if args.version != "v1_4":
if args.variant == "stablediffusion":
if args.version != "v1_4":
text_encoder = CLIPTextModel.from_pretrained(
model_config[args.version], subfolder="text_encoder"
)
elif args.variant in ["anythingv3", "analogdiffusion"]:
text_encoder = CLIPTextModel.from_pretrained(
model_config[args.version], subfolder="text_encoder"
model_variant[args.variant],
subfolder="text_encoder",
revision=model_revision[args.variant],
)
else:
raise (f"{args.variant} not yet added")
class CLIPText(torch.nn.Module):
def __init__(self):
@@ -83,9 +107,11 @@ def get_base_vae_mlir(model_name="vae", extra_args=[]):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version],
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="vae",
revision=model_revision,
revision=model_revision[args.variant],
)
def forward(self, input):
@@ -93,16 +119,27 @@ def get_base_vae_mlir(model_name="vae", extra_args=[]):
return (x / 2 + 0.5).clamp(0, 1)
vae = BaseVaeModel()
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
if args.variant == "stablediffusion":
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in ["anythingv3", "analogdiffusion"]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
)
else:
inputs = model_input["v1_4"]["vae"]
else:
inputs = model_input[args.version]["vae"]
raise (f"{args.variant} not yet added")
shark_vae = compile_through_fx(
vae,
@@ -118,9 +155,11 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version],
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="vae",
revision=model_revision,
revision=model_revision[args.variant],
)
def forward(self, input):
@@ -131,16 +170,27 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
return x.round()
vae = VaeModel()
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
if args.variant == "stablediffusion":
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in ["anythingv3", "analogdiffusion"]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
)
else:
inputs = model_input["v1_4"]["vae"]
else:
inputs = model_input[args.version]["vae"]
raise (f"{args.variant} not yet added")
shark_vae = compile_through_fx(
vae,
@@ -156,9 +206,11 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_config[args.version],
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="unet",
revision=model_revision,
revision=model_revision[args.variant],
)
self.in_channels = self.unet.in_channels
self.train(False)
@@ -176,16 +228,30 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
return noise_pred
unet = UnetModel()
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input[args.version]["unet"]
]
)
if args.variant == "stablediffusion":
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input[args.version]["unet"]
]
)
else:
inputs = model_input[args.version]["unet"]
elif args.variant in ["anythingv3", "analogdiffusion"]:
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input["v1_4"]["unet"]
]
)
else:
inputs = model_input["v1_4"]["unet"]
else:
inputs = model_input[args.version]["unet"]
raise (f"{args.variant} is not yet added")
shark_unet = compile_through_fx(
unet,
inputs,

View File

@@ -97,12 +97,6 @@ p.add_argument(
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--variant",
default="stablediffusion",
help="We now support multiple variants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported
)
p.add_argument(
"--use_base_vae",
default=False,
@@ -110,6 +104,12 @@ p.add_argument(
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--variant",
default="stablediffusion",
help="We now support multiple vairants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
@@ -145,6 +145,13 @@ p.add_argument(
### Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",
@@ -179,11 +186,18 @@ p.add_argument(
p.add_argument(
"--hide_steps",
default=True,
default=False,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
)
##############################################################################
### Web UI flags
##############################################################################

View File

@@ -4,7 +4,10 @@ import torch
from shark.shark_inference import SharkInference
from models.stable_diffusion.stable_args import args
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
def _compile_module(shark_module, model_name, extra_args=[]):
@@ -86,3 +89,102 @@ def set_iree_runtime_flags():
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
return
def set_init_device_flags():
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
Set `full_dict` flag to True to get a dict
with `path`, `name` and `device_id` for all devices
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(
device
)
return device_map
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
if "vulkan" in args.device:
name, args.device = map_device_to_name_path(args.device)
triple = get_vulkan_target_triple(name)
print(f"Found device {name}. Using target triple {triple}")
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
if args.iree_vulkan_target_triple == "" and triple is not None:
args.iree_vulkan_target_triple = triple
# use tuned models only in the case of rdna3 cards.
if not args.iree_vulkan_target_triple:
if triple is not None and "rdna3" not in triple:
args.use_tuned = False
elif "rdna3" not in args.iree_vulkan_target_triple:
args.use_tuned = False
if args.use_tuned:
print("Using tuned models for rdna3 card")
else:
if args.use_tuned:
print("Tuned models not currently supported for device")
args.use_tuned = False