[SD][WEB] Avoid passing args to utils APIs

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-12-16 19:50:03 +05:30
parent a14a47af12
commit bba06d0142
8 changed files with 233 additions and 140 deletions

View File

@@ -18,7 +18,7 @@ p.add_argument(
p.add_argument(
"--negative-prompts",
nargs="+",
default=["trees, green"],
default=[""],
help="text you don't want to see in the generated image.",
)
@@ -55,7 +55,7 @@ p.add_argument(
##############################################################################
p.add_argument(
"--device", type=str, default="cpu", help="device to run the model."
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(

View File

@@ -98,7 +98,7 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
)
with gr.Row():
steps = gr.Slider(1, 100, value=50, step=1, label="Steps")
guidance = gr.Slider(
guidance_scale = gr.Slider(
0,
50,
value=7.5,
@@ -156,7 +156,7 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
prompt,
negative_prompt,
steps,
guidance,
guidance_scale,
seed,
scheduler_key,
],
@@ -168,7 +168,7 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
prompt,
negative_prompt,
steps,
guidance,
guidance_scale,
seed,
scheduler_key,
],

View File

@@ -40,12 +40,8 @@ schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
subfolder="scheduler",
)
# set use_tuned
if "rdna3" not in get_vulkan_triple_flag():
args.use_tuned = False
# set iree-runtime flags
set_iree_runtime_flags(args)
set_iree_runtime_flags()
cache_obj = dict()
# cache vae, unet and clip.
@@ -53,7 +49,7 @@ cache_obj = dict()
cache_obj["vae"],
cache_obj["unet"],
cache_obj["clip"],
) = (get_vae(args), get_unet(args), get_clip(args))
) = (get_vae(), get_unet(), get_clip())
# cache tokenizer
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(

View File

@@ -11,22 +11,19 @@ import numpy as np
import time
def set_ui_params(
prompt, negative_prompt, steps, guidance, seed, scheduler_key
):
def set_ui_params(prompt, negative_prompt, steps, guidance_scale, seed):
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.steps = steps
args.guidance = guidance
args.guidance_scale = guidance_scale
args.seed = seed
args.scheduler = scheduler_key
def stable_diff_inf(
prompt: str,
negative_prompt: str,
steps: int,
guidance: float,
guidance_scale: float,
seed: int,
scheduler_key: str,
):
@@ -37,14 +34,20 @@ def stable_diff_inf(
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
set_ui_params(
prompt, negative_prompt, steps, guidance, seed, scheduler_key
)
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
set_ui_params(prompt, negative_prompt, steps, guidance_scale, seed)
dtype = torch.float32 if args.precision == "fp32" else torch.half
generator = torch.manual_seed(
args.seed
) # Seed generator to create the inital latent noise
guidance_scale = torch.tensor(args.guidance).to(torch.float32)
# set height and width.
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
if args.version == "v2.1":
height = 768
width = 768
# Initialize vae and unet models.
vae, unet, clip, tokenizer = (
cache_obj["vae"],
@@ -52,7 +55,7 @@ def stable_diff_inf(
cache_obj["clip"],
cache_obj["tokenizer"],
)
scheduler = schedulers[args.scheduler]
scheduler = schedulers[scheduler_key]
start = time.time()
text_input = tokenizer(
@@ -84,7 +87,7 @@ def stable_diff_inf(
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents = torch.randn(
(1, 4, args.height // 8, args.width // 8),
(1, 4, height // 8, width // 8),
generator=generator,
dtype=torch.float32,
).to(dtype)
@@ -109,7 +112,7 @@ def stable_diff_inf(
latents_numpy,
timestep,
text_embeddings_numpy,
guidance_scale,
args.guidance_scale,
)
)
noise_pred = torch.from_numpy(noise_pred)
@@ -136,7 +139,7 @@ def stable_diff_inf(
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}, version={args.version}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, scheduler={scheduler_key}, seed={args.seed}, size={height}x{width}, version={args.version}"
text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms)
print(f"\nAverage step time: {avg_ms}ms/it")
text_output += "\nTotal image generation time: {0:.2f}sec".format(

View File

@@ -1,16 +1,17 @@
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from models.stable_diffusion.utils import compile_through_fx
from models.stable_diffusion.stable_args import args
import torch
model_config = {
"v2": "stabilityai/stable-diffusion-2",
"v2.1": "stabilityai/stable-diffusion-2-1",
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
"v1.4": "CompVis/stable-diffusion-v1-4",
}
model_input = {
"v2": {
"v2.1": {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 96, 96),),
"unet": (
@@ -42,13 +43,16 @@ model_input = {
},
}
# revision param for from_pretrained defaults to "main" => fp32
model_revision = "fp16" if args.precision == "fp16" else "main"
def get_clip_mlir(args, model_name="clip_text", extra_args=[]):
def get_clip_mlir(model_name="clip_text", extra_args=[]):
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
if args.version == "v2":
if args.version != "v1.4":
text_encoder = CLIPTextModel.from_pretrained(
model_config[args.version], subfolder="text_encoder"
)
@@ -63,7 +67,6 @@ def get_clip_mlir(args, model_name="clip_text", extra_args=[]):
clip_model = CLIPText()
shark_clip = compile_through_fx(
args,
clip_model,
model_input[args.version]["clip"],
model_name=model_name,
@@ -72,10 +75,7 @@ def get_clip_mlir(args, model_name="clip_text", extra_args=[]):
return shark_clip
def get_vae_mlir(args, model_name="vae", extra_args=[]):
# revision param for from_pretrained defaults to "main" => fp32
model_revision = "fp16" if args.precision == "fp16" else "main"
def get_vae_mlir(model_name="vae", extra_args=[]):
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -102,7 +102,6 @@ def get_vae_mlir(args, model_name="vae", extra_args=[]):
inputs = model_input[args.version]["vae"]
shark_vae = compile_through_fx(
args,
vae,
inputs,
model_name=model_name,
@@ -111,7 +110,7 @@ def get_vae_mlir(args, model_name="vae", extra_args=[]):
return shark_vae
def get_vae_encode_mlir(args, model_name="vae_encode", extra_args=[]):
def get_vae_encode_mlir(model_name="vae_encode", extra_args=[]):
class VaeEncodeModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -131,7 +130,6 @@ def get_vae_encode_mlir(args, model_name="vae_encode", extra_args=[]):
[inputs.half().cuda() for inputs in model_input[args.version]["vae"]]
)
shark_vae = compile_through_fx(
args,
vae,
inputs,
model_name=model_name,
@@ -140,9 +138,7 @@ def get_vae_encode_mlir(args, model_name="vae_encode", extra_args=[]):
return shark_vae
def get_unet_mlir(args, model_name="unet", extra_args=[]):
model_revision = "fp16" if args.precision == "fp16" else "main"
def get_unet_mlir(model_name="unet", extra_args=[]):
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -178,7 +174,6 @@ def get_unet_mlir(args, model_name="unet", extra_args=[]):
else:
inputs = model_input[args.version]["unet"]
shark_unet = compile_through_fx(
args,
unet,
inputs,
model_name=model_name,

View File

@@ -5,15 +5,23 @@ from models.stable_diffusion.model_wrappers import (
get_unet_mlir,
get_clip_mlir,
)
from models.stable_diffusion.stable_args import args
from models.stable_diffusion.utils import get_shark_model
BATCH_SIZE = len(args.prompts)
if BATCH_SIZE != 1:
sys.exit("Only batch size 1 is supported.")
def get_unet(args):
def get_unet():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
# Tuned model is present for `fp16` precision.
if args.precision == "fp16":
if args.use_tuned:
@@ -22,20 +30,22 @@ def get_unet(args):
model_name = "unet_1dec_fp16_tuned"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16_tuned"
return get_shark_model(args, bucket, model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_8dec_fp16"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16"
if args.version == "v2.1":
model_name = "unet2_14dec_fp16"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_unet_mlir(args, model_name, iree_flags)
return get_shark_model(args, bucket, model_name, iree_flags)
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
# Tuned model is not present for `fp32` case.
if args.precision == "fp32":
@@ -47,29 +57,49 @@ def get_unet(args):
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
return get_unet_mlir(args, model_name, iree_flags)
return get_shark_model(args, bucket, model_name, iree_flags)
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "int8":
bucket = "gs://shark_tank/prashant_nod"
model_name = "unet_int8"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
]
sys.exit("int8 model is currently in maintenance.")
# # TODO: Pass iree_flags to the exported model.
# if args.import_mlir:
# sys.exit(
# "--import_mlir is not supported for the int8 model, try --no-import_mlir flag."
# )
# return get_shark_model(bucket, model_name, iree_flags)
def get_vae(args):
def get_vae():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
if args.precision == "fp16":
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.precision in ["fp16", "int8"]:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_8dec_fp16"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16"
if args.version == "v2.1":
model_name = "vae2_14dec_fp16"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_vae_mlir(args, model_name, iree_flags)
return get_shark_model(args, bucket, model_name, iree_flags)
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
@@ -80,23 +110,25 @@ def get_vae(args):
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
return get_vae_mlir(args, model_name, iree_flags)
return get_shark_model(args, bucket, model_name, iree_flags)
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae_encode(args):
def get_vae_encode():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
if args.precision == "fp16":
if args.precision in ["fp16", "int8"]:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_encode_1dec_fp16"
if args.version == "v2":
model_name = "vae2_encode_29nov_fp16"
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_vae_encode_mlir(model_name, iree_flags)
@@ -115,20 +147,25 @@ def get_vae_encode(args):
return get_shark_model(bucket, model_name, iree_flags)
def get_clip(args):
def get_clip():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
bucket = "gs://shark_tank/stable_diffusion"
model_name = "clip_8dec_fp32"
if args.version == "v2.1base":
model_name = "clip2base_8dec_fp32"
if args.version == "v2.1":
model_name = "clip2_14dec_fp32"
iree_flags += [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops",
]
if args.import_mlir:
return get_clip_mlir(args, model_name, iree_flags)
return get_shark_model(args, bucket, model_name, iree_flags)
return get_clip_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)

View File

@@ -4,6 +4,10 @@ p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Stable Diffusion Params
##############################################################################
p.add_argument(
"--prompts",
nargs="+",
@@ -14,14 +18,10 @@ p.add_argument(
p.add_argument(
"--negative-prompts",
nargs="+",
default=["trees, green"],
default=[""],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(
"--steps",
type=int,
@@ -29,13 +29,6 @@ p.add_argument(
help="the no. of steps to do the sampling.",
)
p.add_argument(
"--version",
type=str,
default="v2.1base",
help="Specify version of stable diffusion model",
)
p.add_argument(
"--seed",
type=int,
@@ -43,20 +36,6 @@ p.add_argument(
help="the seed to use.",
)
p.add_argument(
"--height",
type=int,
default=512,
help="the height to use.",
)
p.add_argument(
"--width",
type=int,
default=512,
help="the width to use.",
)
p.add_argument(
"--guidance_scale",
type=float,
@@ -65,10 +44,29 @@ p.add_argument(
)
p.add_argument(
"--scheduler",
"--max_length",
type=int,
default=77,
help="max length of the tokenizer output.",
)
##############################################################################
### Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(
"--version",
type=str,
default="EulerDiscrete",
help="can be [PNDM, LMSDiscrete, DDIM, DPMSolverMultistep, EulerDiscrete]",
default="v2.1base",
help="Specify version of stable diffusion model",
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)
p.add_argument(
@@ -79,23 +77,30 @@ p.add_argument(
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)
p.add_argument(
"--max_length",
type=int,
default=77,
help="max length of the tokenizer output.",
)
p.add_argument(
"--cache",
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
)
p.add_argument(
"--use_tuned",
default=False,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree-vulkan-target-triple",
type=str,
@@ -104,16 +109,10 @@ p.add_argument(
)
p.add_argument(
"--use_tuned",
default=True,
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
help="Profiles vulkan device and collects the .rdc info",
)
p.add_argument(
@@ -122,4 +121,46 @@ p.add_argument(
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
)
##############################################################################
### Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
)
args = p.parse_args()

View File

@@ -2,41 +2,43 @@ import os
import torch
from shark.shark_inference import SharkInference
from models.stable_diffusion.stable_args import args
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
def set_iree_runtime_flags(args):
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
]
if "vulkan" in args.device:
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
return
def _compile_module(args, shark_module, model_name, extra_args=[]):
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
if args.cache:
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
if os.path.isfile(vmfb_path):
print("Loading flatbuffer from {}".format(vmfb_path))
shark_module.load_module(vmfb_path)
return shark_module
print("No vmfb found. Compiling and saving to {}".format(vmfb_path))
path = shark_module.save_module(os.getcwd(), extended_name, extra_args)
shark_module.load_module(path)
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), extended_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(args, tank_url, model_name, extra_args=[]):
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.shark_downloader import download_model
from shark.parser import shark_args
@@ -44,16 +46,18 @@ def get_shark_model(args, tank_url, model_name, extra_args=[]):
shark_args.local_tank_cache = args.local_tank_cache
mlir_model, func_name, inputs, golden_out = download_model(
model_name, tank_url=tank_url, frontend="torch"
model_name,
tank_url=tank_url,
frontend="torch",
)
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
)
return _compile_module(args, shark_module, model_name, extra_args)
return _compile_module(shark_module, model_name, extra_args)
# Converts the torch-module into shark_module.
def compile_through_fx(args, model, inputs, model_name, extra_args=[]):
# Converts the torch-module into a shark_module.
def compile_through_fx(model, inputs, model_name, extra_args=[]):
mlir_module, func_name = import_with_fx(model, inputs)
@@ -64,4 +68,21 @@ def compile_through_fx(args, model, inputs, model_name, extra_args=[]):
mlir_dialect="linalg",
)
return _compile_module(args, shark_module, model_name, extra_args)
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if "vulkan" in args.device:
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
return