mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
[SD] Update SD CLI to use model_db.json
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -36,7 +36,9 @@
|
||||
" from torchdynamo.optimizations.backends import create_backend\n",
|
||||
" from torchdynamo.optimizations.subgraph import SubGraph\n",
|
||||
"except ModuleNotFoundError:\n",
|
||||
" print(\"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\")\n",
|
||||
" print(\n",
|
||||
" \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n",
|
||||
" )\n",
|
||||
" exit()\n",
|
||||
"\n",
|
||||
"# torch-mlir imports for compiling\n",
|
||||
@@ -97,7 +99,9 @@
|
||||
"\n",
|
||||
" for node in fx_g.graph.nodes:\n",
|
||||
" if node.op == \"output\":\n",
|
||||
" assert len(node.args) == 1, \"Output node must have a single argument\"\n",
|
||||
" assert (\n",
|
||||
" len(node.args) == 1\n",
|
||||
" ), \"Output node must have a single argument\"\n",
|
||||
" node_arg = node.args[0]\n",
|
||||
" if isinstance(node_arg, tuple) and len(node_arg) == 1:\n",
|
||||
" node.args = (node_arg[0],)\n",
|
||||
@@ -116,8 +120,12 @@
|
||||
" if len(args) == 1 and isinstance(args[0], list):\n",
|
||||
" args = args[0]\n",
|
||||
"\n",
|
||||
" linalg_module = compile(ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS)\n",
|
||||
" callable, _ = get_iree_compiled_module(linalg_module, \"cuda\", func_name=\"forward\")\n",
|
||||
" linalg_module = compile(\n",
|
||||
" ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n",
|
||||
" )\n",
|
||||
" callable, _ = get_iree_compiled_module(\n",
|
||||
" linalg_module, \"cuda\", func_name=\"forward\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(*inputs):\n",
|
||||
" return callable(*inputs)\n",
|
||||
@@ -212,6 +220,7 @@
|
||||
" assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n",
|
||||
" return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@torchdynamo.optimize(\"torch_mlir\")\n",
|
||||
"def toy_example2(*args):\n",
|
||||
" a, b = args\n",
|
||||
|
||||
@@ -51,7 +51,7 @@ if __name__ == "__main__":
|
||||
neg_prompt = args.negative_prompts
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
if args.version == "v2.1" and args.variant == "stablediffusion":
|
||||
if args.version == "v2_1":
|
||||
height = 768
|
||||
width = 768
|
||||
|
||||
@@ -91,7 +91,7 @@ if __name__ == "__main__":
|
||||
subfolder="scheduler",
|
||||
)
|
||||
cpu_scheduling = True
|
||||
if args.version == "v2.1":
|
||||
if args.version == "v2_1":
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", subfolder="tokenizer"
|
||||
)
|
||||
@@ -101,7 +101,7 @@ if __name__ == "__main__":
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
if args.version == "v2.1base":
|
||||
if args.version == "v2_1base":
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
|
||||
)
|
||||
|
||||
@@ -5,9 +5,9 @@ from stable_args import args
|
||||
import torch
|
||||
|
||||
model_config = {
|
||||
"v2.1": "stabilityai/stable-diffusion-2-1",
|
||||
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
|
||||
"v1.4": "CompVis/stable-diffusion-v1-4",
|
||||
"v2_1": "stabilityai/stable-diffusion-2-1",
|
||||
"v2_1base": "stabilityai/stable-diffusion-2-1-base",
|
||||
"v1_4": "CompVis/stable-diffusion-v1-4",
|
||||
}
|
||||
|
||||
# clip has 2 variants of max length 77 or 64.
|
||||
@@ -24,7 +24,7 @@ model_variant = {
|
||||
}
|
||||
|
||||
model_input = {
|
||||
"v2.1": {
|
||||
"v2_1": {
|
||||
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
|
||||
"vae": (torch.randn(1, 4, 96, 96),),
|
||||
"unet": (
|
||||
@@ -34,7 +34,7 @@ model_input = {
|
||||
torch.tensor(1).to(torch.float32), # guidance_scale
|
||||
),
|
||||
},
|
||||
"v2.1base": {
|
||||
"v2_1base": {
|
||||
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
|
||||
"vae": (torch.randn(1, 4, 64, 64),),
|
||||
"unet": (
|
||||
@@ -44,7 +44,7 @@ model_input = {
|
||||
torch.tensor(1).to(torch.float32), # guidance_scale
|
||||
),
|
||||
},
|
||||
"v1.4": {
|
||||
"v1_4": {
|
||||
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
|
||||
"vae": (torch.randn(1, 4, 64, 64),),
|
||||
"unet": (
|
||||
@@ -70,7 +70,7 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
if args.variant == "stablediffusion":
|
||||
if args.version != "v1.4":
|
||||
if args.version != "v1_4":
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_config[args.version], subfolder="text_encoder"
|
||||
)
|
||||
@@ -102,6 +102,54 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
return shark_clip
|
||||
|
||||
|
||||
def get_base_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class BaseVaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="vae",
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
return (x / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
vae = BaseVaeModel()
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
elif args.variant in ["anythingv3", "analogdiffusion"]:
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["vae"]
|
||||
else:
|
||||
raise (f"{args.variant} not yet added")
|
||||
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@@ -137,10 +185,10 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input["v1.4"]["vae"]]
|
||||
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1.4"]["vae"]
|
||||
inputs = model_input["v1_4"]["vae"]
|
||||
else:
|
||||
raise (f"{args.variant} not yet added")
|
||||
|
||||
@@ -197,11 +245,11 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input["v1.4"]["unet"]
|
||||
for inputs in model_input["v1_4"]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1.4"]["unet"]
|
||||
inputs = model_input["v1_4"]["unet"]
|
||||
else:
|
||||
raise (f"{args.variant} is not yet added")
|
||||
shark_unet = compile_through_fx(
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import sys
|
||||
from model_wrappers import (
|
||||
get_base_vae_mlir,
|
||||
get_vae_mlir,
|
||||
get_unet_mlir,
|
||||
get_clip_mlir,
|
||||
)
|
||||
from resources import models_db
|
||||
from stable_args import args
|
||||
from utils import get_shark_model
|
||||
|
||||
@@ -11,222 +13,110 @@ BATCH_SIZE = len(args.prompts)
|
||||
if BATCH_SIZE != 1:
|
||||
sys.exit("Only batch size 1 is supported.")
|
||||
|
||||
def get_unet():
|
||||
|
||||
def get_params(model_key):
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if args.variant == "stablediffusion":
|
||||
# Tuned model is present for `fp16` precision.
|
||||
if args.precision == "fp16":
|
||||
if args.use_tuned:
|
||||
bucket = "gs://shark_tank/vivian"
|
||||
if args.version == "v1.4":
|
||||
model_name = "unet_1dec_fp16_tuned"
|
||||
if args.version == "v2.1base":
|
||||
if args.max_length == 64:
|
||||
model_name = "unet_19dec_v2p1base_fp16_64_tuned"
|
||||
else:
|
||||
model_name = "unet2base_8dec_fp16_tuned_v2"
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
else:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "unet_8dec_fp16"
|
||||
if args.version == "v2.1base":
|
||||
if args.max_length == 64:
|
||||
model_name = "unet_19dec_v2p1base_fp16_64"
|
||||
else:
|
||||
model_name = "unet2base_8dec_fp16"
|
||||
if args.version == "v2.1":
|
||||
model_name = "unet2_14dec_fp16"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
try:
|
||||
model_name = models_db[model_key]
|
||||
except KeyError:
|
||||
raise Exception(f"{model_key} is not present in the models database")
|
||||
|
||||
# Tuned model is not present for `fp32` case.
|
||||
if args.precision == "fp32":
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "unet_1dec_fp32"
|
||||
return model_name, iree_flags
|
||||
|
||||
|
||||
def get_unet():
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "/tuned" if args.use_tuned else "/untuned"
|
||||
variant_version = args.variant
|
||||
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}{is_tuned}"
|
||||
model_name, iree_flags = get_params(model_key)
|
||||
if args.use_tuned:
|
||||
bucket = "gs://shark_tank/vivian"
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
else:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
if args.variant == "anythingv3":
|
||||
bucket = "gs://shark_tank/sd_anythingv3"
|
||||
elif args.variant == "analogdiffusion":
|
||||
bucket = "gs://shark_tank/sd_analog_diffusion"
|
||||
if args.precision == "fp16":
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
elif args.precision == "fp32":
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
if args.precision == "int8":
|
||||
bucket = "gs://shark_tank/prashant_nod"
|
||||
model_name = "unet_int8"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
]
|
||||
sys.exit("int8 model is currently in maintenance.")
|
||||
# # TODO: Pass iree_flags to the exported model.
|
||||
# if args.import_mlir:
|
||||
# sys.exit(
|
||||
# "--import_mlir is not supported for the int8 model, try --no-import_mlir flag."
|
||||
# )
|
||||
# return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
else:
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
if args.variant == "anythingv3":
|
||||
bucket = "gs://shark_tank/sd_anythingv3"
|
||||
model_name = f"av3_unet_19dec_{args.precision}"
|
||||
elif args.variant == "analogdiffusion":
|
||||
bucket = "gs://shark_tank/sd_analog_diffusion"
|
||||
model_name = f"ad_unet_19dec_{args.precision}"
|
||||
else:
|
||||
sys.exit(f"{args.variant} variant of SD is currently unsupported")
|
||||
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae():
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision in ["fp16", "int8"]:
|
||||
if args.use_tuned:
|
||||
bucket = "gs://shark_tank/vivian"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "vae2base_19dec_fp16_tuned"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
"--iree-flow-enable-conv-winograd-transform",
|
||||
]
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
else:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "vae_19dec_fp16"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "vae2base_19dec_fp16"
|
||||
if args.version == "v2.1":
|
||||
model_name = "vae2_19dec_fp16"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_vae_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
if args.precision == "fp32":
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "vae_1dec_fp32"
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "/tuned" if args.use_tuned else "/untuned"
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77{is_tuned}{is_base}"
|
||||
model_name, iree_flags = get_params(model_key)
|
||||
if args.use_tuned:
|
||||
bucket = "gs://shark_tank/vivian"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
"--iree-flow-enable-conv-winograd-transform",
|
||||
]
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
else:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
if args.variant == "anythingv3":
|
||||
bucket = "gs://shark_tank/sd_anythingv3"
|
||||
elif args.variant == "analogdiffusion":
|
||||
bucket = "gs://shark_tank/sd_analog_diffusion"
|
||||
if args.precision == "fp16":
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
elif args.precision == "fp32":
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_vae_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
else:
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
]
|
||||
if args.precision == "fp16":
|
||||
iree_flags += [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
elif args.precision == "fp32":
|
||||
iree_flags += [
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
]
|
||||
else:
|
||||
sys.exit("int8 precision is currently in not supported.")
|
||||
|
||||
if args.variant == "anythingv3":
|
||||
bucket = "gs://shark_tank/sd_anythingv3"
|
||||
model_name = f"av3_vae_19dec_{args.precision}"
|
||||
|
||||
elif args.variant == "analogdiffusion":
|
||||
bucket = "gs://shark_tank/sd_analog_diffusion"
|
||||
model_name = f"ad_vae_19dec_{args.precision}"
|
||||
|
||||
else:
|
||||
sys.exit(f"{args.variant} variant of SD is currently unsupported")
|
||||
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
if args.use_base_vae:
|
||||
return get_base_vae_mlir(model_name, iree_flags)
|
||||
return get_vae_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_clip():
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if args.variant == "stablediffusion":
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "clip_18dec_fp32"
|
||||
if args.version == "v2.1base":
|
||||
if args.max_length == 64:
|
||||
model_name = "clip_19dec_v2p1base_fp32_64"
|
||||
else:
|
||||
model_name = "clip2base_18dec_fp32"
|
||||
if args.version == "v2.1":
|
||||
model_name = "clip2_18dec_fp32"
|
||||
iree_flags += [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_clip_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
model_key = f"{args.variant}/{args.version}/clip/fp32/length_{args.max_length}/untuned"
|
||||
model_name, iree_flags = get_params(model_key)
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
if args.variant == "anythingv3":
|
||||
bucket = "gs://shark_tank/sd_anythingv3"
|
||||
model_name = "av3_clip_19dec_fp32"
|
||||
elif args.variant == "analogdiffusion":
|
||||
bucket = "gs://shark_tank/sd_analog_diffusion"
|
||||
model_name = "ad_clip_19dec_fp32"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
else:
|
||||
sys.exit(f"{args.variant} variant of SD is currently unsupported")
|
||||
|
||||
iree_flags += [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
return get_clip_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
25
shark/examples/shark_inference/stable_diffusion/resources.py
Normal file
25
shark/examples/shark_inference/stable_diffusion/resources.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
prompt_examples = []
|
||||
prompts_loc = resource_path("resources/prompts.json")
|
||||
if os.path.exists(prompts_loc):
|
||||
with open(prompts_loc, encoding="utf-8") as fopen:
|
||||
prompt_examples = json.load(fopen)
|
||||
|
||||
|
||||
models_db = dict()
|
||||
models_loc = resource_path("resources/model_db.json")
|
||||
if os.path.exists(models_loc):
|
||||
with open(models_loc, encoding="utf-8") as fopen:
|
||||
models_db = json.load(fopen)
|
||||
@@ -0,0 +1,33 @@
|
||||
{
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_1dec_fp16_tuned",
|
||||
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
|
||||
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
|
||||
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet2base_8dec_fp16",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
|
||||
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip2base_18dec_fp32",
|
||||
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_19dec_v2p1base_fp32_64",
|
||||
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet2_14dec_fp16",
|
||||
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae2_19dec_fp16",
|
||||
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
|
||||
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip2_18dec_fp32",
|
||||
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
|
||||
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
|
||||
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
|
||||
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
|
||||
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
|
||||
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
|
||||
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
|
||||
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
|
||||
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
|
||||
"analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32"
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
[["A high tech solarpunk utopia in the Amazon rainforest"],
|
||||
["A pikachu fine dining with a view to the Eiffel Tower"],
|
||||
["A mecha robot in a favela in expressionist style"],
|
||||
["an insect robot preparing a delicious meal"],
|
||||
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
|
||||
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
|
||||
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
|
||||
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]
|
||||
@@ -61,7 +61,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="v2.1base",
|
||||
default="v2_1base",
|
||||
help="Specify version of stable diffusion model",
|
||||
)
|
||||
|
||||
@@ -97,11 +97,19 @@ p.add_argument(
|
||||
help="Download and use the tuned version of the model if available",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_base_vae",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Do conversion from the VAE output to pixel space on cpu.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--variant",
|
||||
default="stablediffusion",
|
||||
help="We now support multiple vairants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
@@ -158,7 +158,7 @@ def stable_diff_inf(
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, scheduler={scheduler_key}, seed={args.seed}, size={height}x{width}, version={args.version}"
|
||||
text_output += f"\nAverage step time: {avg_ms:.2f}ms/it"
|
||||
text_output += f"\nTotal image generation time: {total_time:.2f}sec"
|
||||
text_output += f"\nAverage step time: {avg_ms:.4f}ms/it"
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
return pil_images[0], text_output
|
||||
|
||||
@@ -12,6 +12,16 @@ model_config = {
|
||||
|
||||
# clip has 2 variants of max length 77 or 64.
|
||||
model_clip_max_length = 64 if args.max_length == 64 else 77
|
||||
if args.variant != "stablediffusion":
|
||||
model_clip_max_length = 77
|
||||
|
||||
model_variant = {
|
||||
"stablediffusion": "SD",
|
||||
"anythingv3": "Linaqruf/anything-v3.0",
|
||||
"dreamlike": "dreamlike-art/dreamlike-diffusion-1.0",
|
||||
"openjourney": "prompthero/openjourney",
|
||||
"analogdiffusion": "wavymulder/Analog-Diffusion",
|
||||
}
|
||||
|
||||
model_input = {
|
||||
"v2_1": {
|
||||
@@ -47,7 +57,11 @@ model_input = {
|
||||
}
|
||||
|
||||
# revision param for from_pretrained defaults to "main" => fp32
|
||||
model_revision = "fp16" if args.precision == "fp16" else "main"
|
||||
model_revision = {
|
||||
"stablediffusion": "fp16" if args.precision == "fp16" else "main",
|
||||
"anythingv3": "diffusers",
|
||||
"analogdiffusion": "main",
|
||||
}
|
||||
|
||||
|
||||
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
@@ -55,10 +69,20 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
if args.version != "v1_4":
|
||||
if args.variant == "stablediffusion":
|
||||
if args.version != "v1_4":
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_config[args.version], subfolder="text_encoder"
|
||||
)
|
||||
|
||||
elif args.variant in ["anythingv3", "analogdiffusion"]:
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_config[args.version], subfolder="text_encoder"
|
||||
model_variant[args.variant],
|
||||
subfolder="text_encoder",
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
else:
|
||||
raise (f"{args.variant} not yet added")
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@@ -83,9 +107,11 @@ def get_base_vae_mlir(model_name="vae", extra_args=[]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version],
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="vae",
|
||||
revision=model_revision,
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
@@ -93,16 +119,27 @@ def get_base_vae_mlir(model_name="vae", extra_args=[]):
|
||||
return (x / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
vae = BaseVaeModel()
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
elif args.variant in ["anythingv3", "analogdiffusion"]:
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["vae"]
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
raise (f"{args.variant} not yet added")
|
||||
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
@@ -118,9 +155,11 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version],
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="vae",
|
||||
revision=model_revision,
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
@@ -131,16 +170,27 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
return x.round()
|
||||
|
||||
vae = VaeModel()
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
elif args.variant in ["anythingv3", "analogdiffusion"]:
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["vae"]
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
raise (f"{args.variant} not yet added")
|
||||
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
@@ -156,9 +206,11 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_config[args.version],
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="unet",
|
||||
revision=model_revision,
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
@@ -176,16 +228,30 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel()
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input[args.version]["unet"]
|
||||
]
|
||||
)
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input[args.version]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["unet"]
|
||||
elif args.variant in ["anythingv3", "analogdiffusion"]:
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input["v1_4"]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["unet"]
|
||||
else:
|
||||
inputs = model_input[args.version]["unet"]
|
||||
raise (f"{args.variant} is not yet added")
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
|
||||
@@ -97,12 +97,6 @@ p.add_argument(
|
||||
help="Download and use the tuned version of the model if available",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--variant",
|
||||
default="stablediffusion",
|
||||
help="We now support multiple variants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_base_vae",
|
||||
default=False,
|
||||
@@ -110,6 +104,12 @@ p.add_argument(
|
||||
help="Do conversion from the VAE output to pixel space on cpu.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--variant",
|
||||
default="stablediffusion",
|
||||
help="We now support multiple vairants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
@@ -145,6 +145,13 @@ p.add_argument(
|
||||
### Misc. Debug and Optimization flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--use_compiled_scheduler",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="use the default scheduler precompiled into the model if available",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
@@ -179,11 +186,18 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--hide_steps",
|
||||
default=True,
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for hiding the details of iteration/sec for each step.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--warmup_count",
|
||||
type=int,
|
||||
default=0,
|
||||
help="flag setting warmup count for clip and vae [>= 0].",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Web UI flags
|
||||
##############################################################################
|
||||
|
||||
@@ -4,7 +4,10 @@ import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from models.stable_diffusion.stable_args import args
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
@@ -86,3 +89,102 @@ def set_iree_runtime_flags():
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def set_init_device_flags():
|
||||
def get_all_devices(driver_name):
|
||||
"""
|
||||
Inputs: driver_name
|
||||
Returns a list of all the available devices for a given driver sorted by
|
||||
the iree path names of the device as in --list_devices option in iree.
|
||||
Set `full_dict` flag to True to get a dict
|
||||
with `path`, `name` and `device_id` for all devices
|
||||
"""
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return device_list_src
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired combination of name/path.
|
||||
"""
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
|
||||
def get_output_value(dev_dict):
|
||||
if key_combination == 1:
|
||||
return f"{driver}://{dev_dict['path']}"
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
|
||||
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(
|
||||
device
|
||||
)
|
||||
return device_map
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
try:
|
||||
device_mapping = device_map[device]
|
||||
except KeyError:
|
||||
raise ValueError(f"Device '{device}' is not a valid device.")
|
||||
return device_mapping
|
||||
|
||||
if "vulkan" in args.device:
|
||||
name, args.device = map_device_to_name_path(args.device)
|
||||
triple = get_vulkan_target_triple(name)
|
||||
print(f"Found device {name}. Using target triple {triple}")
|
||||
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
||||
if args.iree_vulkan_target_triple == "" and triple is not None:
|
||||
args.iree_vulkan_target_triple = triple
|
||||
|
||||
# use tuned models only in the case of rdna3 cards.
|
||||
if not args.iree_vulkan_target_triple:
|
||||
if triple is not None and "rdna3" not in triple:
|
||||
args.use_tuned = False
|
||||
elif "rdna3" not in args.iree_vulkan_target_triple:
|
||||
args.use_tuned = False
|
||||
|
||||
if args.use_tuned:
|
||||
print("Using tuned models for rdna3 card")
|
||||
else:
|
||||
if args.use_tuned:
|
||||
print("Tuned models not currently supported for device")
|
||||
args.use_tuned = False
|
||||
|
||||
Reference in New Issue
Block a user